diff --git a/torax/config/profile_conditions.py b/torax/config/profile_conditions.py index 96e574c7..1ba8be3e 100644 --- a/torax/config/profile_conditions.py +++ b/torax/config/profile_conditions.py @@ -35,11 +35,18 @@ class ProfileConditions( ): """Prescribed values and boundary conditions for the core profiles.""" - # total plasma current in MA + # Total plasma current in MA # Note that if Ip_from_parameters=False in geometry, then this Ip will be - # overwritten by values from the geometry data + # overwritten by values from the geometry data. + # If Vloop_bound_right is not None, then this Ip is used as an + # initial condition ONLY. Ip_tot: interpolated_param.TimeInterpolatedInput = 15.0 + # Boundary condition at LCFS for Vloop ( = dpsi_lcfs/dt ) + # If this is `None` the boundary condition for the psi equation at each timestep + # will instead be taken from `Ip_tot`. + Vloop_bound_right: interpolated_param.TimeInterpolatedInput | None = None + # Temperature boundary conditions at r=Rmin. If this is `None` the boundary # condition will instead be taken from `Ti` and `Te` at rhon=1. Ti_bound_right: interpolated_param.TimeInterpolatedInput | None = None @@ -175,6 +182,7 @@ class ProfileConditionsProvider( runtime_params_config: ProfileConditions Ip_tot: interpolated_param.InterpolatedVarSingleAxis + Vloop_bound_right: interpolated_param.InterpolatedVarSingleAxis | None Ti_bound_right: ( interpolated_param.InterpolatedVarSingleAxis | interpolated_param.InterpolatedVarTimeRho @@ -208,6 +216,7 @@ class DynamicProfileConditions: """Prescribed values and boundary conditions for the core profiles.""" Ip_tot: array_typing.ScalarFloat + Vloop_bound_right: array_typing.ScalarFloat | None Ti_bound_right: array_typing.ScalarFloat Te_bound_right: array_typing.ScalarFloat # Temperature profiles defined on the cell grid. diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index 114646be..a1dd73b9 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -19,8 +19,10 @@ evolved by the PDE system. """ import dataclasses + import jax from jax import numpy as jnp + from torax import constants from torax import geometry from torax import jax_utils @@ -548,7 +550,7 @@ def _update_psi_from_j( Returns: psi: Poloidal flux cell variable. """ - psi_grad_constraint = _calculate_psi_grad_constraint( + psi_grad_constraint = _calculate_psi_grad_constraint_from_Ip_tot( dynamic_runtime_params_slice, geo, ) @@ -584,11 +586,11 @@ def _update_psi_from_j( # pylint: enable=invalid-name -def _calculate_psi_grad_constraint( +def _calculate_psi_grad_constraint_from_Ip_tot( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: Geometry, ) -> jax.Array: - """Calculates the constraint on the poloidal flux (psi).""" + """Calculates the gradient constraint on the poloidal flux (psi) from Ip.""" return ( dynamic_runtime_params_slice.profile_conditions.Ip_tot * 1e6 @@ -597,6 +599,19 @@ def _calculate_psi_grad_constraint( ) +def _psi_value_constraint_from_Vloop( + dt: jax.Array, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + core_profiles_t_minus_dt: state.CoreProfiles, + geo: Geometry, +) -> jax.Array: + """Calculates the value constraint on the poloidal flux (psi) from Vloop.""" + return ( + core_profiles_t_minus_dt.psi.face_value()[-1] + + dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right * dt + ) + + def _init_psi_and_current( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, @@ -624,16 +639,39 @@ def _init_psi_and_current( Returns: Refined core profiles. """ - # Retrieving psi from the profile conditions. + use_Vloop_bound_right = ( + dynamic_runtime_params_slice.profile_conditions.Vloop_bound_right is not None + ) + + # Case 1: retrieving psi from the profile conditions. if dynamic_runtime_params_slice.profile_conditions.psi is not None: - psi = cell_variable.CellVariable( + # Calculate the dpsi/drho necessary to achieve the given Ip_tot + dpsi_drho_edge = _calculate_psi_grad_constraint_from_Ip_tot( + dynamic_runtime_params_slice, + geo, + ) + + # Set the psi BCs to ensure the correct Ip_tot + if use_Vloop_bound_right: + # Extrapolate using the dpsi/drho calculated above to set the psi value at the right face + psi = cell_variable.CellVariable( value=dynamic_runtime_params_slice.profile_conditions.psi, - right_face_grad_constraint=_calculate_psi_grad_constraint( - dynamic_runtime_params_slice, - geo, + right_face_grad_constraint=None, + right_face_constraint=( + dynamic_runtime_params_slice.profile_conditions.psi[-1] + + dpsi_drho_edge * geo.drho / 2 ), dr=geo.drho_norm, - ) + ) + else: + # Use the dpsi/drho calculated above as the right face gradient constraint + psi = cell_variable.CellVariable( + value=dynamic_runtime_params_slice.profile_conditions.psi, + right_face_grad_constraint=dpsi_drho_edge, + right_face_constraint=None, + dr=geo.drho_norm, + ) + core_profiles = dataclasses.replace(core_profiles, psi=psi) currents = _calculate_currents_from_psi( dynamic_runtime_params_slice=dynamic_runtime_params_slice, @@ -642,7 +680,8 @@ def _init_psi_and_current( core_profiles=core_profiles, source_models=source_models, ) - # Retrieving psi from the standard geometry input. + + # Case 2: retrieving psi from the standard geometry input. elif ( isinstance(geo, geometry.StandardGeometry) and not dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j @@ -650,15 +689,32 @@ def _init_psi_and_current( # psi is already provided from a numerical equilibrium, so no need to # first calculate currents. However, non-inductive currents are still # calculated and used in current diffusion equation. + + # Calculate the dpsi/drho necessary to achieve the given Ip_tot + dpsi_drho_edge = _calculate_psi_grad_constraint_from_Ip_tot( + dynamic_runtime_params_slice, + geo, + ) + + # Set the psi BCs based on whether Vloop is provided and the source of Ip + if use_Vloop_bound_right and geo.Ip_from_parameters: + right_face_grad_constraint = None + right_face_constraint = geo.psi_from_Ip[-1] + dpsi_drho_edge * geo.drho / 2 + elif use_Vloop_bound_right: + right_face_grad_constraint = None + right_face_constraint = geo.psi_from_Ip[-1] + else: + right_face_grad_constraint = dpsi_drho_edge + right_face_constraint = None + psi = cell_variable.CellVariable( - value=geo.psi_from_Ip, - right_face_grad_constraint=_calculate_psi_grad_constraint( - dynamic_runtime_params_slice, - geo, - ), - dr=geo.drho_norm, + value=geo.psi_from_Ip, # Use psi from equilibrium + right_face_grad_constraint=right_face_grad_constraint, + right_face_constraint=right_face_constraint, + dr=geo.drho_norm, ) core_profiles = dataclasses.replace(core_profiles, psi=psi) + # Calculate non-inductive currents currents = _calculate_currents_from_psi( dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice, @@ -666,11 +722,16 @@ def _init_psi_and_current( core_profiles=core_profiles, source_models=source_models, ) - # Calculating j according to nu formula and psi from j. + + # Case 3: calculating j according to nu formula and psi from j. elif ( isinstance(geo, geometry.CircularAnalyticalGeometry) or dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j ): + # TODO: Vloop_bound_right is not yet supported for this case. + if use_Vloop_bound_right: + raise NotImplementedError('Vloop_bound_right not yet supported for this case.') + currents = _prescribe_currents_no_bootstrap( dynamic_runtime_params_slice=dynamic_runtime_params_slice, static_runtime_params_slice=static_runtime_params_slice, @@ -910,7 +971,9 @@ def get_update(x_new, var): def compute_boundary_conditions( - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + dt: jax.Array, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + core_profiles_t_minus_dt: state.CoreProfiles, geo: geometry.Geometry, ) -> dict[str, dict[str, jax.Array | None]]: """Computes boundary conditions for time t and returns updates to State. @@ -925,17 +988,17 @@ def compute_boundary_conditions( values in a State object. """ Ti_bound_right = jax_utils.error_if_not_positive( # pylint: disable=invalid-name - dynamic_runtime_params_slice.profile_conditions.Ti_bound_right, + dynamic_runtime_params_slice_t.profile_conditions.Ti_bound_right, 'Ti_bound_right', ) Te_bound_right = jax_utils.error_if_not_positive( # pylint: disable=invalid-name - dynamic_runtime_params_slice.profile_conditions.Te_bound_right, + dynamic_runtime_params_slice_t.profile_conditions.Te_bound_right, 'Te_bound_right', ) ne = _get_ne( - dynamic_runtime_params_slice, + dynamic_runtime_params_slice_t, geo, ) ne_bound_right = ne.right_face_constraint @@ -946,14 +1009,15 @@ def compute_boundary_conditions( # Zeff = (ni + Zimp**2 * nimp)/ne ; nimp*Zimp + ni = ne dilution_factor_edge = physics.get_main_ion_dilution_factor( - dynamic_runtime_params_slice.plasma_composition.Zimp, - dynamic_runtime_params_slice.plasma_composition.Zeff_face[-1], + dynamic_runtime_params_slice_t.plasma_composition.Zimp, + dynamic_runtime_params_slice_t.plasma_composition.Zeff_face[-1], ) ni_bound_right = ne_bound_right * dilution_factor_edge nimp_bound_right = ( ne_bound_right - ni_bound_right - ) / dynamic_runtime_params_slice.plasma_composition.Zimp + ) / dynamic_runtime_params_slice_t.plasma_composition.Zimp + return { 'temp_ion': dict( @@ -982,12 +1046,27 @@ def compute_boundary_conditions( right_face_constraint=jnp.array(nimp_bound_right), ), 'psi': dict( - right_face_grad_constraint=_calculate_psi_grad_constraint( - dynamic_runtime_params_slice, + right_face_grad_constraint=( + _calculate_psi_grad_constraint_from_Ip_tot( + dynamic_runtime_params_slice_t, geo, + ) + if dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right + is None + else None ), - ), - } + right_face_constraint=( + _psi_value_constraint_from_Vloop( + dt, + dynamic_runtime_params_slice_t, + core_profiles_t_minus_dt, + geo, + ) + if dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right is not None + else None + ), + ), + } # pylint: disable=invalid-name diff --git a/torax/output.py b/torax/output.py index f5d01c25..5c37176b 100644 --- a/torax/output.py +++ b/torax/output.py @@ -60,6 +60,7 @@ class ToraxSimOutputs: PSI = "psi" PSIDOT = "psidot" PSI_RIGHT_GRAD_BC = "psi_right_grad_bc" +PSI_RIGHT_BC = "psi_right_bc" NE = "ne" NE_RIGHT_BC = "ne_right_bc" NI = "ni" @@ -195,9 +196,9 @@ def __init__( post_processed_output = [ state.post_processed_outputs for state in sim_outputs.sim_history ] - stack = lambda *ys: jnp.stack(ys) + stack = lambda *ys: jnp.stack(jnp.array(ys)) self.core_profiles: state.CoreProfiles = jax.tree_util.tree_map( - stack, *core_profiles + stack, *core_profiles, is_leaf=lambda x: x is None, ) self.core_sources: source_profiles.SourceProfiles = jax.tree_util.tree_map( stack, *core_sources @@ -263,6 +264,9 @@ def _get_core_profiles( xr_dict[PSI_RIGHT_GRAD_BC] = ( self.core_profiles.psi.right_face_grad_constraint ) + xr_dict[PSI_RIGHT_BC] = ( + self.core_profiles.psi.right_face_constraint + ) xr_dict[PSIDOT] = self.core_profiles.psidot.value xr_dict[NE] = self.core_profiles.ne.value xr_dict[NE_RIGHT_BC] = self.core_profiles.ne.right_face_constraint diff --git a/torax/sim.py b/torax/sim.py index d78c29ac..e91ec57a 100644 --- a/torax/sim.py +++ b/torax/sim.py @@ -503,6 +503,7 @@ def step( # conditions and time-dependent prescribed profiles not directly solved by # PDE system. core_profiles_t_plus_dt = provide_core_profiles_t_plus_dt( + dt=dt, static_runtime_params_slice=static_runtime_params_slice, dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo_t_plus_dt=geo_t_plus_dt, @@ -624,10 +625,11 @@ def body_fun( ) core_profiles_t_plus_dt = provide_core_profiles_t_plus_dt( - core_profiles_t=core_profiles_t, - dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, + dt=dt, static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo_t_plus_dt=geo_t_plus_dt, + core_profiles_t=core_profiles_t, ) core_profiles, core_sources, core_transport, stepper_numeric_outputs = ( self._stepper_fn( @@ -1611,6 +1613,7 @@ def _update_psidot( def provide_core_profiles_t_plus_dt( + dt: jax.Array, static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo_t_plus_dt: geometry.Geometry, @@ -1619,7 +1622,9 @@ def provide_core_profiles_t_plus_dt( """Provides state at t_plus_dt with new boundary conditions and prescribed profiles.""" updated_boundary_conditions = ( core_profile_setters.compute_boundary_conditions( + dt, dynamic_runtime_params_slice_t_plus_dt, + core_profiles_t, geo_t_plus_dt, ) ) diff --git a/torax/tests/boundary_conditions.py b/torax/tests/boundary_conditions.py index 9c451151..e3066492 100644 --- a/torax/tests/boundary_conditions.py +++ b/torax/tests/boundary_conditions.py @@ -102,8 +102,10 @@ def test_setting_boundary_conditions( ) bc = core_profile_setters.compute_boundary_conditions( - dynamic_runtime_params_slice, - geo, + dt=runtime_params.numerics.fixed_dt, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice, + core_profiles_t_minus_dt=core_profiles, + geo=geo, ) updated = config_args.recursive_replace(core_profiles, **bc) diff --git a/torax/tests/core_profile_setters.py b/torax/tests/core_profile_setters.py index 51194ad1..09089155 100644 --- a/torax/tests/core_profile_setters.py +++ b/torax/tests/core_profile_setters.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for module torax.boundary_conditions.""" +"""Tests for module torax.core_profile_setters.""" from absl.testing import absltest from absl.testing import parameterized @@ -494,8 +494,10 @@ def test_compute_boundary_conditions_ne( ) boundary_conditions = core_profile_setters.compute_boundary_conditions( - dynamic_runtime_params_slice, - self.geo, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice, + core_profiles_t_minus_dt=None, # This test does not hit the conditional that requires this + geo=self.geo, + dt=runtime_params.numerics.fixed_dt, ) if (ne_is_fGW and ne_bound_right is None) or ( @@ -590,8 +592,10 @@ def test_compute_boundary_conditions_Te( ) boundary_conditions = core_profile_setters.compute_boundary_conditions( - dynamic_runtime_params_slice, - self.geo, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice, + core_profiles_t_minus_dt=None, # This test does not hit the conditional that requires this + geo=self.geo, + dt=runtime_params.numerics.fixed_dt, ) self.assertEqual( @@ -621,13 +625,13 @@ def test_compute_boundary_conditions_Ti( stepper=stepper_params_lib.RuntimeParams(), torax_mesh=self.geo.torax_mesh, ) - dynamic_runtime_params_slice = provider( - t=1.0, - ) + dynamic_runtime_params_slice = provider(t=1.0) boundary_conditions = core_profile_setters.compute_boundary_conditions( - dynamic_runtime_params_slice, - self.geo, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice, + core_profiles_t_minus_dt=None, # This test does not hit the conditional that requires this + geo=self.geo, + dt=runtime_params.numerics.fixed_dt, ) self.assertEqual( diff --git a/torax/tests/test_lib/explicit_stepper.py b/torax/tests/test_lib/explicit_stepper.py index 6c45bedc..7fa12343 100644 --- a/torax/tests/test_lib/explicit_stepper.py +++ b/torax/tests/test_lib/explicit_stepper.py @@ -119,8 +119,10 @@ def __call__( # Update the potentially time-dependent boundary conditions as well. updated_boundary_conditions = ( core_profile_setters.compute_boundary_conditions( - dynamic_runtime_params_slice_t_plus_dt, - geo_t, + dt=dynamic_runtime_params_slice_t_plus_dt.numerics.fixed_dt, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t_plus_dt, + core_profiles_t_minus_dt=core_profiles_t, + geo=geo_t_plus_dt, ) ) temp_ion_new = dataclasses.replace(