@@ -231,6 +231,7 @@ class Config:
231231 copy_on_model_validation = "none"
232232
233233 _cached_properties = pydantic .PrivateAttr ({})
234+ _has_tracers : Optional [bool ] = pydantic .PrivateAttr (default = None )
234235
235236 @pydantic .root_validator (skip_on_failure = True )
236237 def _special_characters_not_in_name (cls , values ):
@@ -283,6 +284,7 @@ def copy(self, deep: bool = True, validate: bool = True, **kwargs: Any) -> Self:
283284 # cached property is cleared automatically when validation is on, but it
284285 # needs to be manually cleared when validation is off
285286 new_copy ._cached_properties = {}
287+ new_copy ._has_tracers = None
286288 return new_copy
287289
288290 def updated_copy (
@@ -1054,7 +1056,7 @@ def _json(self, indent=INDENT, exclude_unset=False, **kwargs: Any) -> str:
10541056 return json_string
10551057
10561058 def _strip_traced_fields (
1057- self , starting_path : tuple [str ] = (), include_untraced_data_arrays : bool = False
1059+ self , starting_path : tuple [str , ... ] = (), include_untraced_data_arrays : bool = False
10581060 ) -> AutogradFieldMap :
10591061 """Extract a dictionary mapping paths in the model to the data traced by ``autograd``.
10601062
@@ -1073,6 +1075,10 @@ def _strip_traced_fields(
10731075
10741076 """
10751077
1078+ path = tuple (starting_path )
1079+ if self ._has_tracers is False and not include_untraced_data_arrays :
1080+ return dict_ag ()
1081+
10761082 field_mapping = {}
10771083
10781084 def handle_value (x : Any , path : tuple [str , ...]) -> None :
@@ -1100,14 +1106,20 @@ def handle_value(x: Any, path: tuple[str, ...]) -> None:
11001106 self_dict = self .dict ()
11011107
11021108 # if an include_only string was provided, only look at that subset of the dict
1103- if starting_path :
1104- for key in starting_path :
1109+ if path :
1110+ for key in path :
11051111 self_dict = self_dict [key ]
11061112
1107- handle_value (self_dict , path = starting_path )
1113+ handle_value (self_dict , path = path )
1114+
1115+ if field_mapping :
1116+ if not include_untraced_data_arrays :
1117+ self ._has_tracers = True
1118+ return dict_ag (field_mapping )
11081119
1109- # convert the resulting field_mapping to an autograd-traced dictionary
1110- return dict_ag (field_mapping )
1120+ if not include_untraced_data_arrays and not path :
1121+ self ._has_tracers = False
1122+ return dict_ag ()
11111123
11121124 def _insert_traced_fields (self , field_mapping : AutogradFieldMap ) -> Self :
11131125 """Recursively insert a map of paths to autograd-traced fields into a copy of this obj."""
@@ -1157,18 +1169,24 @@ def _serialized_traced_field_keys(
11571169 def to_static (self ) -> Self :
11581170 """Version of object with all autograd-traced fields removed."""
11591171
1172+ if self ._has_tracers is False :
1173+ return self
1174+
11601175 # get dictionary of all traced fields
11611176 field_mapping = self ._strip_traced_fields ()
11621177
11631178 # shortcut to just return self if no tracers found, for performance
11641179 if not field_mapping :
1180+ self ._has_tracers = False
11651181 return self
11661182
11671183 # convert all fields to static values
11681184 field_mapping_static = {key : get_static (val ) for key , val in field_mapping .items ()}
11691185
11701186 # insert the static values into a copy of self
1171- return self ._insert_traced_fields (field_mapping_static )
1187+ static_self = self ._insert_traced_fields (field_mapping_static )
1188+ static_self ._has_tracers = False
1189+ return static_self
11721190
11731191 @classmethod
11741192 def add_type_field (cls ) -> None :
0 commit comments