Skip to content

Commit 3d93236

Browse files
committed
perf: short-circuit tracer-free conversions
1 parent 9095eb2 commit 3d93236

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

tidy3d/components/base.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ class Config:
231231
copy_on_model_validation = "none"
232232

233233
_cached_properties = pydantic.PrivateAttr({})
234+
_has_tracers: Optional[bool] = pydantic.PrivateAttr(default=None)
234235

235236
@pydantic.root_validator(skip_on_failure=True)
236237
def _special_characters_not_in_name(cls, values):
@@ -283,6 +284,7 @@ def copy(self, deep: bool = True, validate: bool = True, **kwargs: Any) -> Self:
283284
# cached property is cleared automatically when validation is on, but it
284285
# needs to be manually cleared when validation is off
285286
new_copy._cached_properties = {}
287+
new_copy._has_tracers = None
286288
return new_copy
287289

288290
def updated_copy(
@@ -1054,7 +1056,7 @@ def _json(self, indent=INDENT, exclude_unset=False, **kwargs: Any) -> str:
10541056
return json_string
10551057

10561058
def _strip_traced_fields(
1057-
self, starting_path: tuple[str] = (), include_untraced_data_arrays: bool = False
1059+
self, starting_path: tuple[str, ...] = (), include_untraced_data_arrays: bool = False
10581060
) -> AutogradFieldMap:
10591061
"""Extract a dictionary mapping paths in the model to the data traced by ``autograd``.
10601062
@@ -1073,6 +1075,10 @@ def _strip_traced_fields(
10731075
10741076
"""
10751077

1078+
path = tuple(starting_path)
1079+
if self._has_tracers is False and not include_untraced_data_arrays:
1080+
return dict_ag()
1081+
10761082
field_mapping = {}
10771083

10781084
def handle_value(x: Any, path: tuple[str, ...]) -> None:
@@ -1100,14 +1106,20 @@ def handle_value(x: Any, path: tuple[str, ...]) -> None:
11001106
self_dict = self.dict()
11011107

11021108
# if an include_only string was provided, only look at that subset of the dict
1103-
if starting_path:
1104-
for key in starting_path:
1109+
if path:
1110+
for key in path:
11051111
self_dict = self_dict[key]
11061112

1107-
handle_value(self_dict, path=starting_path)
1113+
handle_value(self_dict, path=path)
1114+
1115+
if field_mapping:
1116+
if not include_untraced_data_arrays:
1117+
self._has_tracers = True
1118+
return dict_ag(field_mapping)
11081119

1109-
# convert the resulting field_mapping to an autograd-traced dictionary
1110-
return dict_ag(field_mapping)
1120+
if not include_untraced_data_arrays and not path:
1121+
self._has_tracers = False
1122+
return dict_ag()
11111123

11121124
def _insert_traced_fields(self, field_mapping: AutogradFieldMap) -> Self:
11131125
"""Recursively insert a map of paths to autograd-traced fields into a copy of this obj."""
@@ -1157,18 +1169,24 @@ def _serialized_traced_field_keys(
11571169
def to_static(self) -> Self:
11581170
"""Version of object with all autograd-traced fields removed."""
11591171

1172+
if self._has_tracers is False:
1173+
return self
1174+
11601175
# get dictionary of all traced fields
11611176
field_mapping = self._strip_traced_fields()
11621177

11631178
# shortcut to just return self if no tracers found, for performance
11641179
if not field_mapping:
1180+
self._has_tracers = False
11651181
return self
11661182

11671183
# convert all fields to static values
11681184
field_mapping_static = {key: get_static(val) for key, val in field_mapping.items()}
11691185

11701186
# insert the static values into a copy of self
1171-
return self._insert_traced_fields(field_mapping_static)
1187+
static_self = self._insert_traced_fields(field_mapping_static)
1188+
static_self._has_tracers = False
1189+
return static_self
11721190

11731191
@classmethod
11741192
def add_type_field(cls) -> None:

0 commit comments

Comments
 (0)