Skip to content

Commit

Permalink
Parametrise type of static params to exclude from dynamic param creat…
Browse files Browse the repository at this point in the history
…ion in sources.

This allows us to extend the exclusion logic more simply for child sources.

PiperOrigin-RevId: 711463448
  • Loading branch information
Nush395 authored and Torax team committed Jan 2, 2025
1 parent fb1f53b commit 92a34e0
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions torax/sources/runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,25 @@ def build_static_params(self) -> StaticRuntimeParams:
)


@chex.dataclass(frozen=True)
class DynamicRuntimeParams:
"""Dynamic params for a single TORAX source.
These params can be changed without triggering a recompile. TORAX sources are
stateless, so these params are their inputs to determine their output
profiles.
"""
prescribed_values: array_typing.ArrayFloat


@chex.dataclass(frozen=True)
class StaticRuntimeParams:
"""Static params for the sources."""

mode: int
is_explicit: bool


@chex.dataclass
class RuntimeParamsProvider(
base.RuntimeParametersProvider['DynamicRuntimeParams']
Expand All @@ -111,12 +130,15 @@ class RuntimeParamsProvider(
def get_dynamic_params_kwargs(
self,
t: chex.Numeric,
static_runtime_params_type: type[
StaticRuntimeParams
] = StaticRuntimeParams,
) -> dict[str, Any]:
dynamic_params_kwargs = super(
RuntimeParamsProvider, self
).get_dynamic_params_kwargs(t)
# Remove any fields from runtime params that are included in static params.
for field in dataclasses.fields(StaticRuntimeParams):
for field in dataclasses.fields(static_runtime_params_type):
del dynamic_params_kwargs[field.name]
return dynamic_params_kwargs

Expand All @@ -125,22 +147,3 @@ def build_dynamic_params(
t: chex.Numeric,
) -> DynamicRuntimeParams:
return DynamicRuntimeParams(**self.get_dynamic_params_kwargs(t))


@chex.dataclass(frozen=True)
class DynamicRuntimeParams:
"""Dynamic params for a single TORAX source.
These params can be changed without triggering a recompile. TORAX sources are
stateless, so these params are their inputs to determine their output
profiles.
"""
prescribed_values: array_typing.ArrayFloat


@chex.dataclass(frozen=True)
class StaticRuntimeParams:
"""Static params for the sources."""

mode: int
is_explicit: bool

0 comments on commit 92a34e0

Please sign in to comment.