Skip to content

Commit

Permalink
Hooking up new pedestal model API to experiments via build_sim and pl…
Browse files Browse the repository at this point in the history
…umbing through to where it is used in calc_coeffs and transport models.

The rest of this change is a whole lot of plumbing and updating interfaces.

TODO in follow up changes:
- rename pedestal_model/basic.py to something more meaningful
- only conditionally compute the pedestal model depending on `set_pedestal` as it
 could be wasteful to compute this in future if we have more expensive pedestal models.

PiperOrigin-RevId: 698369283
  • Loading branch information
Nush395 authored and Torax team committed Nov 21, 2024
1 parent e042da3 commit ae3b9f2
Show file tree
Hide file tree
Showing 88 changed files with 955 additions and 285 deletions.
7 changes: 7 additions & 0 deletions run_simulation_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,11 @@ def change_config(
new_stepper_builder = build_sim.build_stepper_builder_from_config(
sim_config['stepper']
)
new_pedestal_model_builder = (
build_sim.build_pedestal_model_builder_from_config(
sim_config['pedestal']
)
)
else:
# Assume the config module has several methods to define the individual Sim
# attributes (the "advanced", more Python-forward configuration method).
Expand All @@ -282,6 +287,7 @@ def change_config(
new_transport_model_builder = config_module.get_transport_model_builder()
source_models_builder = config_module.get_sources_builder()
new_stepper_builder = config_module.get_stepper_builder()
new_pedestal_model_builder = config_module.get_pedestal_model_builder()
new_source_params = {
name: runtime_params
for name, runtime_params in source_models_builder.runtime_params.items()
Expand All @@ -301,6 +307,7 @@ def change_config(
transport_runtime_params=new_transport_model_builder.runtime_params,
source_runtime_params=new_source_params,
stepper_runtime_params=new_stepper_builder.runtime_params,
pedestal_runtime_params=new_pedestal_model_builder.runtime_params,
)
return sim, new_runtime_params

Expand Down
1 change: 1 addition & 0 deletions torax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
'source_profile',
'explicit_source_profiles',
'source_models',
'pedestal_model',
'time_step_calculator',
'coeffs_callback',
'evolving_names',
Expand Down
52 changes: 26 additions & 26 deletions torax/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@
from torax import state
from torax.config import runtime_params_slice
from torax.fvm import block_1d_coeffs
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profiles as source_profiles_lib
from torax.transport_model import transport_model as transport_model_lib


def calculate_pereverzev_flux(
def _calculate_pereverzev_flux(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:
"""Adds Pereverzev-Corrigan flux to diffusion terms."""

Expand Down Expand Up @@ -80,7 +82,7 @@ def calculate_pereverzev_flux(
jnp.logical_and(
dynamic_runtime_params_slice.profile_conditions.set_pedestal,
geo.rho_face_norm
> dynamic_runtime_params_slice.profile_conditions.Ped_top,
> pedestal_model_output.rho_norm_ped,
),
0.0,
chi_face_per_ion,
Expand All @@ -89,7 +91,7 @@ def calculate_pereverzev_flux(
jnp.logical_and(
dynamic_runtime_params_slice.profile_conditions.set_pedestal,
geo.rho_face_norm
> dynamic_runtime_params_slice.profile_conditions.Ped_top,
> pedestal_model_output.rho_norm_ped,
),
0.0,
chi_face_per_el,
Expand All @@ -110,7 +112,7 @@ def calculate_pereverzev_flux(
jnp.logical_and(
dynamic_runtime_params_slice.profile_conditions.set_pedestal,
geo.rho_face_norm
> dynamic_runtime_params_slice.profile_conditions.Ped_top,
> pedestal_model_output.rho_norm_ped,
),
0.0,
d_face_per_el * geo.g1_over_vpr_face,
Expand All @@ -120,7 +122,7 @@ def calculate_pereverzev_flux(
jnp.logical_and(
dynamic_runtime_params_slice.profile_conditions.set_pedestal,
geo.rho_face_norm
> dynamic_runtime_params_slice.profile_conditions.Ped_top,
> pedestal_model_output.rho_norm_ped,
),
0.0,
v_face_per_el * geo.g0_face,
Expand All @@ -147,6 +149,7 @@ def calc_coeffs(
transport_model: transport_model_lib.TransportModel,
explicit_source_profiles: source_profiles_lib.SourceProfiles,
source_models: source_models_lib.SourceModels,
pedestal_model: pedestal_model_lib.PedestalModel,
evolving_names: tuple[str, ...],
use_pereverzev: bool = False,
explicit_call: bool = False,
Expand All @@ -173,6 +176,7 @@ def calc_coeffs(
source_models: All TORAX source/sink functions that generate the explicit
and implicit source profiles used as terms for the core profiles
equations.
pedestal_model: A PedestalModel subclass, calculates pedestal values.
evolving_names: The names of the evolving variables in the order that their
coefficients should be written to `coeffs`.
use_pereverzev: Toggle whether to calculate Pereverzev terms
Expand Down Expand Up @@ -202,6 +206,7 @@ def calc_coeffs(
transport_model,
explicit_source_profiles,
source_models,
pedestal_model,
evolving_names,
use_pereverzev,
)
Expand All @@ -212,6 +217,7 @@ def calc_coeffs(
static_argnames=[
'static_runtime_params_slice',
'transport_model',
'pedestal_model',
'source_models',
'evolving_names',
],
Expand All @@ -224,6 +230,7 @@ def _calc_coeffs_full(
transport_model: transport_model_lib.TransportModel,
explicit_source_profiles: source_profiles_lib.SourceProfiles,
source_models: source_models_lib.SourceModels,
pedestal_model: pedestal_model_lib.PedestalModel,
evolving_names: tuple[str, ...],
use_pereverzev: bool = False,
) -> block_1d_coeffs.Block1DCoeffs:
Expand All @@ -249,6 +256,7 @@ def _calc_coeffs_full(
source_models: All TORAX source/sink functions that generate the explicit
and implicit source profiles used as terms for the core profiles
equations.
pedestal_model: A PedestalModel subclass, calculates pedestal values.
evolving_names: The names of the evolving variables in the order that their
coefficients should be written to `coeffs`.
use_pereverzev: Toggle whether to calculate Pereverzev terms
Expand All @@ -259,11 +267,15 @@ def _calc_coeffs_full(

consts = constants.CONSTANTS

pedestal_model_output = pedestal_model(
dynamic_runtime_params_slice, geo, core_profiles
)

# Boolean mask for enforcing internal temperature boundary conditions to
# model the pedestal.
mask = physics.internal_boundary(
geo,
dynamic_runtime_params_slice.profile_conditions.Ped_top,
pedestal_model_output.rho_norm_ped,
dynamic_runtime_params_slice.profile_conditions.set_pedestal,
)

Expand Down Expand Up @@ -403,7 +415,7 @@ def _calc_coeffs_full(

# Diffusion term coefficients
transport_coeffs = transport_model(
dynamic_runtime_params_slice, geo, core_profiles
dynamic_runtime_params_slice, geo, core_profiles, pedestal_model_output
)
chi_face_ion = transport_coeffs.chi_face_ion
chi_face_el = transport_coeffs.chi_face_el
Expand Down Expand Up @@ -562,24 +574,11 @@ def _calc_coeffs_full(
source_models,
)

# calculate neped
# pylint: disable=invalid-name
nGW = (
dynamic_runtime_params_slice.profile_conditions.Ip_tot
/ (jnp.pi * geo.Rmin**2)
* 1e20
/ dynamic_runtime_params_slice.numerics.nref
)
# pylint: enable=invalid-name
neped_unnorm = jnp.where(
dynamic_runtime_params_slice.profile_conditions.neped_is_fGW,
dynamic_runtime_params_slice.profile_conditions.neped * nGW,
dynamic_runtime_params_slice.profile_conditions.neped,
)

source_ne += jnp.where(
dynamic_runtime_params_slice.profile_conditions.set_pedestal,
mask * dynamic_runtime_params_slice.numerics.largeValue_n * neped_unnorm,
mask
* dynamic_runtime_params_slice.numerics.largeValue_n
* pedestal_model_output.neped,
0.0,
)
source_mat_nn += jnp.where(
Expand All @@ -602,10 +601,11 @@ def _calc_coeffs_full(
v_face_per_el,
) = jax.lax.cond(
use_pereverzev,
lambda: calculate_pereverzev_flux(
lambda: _calculate_pereverzev_flux(
dynamic_runtime_params_slice,
geo,
core_profiles,
pedestal_model_output,
),
lambda: tuple([jnp.zeros_like(geo.rho_face)] * 6),
)
Expand Down Expand Up @@ -720,14 +720,14 @@ def _calc_coeffs_full(
dynamic_runtime_params_slice.profile_conditions.set_pedestal,
mask
* dynamic_runtime_params_slice.numerics.largeValue_T
* dynamic_runtime_params_slice.profile_conditions.Tiped,
* pedestal_model_output.Tiped,
0.0,
)
source_e += jnp.where(
dynamic_runtime_params_slice.profile_conditions.set_pedestal,
mask
* dynamic_runtime_params_slice.numerics.largeValue_T
* dynamic_runtime_params_slice.profile_conditions.Teped,
* pedestal_model_output.Teped,
0.0,
)

Expand Down
24 changes: 23 additions & 1 deletion torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Functions to build sim.Sim objects, which are used to run TORAX."""

from collections.abc import MutableMapping
import copy
from typing import Any
Expand All @@ -22,6 +23,8 @@
from torax import sim as sim_lib
from torax.config import config_args
from torax.config import runtime_params as runtime_params_lib
from torax.pedestal_model import basic as basic_pedestal_model
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import formula_config
from torax.sources import formulas
from torax.sources import register_source
Expand All @@ -39,13 +42,15 @@
from torax.transport_model import constant as constant_transport
from torax.transport_model import critical_gradient as critical_gradient_transport
from torax.transport_model import qlknn_wrapper
from torax.transport_model import transport_model as transport_model_lib


# pylint: disable=g-import-not-at-top
try:
from torax.transport_model import qualikiz_wrapper
_QUALIKIZ_TRANSPORT_MODEL_AVAILABLE = True
except ImportError:
_QUALIKIZ_TRANSPORT_MODEL_AVAILABLE = False
from torax.transport_model import transport_model as transport_model_lib
# pylint: enable=g-import-not-at-top
# pylint: disable=invalid-name

Expand Down Expand Up @@ -239,6 +244,7 @@ def build_sim_from_config(
'runtime_params',
'geometry',
'sources',
'pedestal',
'transport',
'stepper',
'time_step_calculator',
Expand Down Expand Up @@ -268,6 +274,9 @@ def build_sim_from_config(
config['transport']
),
stepper_builder=build_stepper_builder_from_config(config['stepper']),
pedestal_model_builder=build_pedestal_model_builder_from_config(
config['pedestal']
),
time_step_calculator=build_time_step_calculator_from_config(
config['time_step_calculator']
),
Expand Down Expand Up @@ -601,6 +610,19 @@ def build_transport_model_builder_from_config(
raise ValueError(f'Unknown transport model: {transport_model}')


def build_pedestal_model_builder_from_config(
pedestal_config: dict[str, Any],
) -> pedestal_model_lib.PedestalModelBuilder:
"""Builds a `PedestalModelBuilder` from the input config."""
runtime_params = basic_pedestal_model.RuntimeParams()
runtime_params = config_args.recursive_replace(
runtime_params, **pedestal_config
)
return basic_pedestal_model.BasicPedestalModelBuilder(
runtime_params=runtime_params
)


def build_stepper_builder_from_config(
stepper_config: dict[str, Any],
) -> stepper_lib.StepperBuilder:
Expand Down
24 changes: 1 addition & 23 deletions torax/config/profile_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,9 @@ class ProfileConditions(
ne_bound_right: interpolated_param.TimeInterpolatedInput | None = None
ne_bound_right_is_fGW: bool = False
ne_bound_right_is_absolute: bool = False

# Internal boundary condition (pedestal)
# Internal boundary condition (pedestal)
# Do not set internal boundary condition if this is False
set_pedestal: interpolated_param.TimeInterpolatedInput = True
# ion pedestal top temperature in keV
Tiped: interpolated_param.TimeInterpolatedInput = 5.0
# electron pedestal top temperature in keV
Teped: interpolated_param.TimeInterpolatedInput = 5.0
# pedestal top electron density
# In units of reference density if neped_is_fGW = False.
# In Greenwald fraction if neped_is_fGW = True.
neped: interpolated_param.TimeInterpolatedInput = 0.7
neped_is_fGW: bool = False
# Set ped top location.
Ped_top: interpolated_param.TimeInterpolatedInput = 0.91

# current profiles (broad "Ohmic" + localized "external" currents)
# peaking factor of "Ohmic" current: johm = j0*(1 - r^2/a^2)^nu
nu: float = 3.0
Expand Down Expand Up @@ -206,10 +193,6 @@ class ProfileConditionsProvider(
| interpolated_param.InterpolatedVarTimeRho
)
set_pedestal: interpolated_param.InterpolatedVarSingleAxis
Tiped: interpolated_param.InterpolatedVarSingleAxis
Teped: interpolated_param.InterpolatedVarSingleAxis
neped: interpolated_param.InterpolatedVarSingleAxis
Ped_top: interpolated_param.InterpolatedVarSingleAxis

@override
def build_dynamic_params(
Expand Down Expand Up @@ -241,11 +224,6 @@ class DynamicProfileConditions:
ne_bound_right_is_fGW: bool
ne_bound_right_is_absolute: bool
set_pedestal: array_typing.ScalarBool
Tiped: array_typing.ScalarFloat
Teped: array_typing.ScalarFloat
neped: array_typing.ScalarFloat
neped_is_fGW: bool
Ped_top: array_typing.ScalarFloat
nu: float
initial_j_is_total_current: bool
initial_psi_from_j: bool
Loading

0 comments on commit ae3b9f2

Please sign in to comment.