From 41b9d2da512fee04f9d3c82ff39317fa2057058c Mon Sep 17 00:00:00 2001 From: Anushan Fernando Date: Thu, 2 Jan 2025 09:01:46 -0800 Subject: [PATCH] Parametrise type of static params to exclude from dynamic param creation in sources. This allows us to extend the exclusion logic more simply for child sources. PiperOrigin-RevId: 711439698 --- torax/sources/runtime_params.py | 43 ++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/torax/sources/runtime_params.py b/torax/sources/runtime_params.py index b42bac7a..1cd244ea 100644 --- a/torax/sources/runtime_params.py +++ b/torax/sources/runtime_params.py @@ -99,6 +99,25 @@ def build_static_params(self) -> StaticRuntimeParams: ) +@chex.dataclass(frozen=True) +class DynamicRuntimeParams: + """Dynamic params for a single TORAX source. + + These params can be changed without triggering a recompile. TORAX sources are + stateless, so these params are their inputs to determine their output + profiles. + """ + prescribed_values: array_typing.ArrayFloat + + +@chex.dataclass(frozen=True) +class StaticRuntimeParams: + """Static params for the sources.""" + + mode: int + is_explicit: bool + + @chex.dataclass class RuntimeParamsProvider( base.RuntimeParametersProvider['DynamicRuntimeParams'] @@ -111,12 +130,15 @@ class RuntimeParamsProvider( def get_dynamic_params_kwargs( self, t: chex.Numeric, + static_runtime_params_type: type[ + StaticRuntimeParams + ] = StaticRuntimeParams, ) -> dict[str, Any]: dynamic_params_kwargs = super( RuntimeParamsProvider, self ).get_dynamic_params_kwargs(t) # Remove any fields from runtime params that are included in static params. - for field in dataclasses.fields(StaticRuntimeParams): + for field in dataclasses.fields(static_runtime_params_type): del dynamic_params_kwargs[field.name] return dynamic_params_kwargs @@ -125,22 +147,3 @@ def build_dynamic_params( t: chex.Numeric, ) -> DynamicRuntimeParams: return DynamicRuntimeParams(**self.get_dynamic_params_kwargs(t)) - - -@chex.dataclass(frozen=True) -class DynamicRuntimeParams: - """Dynamic params for a single TORAX source. - - These params can be changed without triggering a recompile. TORAX sources are - stateless, so these params are their inputs to determine their output - profiles. - """ - prescribed_values: array_typing.ArrayFloat - - -@chex.dataclass(frozen=True) -class StaticRuntimeParams: - """Static params for the sources.""" - - mode: int - is_explicit: bool