From c9134bdf35118de5556aa4be3b62bfea038d940f Mon Sep 17 00:00:00 2001 From: Theo Brown <7982453+theo-brown@users.noreply.github.com> Date: Mon, 21 Oct 2024 11:56:59 +0100 Subject: [PATCH 1/8] Add Vloop BC to config.profile_conditions --- torax/config/profile_conditions.py | 21 +++++++++++++++++---- torax/config/tests/profile_conditions.py | 23 +++++++++++++++++++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/torax/config/profile_conditions.py b/torax/config/profile_conditions.py index 47cb0f17..660896bb 100644 --- a/torax/config/profile_conditions.py +++ b/torax/config/profile_conditions.py @@ -20,12 +20,13 @@ import logging import chex +from typing_extensions import override + from torax import array_typing from torax import geometry from torax import interpolated_param from torax.config import base from torax.config import config_args -from typing_extensions import override # pylint: disable=invalid-name @@ -38,7 +39,11 @@ class ProfileConditions( # total plasma current in MA # Note that if Ip_from_parameters=False in geometry, then this Ip will be # overwritten by values from the geometry data - Ip: interpolated_param.TimeInterpolatedInput = 15.0 + Ip: interpolated_param.TimeInterpolatedInput | None = 15.0 + + # Boundary condition at LCFS for Vloop ( = dpsi/dt ) + # Used if total plasma current is a state and edge flux is an input + Vloop_bound_right: interpolated_param.TimeInterpolatedInput | None = None # Temperature boundary conditions at r=Rmin. If this is `None` the boundary # condition will instead be taken from `Ti` and `Te` at rhon=1. @@ -142,6 +147,12 @@ def __post_init__(self): 'ne', ) + # Check that only one of Vloop_bound_right and Ip are provided + # ^ is XOR + if not ((self.Ip is None) ^ (self.Vloop_bound_right is None)): + raise ValueError("Only one of Ip and Vloop_bound_right can be defined") + + @override def make_provider( self, @@ -187,7 +198,8 @@ class ProfileConditionsProvider( """Provider to retrieve initial and prescribed values and boundary conditions.""" runtime_params_config: ProfileConditions - Ip: interpolated_param.InterpolatedVarSingleAxis + Ip: interpolated_param.InterpolatedVarSingleAxis | None + Vloop_bound_right: interpolated_param.InterpolatedVarSingleAxis | None Ti_bound_right: ( interpolated_param.InterpolatedVarSingleAxis | interpolated_param.InterpolatedVarTimeRho @@ -224,7 +236,8 @@ def build_dynamic_params( class DynamicProfileConditions: """Prescribed values and boundary conditions for the core profiles.""" - Ip: array_typing.ScalarFloat + Ip: array_typing.ScalarFloat | None + Vloop_bound_right: array_typing.ScalarFloat Ti_bound_right: array_typing.ScalarFloat Te_bound_right: array_typing.ScalarFloat # Temperature profiles defined on the cell grid. diff --git a/torax/config/tests/profile_conditions.py b/torax/config/tests/profile_conditions.py index 1b3bfb84..43f2db79 100644 --- a/torax/config/tests/profile_conditions.py +++ b/torax/config/tests/profile_conditions.py @@ -250,6 +250,29 @@ def test_profile_conditions_raises_error_if_boundary_condition_not_defined( Ti_bound_right=None, ) + def test_raises_error_if_Ip_and_Vloop_clash(self, values, raises): + """Tests that an error is raised if a) Vloop and Ip are both None or + b) Vloop and Ip are both not None.""" + if raises: + with self.assertRaises(ValueError): + profile_conditions.ProfileConditions( + Ip=values, + Vloop_bound_right=values, + ) + with self.assertRaises(ValueError): + profile_conditions.ProfileConditions( + Ip=None, + Vloop_bound_right=None, + ) + else: + profile_conditions.ProfileConditions( + Ip=values, + Vloop_bound_right=values, + ) + profile_conditions.ProfileConditions( + Ip=None, + Vloop_bound_right=None, + ) if __name__ == '__main__': absltest.main() From 1932026c92e7ff3ae1f74eb29bfa97fac81b69f0 Mon Sep 17 00:00:00 2001 From: Theo Brown <7982453+theo-brown@users.noreply.github.com> Date: Mon, 21 Oct 2024 13:36:48 +0100 Subject: [PATCH 2/8] Add Vloop BC for Psi value --- torax/config/profile_conditions.py | 2 +- torax/core_profile_setters.py | 60 ++++++++++++++++++++++-------- torax/sim.py | 15 ++++++-- 3 files changed, 57 insertions(+), 20 deletions(-) diff --git a/torax/config/profile_conditions.py b/torax/config/profile_conditions.py index 660896bb..93c970cf 100644 --- a/torax/config/profile_conditions.py +++ b/torax/config/profile_conditions.py @@ -237,7 +237,7 @@ class DynamicProfileConditions: """Prescribed values and boundary conditions for the core profiles.""" Ip: array_typing.ScalarFloat | None - Vloop_bound_right: array_typing.ScalarFloat + Vloop_bound_right: array_typing.ScalarFloat | None Ti_bound_right: array_typing.ScalarFloat Te_bound_right: array_typing.ScalarFloat # Temperature profiles defined on the cell grid. diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index 9795b775..0c6f0b4a 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -19,8 +19,10 @@ evolved by the PDE system. """ import dataclasses + import jax from jax import numpy as jnp + from torax import constants from torax import geometry from torax import jax_utils @@ -541,7 +543,7 @@ def _update_psi_from_j( Returns: psi: Poloidal flux cell variable. """ - psi_grad_constraint = _calculate_psi_grad_constraint( + psi_grad_constraint = _psi_grad_constraint_from_Ip( dynamic_runtime_params_slice, geo, ) @@ -579,11 +581,11 @@ def _update_psi_from_j( # pylint: enable=invalid-name -def _calculate_psi_grad_constraint( +def _psi_grad_constraint_from_Ip( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: Geometry, ) -> jax.Array: - """Calculates the constraint on the poloidal flux (psi).""" + """Calculates the gradient constraint on the poloidal flux (psi) from Ip.""" return ( dynamic_runtime_params_slice.profile_conditions.Ip * 1e6 @@ -591,6 +593,17 @@ def _calculate_psi_grad_constraint( / (geo.g2g3_over_rhon_face[-1] * geo.F_face[-1]) ) +def _psi_value_constraint_from_Vloop( + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + core_profiles_t_minus_dt: state.CoreProfiles, + geo: Geometry, + dt: jax.Array, +) -> jax.Array: + """Calculates the value constraint on the poloidal flux (psi) from Vloop.""" + return ( + core_profiles_t_minus_dt.psi.face_value(geo.rho_b) + + dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right * dt + ) def _initial_psi( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, @@ -624,7 +637,7 @@ def _initial_psi( if dynamic_runtime_params_slice.profile_conditions.psi is not None: psi = cell_variable.CellVariable( value=dynamic_runtime_params_slice.profile_conditions.psi, - right_face_grad_constraint=_calculate_psi_grad_constraint( + right_face_grad_constraint=_psi_grad_constraint_from_Ip( dynamic_runtime_params_slice, geo, ), @@ -687,7 +700,7 @@ def _initial_psi( # psi is already provided from a numerical equilibrium, so no need to # first calculate currents. However, non-inductive currents are still # calculated and used in current diffusion equation. - psi_grad_constraint = _calculate_psi_grad_constraint( + psi_grad_constraint = _psi_grad_constraint_from_Ip( dynamic_runtime_params_slice, geo, ) @@ -907,8 +920,10 @@ def get_update(x_new, var): def compute_boundary_conditions( - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + core_profiles_t_minus_dt: state.CoreProfiles, geo: geometry.Geometry, + dt: jax.Array, ) -> dict[str, dict[str, jax.Array | None]]: """Computes boundary conditions for time t and returns updates to State. @@ -922,17 +937,17 @@ def compute_boundary_conditions( values in a State object. """ Ti_bound_right = jax_utils.error_if_not_positive( # pylint: disable=invalid-name - dynamic_runtime_params_slice.profile_conditions.Ti_bound_right, + dynamic_runtime_params_slice_t.profile_conditions.Ti_bound_right, 'Ti_bound_right', ) Te_bound_right = jax_utils.error_if_not_positive( # pylint: disable=invalid-name - dynamic_runtime_params_slice.profile_conditions.Te_bound_right, + dynamic_runtime_params_slice_t.profile_conditions.Te_bound_right, 'Te_bound_right', ) ne = _get_ne( - dynamic_runtime_params_slice, + dynamic_runtime_params_slice_t, geo, ) ne_bound_right = ne.right_face_constraint @@ -943,14 +958,15 @@ def compute_boundary_conditions( # Zeff = (ni + Zimp**2 * nimp)/ne ; nimp*Zimp + ni = ne dilution_factor_edge = physics.get_main_ion_dilution_factor( - dynamic_runtime_params_slice.plasma_composition.Zimp, - dynamic_runtime_params_slice.plasma_composition.Zeff_face[-1], + dynamic_runtime_params_slice_t.plasma_composition.Zimp, + dynamic_runtime_params_slice_t.plasma_composition.Zeff_face[-1], ) ni_bound_right = ne_bound_right * dilution_factor_edge nimp_bound_right = ( ne_bound_right - ni_bound_right - ) / dynamic_runtime_params_slice.plasma_composition.Zimp + ) / dynamic_runtime_params_slice_t.plasma_composition.Zimp + return { 'temp_ion': dict( @@ -979,11 +995,25 @@ def compute_boundary_conditions( right_face_constraint=jnp.array(nimp_bound_right), ), 'psi': dict( - right_face_grad_constraint=_calculate_psi_grad_constraint( - dynamic_runtime_params_slice, + right_face_grad_constraint=jnp.where( + dynamic_runtime_params_slice_t.profile_conditions.Ip is not None, + _psi_grad_constraint_from_Ip( + dynamic_runtime_params_slice_t, geo, + ), + None ), - ), + right_face_constraint=jnp.where( + dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right is not None, + _psi_value_constraint_from_Vloop( + dynamic_runtime_params_slice_t, + core_profiles_t_minus_dt, + geo, + dt, + ), + None, + ), + ), } diff --git a/torax/sim.py b/torax/sim.py index 278d3c34..be27c674 100644 --- a/torax/sim.py +++ b/torax/sim.py @@ -28,13 +28,15 @@ import dataclasses import time -from typing import Any, Optional +from typing import Any +from typing import Optional -from absl import logging import chex import jax import jax.numpy as jnp import numpy as np +from absl import logging + from torax import calc_coeffs from torax import core_profile_setters from torax import geometry @@ -485,6 +487,7 @@ def step( dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo_t_plus_dt=geo_t_plus_dt, core_profiles_t=core_profiles_t, + dt=dt, ) # Initial trial for stepper. If did not converge (can happen for nonlinear @@ -602,10 +605,11 @@ def body_fun( ) core_profiles_t_plus_dt = provide_core_profiles_t_plus_dt( - core_profiles_t=core_profiles_t, - dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo_t_plus_dt=geo_t_plus_dt, + core_profiles_t=core_profiles_t, + dt=dt, ) core_profiles, core_sources, core_transport, stepper_numeric_outputs = ( self._stepper_fn( @@ -1508,12 +1512,15 @@ def provide_core_profiles_t_plus_dt( dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo_t_plus_dt: geometry.Geometry, core_profiles_t: state.CoreProfiles, + dt: jax.Array, ) -> state.CoreProfiles: """Provides state at t_plus_dt with new boundary conditions and prescribed profiles.""" updated_boundary_conditions = ( core_profile_setters.compute_boundary_conditions( dynamic_runtime_params_slice_t_plus_dt, + core_profiles_t, geo_t_plus_dt, + dt, ) ) updated_values = core_profile_setters.updated_prescribed_core_profiles( From 4a942a03e1aa770f734430750b29fbbc490cf19d Mon Sep 17 00:00:00 2001 From: Theo Brown <7982453+theo-brown@users.noreply.github.com> Date: Fri, 22 Nov 2024 18:25:43 +0000 Subject: [PATCH 3/8] Simplify Vloop config --- torax/config/profile_conditions.py | 29 ++++-------------------- torax/config/tests/profile_conditions.py | 29 +++--------------------- torax/core_profile_setters.py | 24 ++++++++------------ torax/tests/core_profile_setters.py | 26 ++++++++++++--------- 4 files changed, 33 insertions(+), 75 deletions(-) diff --git a/torax/config/profile_conditions.py b/torax/config/profile_conditions.py index 7b00aca5..c058853d 100644 --- a/torax/config/profile_conditions.py +++ b/torax/config/profile_conditions.py @@ -36,18 +36,13 @@ class ProfileConditions( ): """Prescribed values and boundary conditions for the core profiles.""" - # total plasma current in MA + # Total plasma current in MA # Note that if Ip_from_parameters=False in geometry, then this Ip will be # overwritten by values from the geometry data -<<<<<<< HEAD - Ip: interpolated_param.TimeInterpolatedInput | None = 15.0 + Ip_tot: interpolated_param.TimeInterpolatedInput = 15.0 - # Boundary condition at LCFS for Vloop ( = dpsi/dt ) - # Used if total plasma current is a state and edge flux is an input + # Boundary condition at LCFS for Vloop ( = dpsi_lcfs/dt ) Vloop_bound_right: interpolated_param.TimeInterpolatedInput | None = None -======= - Ip_tot: interpolated_param.TimeInterpolatedInput = 15.0 ->>>>>>> main # Temperature boundary conditions at r=Rmin. If this is `None` the boundary # condition will instead be taken from `Ti` and `Te` at rhon=1. @@ -151,12 +146,6 @@ def __post_init__(self): 'ne', ) - # Check that only one of Vloop_bound_right and Ip are provided - # ^ is XOR - if not ((self.Ip is None) ^ (self.Vloop_bound_right is None)): - raise ValueError("Only one of Ip and Vloop_bound_right can be defined") - - @override def make_provider( self, @@ -202,12 +191,8 @@ class ProfileConditionsProvider( """Provider to retrieve initial and prescribed values and boundary conditions.""" runtime_params_config: ProfileConditions -<<<<<<< HEAD - Ip: interpolated_param.InterpolatedVarSingleAxis | None - Vloop_bound_right: interpolated_param.InterpolatedVarSingleAxis | None -======= Ip_tot: interpolated_param.InterpolatedVarSingleAxis ->>>>>>> main + Vloop_bound_right: interpolated_param.InterpolatedVarSingleAxis | None Ti_bound_right: ( interpolated_param.InterpolatedVarSingleAxis | interpolated_param.InterpolatedVarTimeRho @@ -244,12 +229,8 @@ def build_dynamic_params( class DynamicProfileConditions: """Prescribed values and boundary conditions for the core profiles.""" -<<<<<<< HEAD - Ip: array_typing.ScalarFloat | None - Vloop_bound_right: array_typing.ScalarFloat | None -======= Ip_tot: array_typing.ScalarFloat ->>>>>>> main + Vloop_bound_right: array_typing.ScalarFloat | None Ti_bound_right: array_typing.ScalarFloat Te_bound_right: array_typing.ScalarFloat # Temperature profiles defined on the cell grid. diff --git a/torax/config/tests/profile_conditions.py b/torax/config/tests/profile_conditions.py index 43f2db79..c20a429f 100644 --- a/torax/config/tests/profile_conditions.py +++ b/torax/config/tests/profile_conditions.py @@ -14,14 +14,15 @@ """Unit tests for the `torax.config.profile_conditions` module.""" +import numpy as np +import xarray as xr from absl.testing import absltest from absl.testing import parameterized -import numpy as np + from torax import geometry from torax import interpolated_param from torax.config import config_args from torax.config import profile_conditions -import xarray as xr # pylint: disable=invalid-name @@ -250,29 +251,5 @@ def test_profile_conditions_raises_error_if_boundary_condition_not_defined( Ti_bound_right=None, ) - def test_raises_error_if_Ip_and_Vloop_clash(self, values, raises): - """Tests that an error is raised if a) Vloop and Ip are both None or - b) Vloop and Ip are both not None.""" - if raises: - with self.assertRaises(ValueError): - profile_conditions.ProfileConditions( - Ip=values, - Vloop_bound_right=values, - ) - with self.assertRaises(ValueError): - profile_conditions.ProfileConditions( - Ip=None, - Vloop_bound_right=None, - ) - else: - profile_conditions.ProfileConditions( - Ip=values, - Vloop_bound_right=values, - ) - profile_conditions.ProfileConditions( - Ip=None, - Vloop_bound_right=None, - ) - if __name__ == '__main__': absltest.main() diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index 1b61aed3..2fd6853e 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -516,7 +516,7 @@ def _update_psi_from_j( Returns: psi: Poloidal flux cell variable. """ - psi_grad_constraint = _calculate_psi_grad_constraint_from_Ip( + psi_grad_constraint = _calculate_psi_grad_constraint_from_Ip_tot( dynamic_runtime_params_slice, geo, ) @@ -552,7 +552,7 @@ def _update_psi_from_j( # pylint: enable=invalid-name -def _calculate_psi_grad_constraint_from_Ip( +def _calculate_psi_grad_constraint_from_Ip_tot( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: Geometry, ) -> jax.Array: @@ -572,7 +572,7 @@ def _psi_value_constraint_from_Vloop( ) -> jax.Array: """Calculates the value constraint on the poloidal flux (psi) from Vloop.""" return ( - core_profiles_t_minus_dt.psi.face_value(geo.rho_b) + core_profiles_t_minus_dt.psi.face_value()[-1] + dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right * dt ) @@ -605,7 +605,7 @@ def _init_psi_and_current( if dynamic_runtime_params_slice.profile_conditions.psi is not None: psi = cell_variable.CellVariable( value=dynamic_runtime_params_slice.profile_conditions.psi, - right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip( + right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot( dynamic_runtime_params_slice, geo, ), @@ -628,7 +628,7 @@ def _init_psi_and_current( # calculated and used in current diffusion equation. psi = cell_variable.CellVariable( value=geo.psi_from_Ip, - right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip( + right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot( dynamic_runtime_params_slice, geo, ), @@ -954,23 +954,19 @@ def compute_boundary_conditions( right_face_constraint=jnp.array(nimp_bound_right), ), 'psi': dict( - right_face_grad_constraint=jnp.where( - dynamic_runtime_params_slice_t.profile_conditions.Ip is not None, - _calculate_psi_grad_constraint_from_Ip( + right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot( dynamic_runtime_params_slice_t, geo, ), - None - ), - right_face_constraint=jnp.where( - dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right is not None, + right_face_constraint=( _psi_value_constraint_from_Vloop( dynamic_runtime_params_slice_t, core_profiles_t_minus_dt, geo, dt, - ), - None, + ) + if dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right is not None + else None ), ), } diff --git a/torax/tests/core_profile_setters.py b/torax/tests/core_profile_setters.py index fb98a2fe..024dc0f9 100644 --- a/torax/tests/core_profile_setters.py +++ b/torax/tests/core_profile_setters.py @@ -14,9 +14,10 @@ """Tests for module torax.boundary_conditions.""" +import numpy as np from absl.testing import absltest from absl.testing import parameterized -import numpy as np + from torax import core_profile_setters from torax import geometry from torax import physics @@ -27,7 +28,6 @@ from torax.stepper import runtime_params as stepper_params_lib from torax.transport_model import runtime_params as transport_params_lib - SMALL_VALUE = 1e-6 @@ -494,8 +494,10 @@ def test_compute_boundary_conditions_ne( ) boundary_conditions = core_profile_setters.compute_boundary_conditions( - dynamic_runtime_params_slice, - self.geo, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice, + core_profiles_t_minus_dt=None, # This test does not hit the conditional that requires this + geo=self.geo, + dt=runtime_params.numerics.fixed_dt, ) if (ne_is_fGW and ne_bound_right is None) or ( @@ -585,8 +587,10 @@ def test_compute_boundary_conditions_Te( ) boundary_conditions = core_profile_setters.compute_boundary_conditions( - dynamic_runtime_params_slice, - self.geo, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice, + core_profiles_t_minus_dt=None, # This test does not hit the conditional that requires this + geo=self.geo, + dt=runtime_params.numerics.fixed_dt, ) self.assertEqual( @@ -616,13 +620,13 @@ def test_compute_boundary_conditions_Ti( stepper=stepper_params_lib.RuntimeParams(), torax_mesh=self.geo.torax_mesh, ) - dynamic_runtime_params_slice = provider( - t=1.0, - ) + dynamic_runtime_params_slice = provider(t=1.0) boundary_conditions = core_profile_setters.compute_boundary_conditions( - dynamic_runtime_params_slice, - self.geo, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice, + core_profiles_t_minus_dt=None, # This test does not hit the conditional that requires this + geo=self.geo, + dt=runtime_params.numerics.fixed_dt, ) self.assertEqual( From d3d66bef03b0906e47a80cec2901784f575f0f41 Mon Sep 17 00:00:00 2001 From: Theo Brown <7982453+theo-brown@users.noreply.github.com> Date: Fri, 22 Nov 2024 22:50:15 +0000 Subject: [PATCH 4/8] Fix arg order --- torax/core_profile_setters.py | 4 ++-- torax/sim.py | 10 +++++----- torax/tests/boundary_conditions.py | 9 ++++++--- torax/tests/test_lib/explicit_stepper.py | 7 +++++-- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index 2fd6853e..4203df63 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -565,10 +565,10 @@ def _calculate_psi_grad_constraint_from_Ip_tot( ) def _psi_value_constraint_from_Vloop( + dt: jax.Array, dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, core_profiles_t_minus_dt: state.CoreProfiles, geo: Geometry, - dt: jax.Array, ) -> jax.Array: """Calculates the value constraint on the poloidal flux (psi) from Vloop.""" return ( @@ -879,10 +879,10 @@ def get_update(x_new, var): def compute_boundary_conditions( + dt: jax.Array, dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, core_profiles_t_minus_dt: state.CoreProfiles, geo: geometry.Geometry, - dt: jax.Array, ) -> dict[str, dict[str, jax.Array | None]]: """Computes boundary conditions for time t and returns updates to State. diff --git a/torax/sim.py b/torax/sim.py index 0f61c627..046389a5 100644 --- a/torax/sim.py +++ b/torax/sim.py @@ -35,6 +35,7 @@ import jax import jax.numpy as jnp import numpy as np +import xarray as xr from absl import logging from torax import calc_coeffs @@ -58,7 +59,6 @@ from torax.time_step_calculator import chi_time_step_calculator from torax.time_step_calculator import time_step_calculator as ts from torax.transport_model import transport_model as transport_model_lib -import xarray as xr def _log_timestep( @@ -485,11 +485,11 @@ def step( # conditions and time-dependent prescribed profiles not directly solved by # PDE system. core_profiles_t_plus_dt = provide_core_profiles_t_plus_dt( + dt=dt, static_runtime_params_slice=static_runtime_params_slice, dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo_t_plus_dt=geo_t_plus_dt, core_profiles_t=core_profiles_t, - dt=dt, ) # Initial trial for stepper. If did not converge (can happen for nonlinear @@ -607,11 +607,11 @@ def body_fun( ) core_profiles_t_plus_dt = provide_core_profiles_t_plus_dt( + dt=dt, static_runtime_params_slice=static_runtime_params_slice, dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, geo_t_plus_dt=geo_t_plus_dt, core_profiles_t=core_profiles_t, - dt=dt, ) core_profiles, core_sources, core_transport, stepper_numeric_outputs = ( self._stepper_fn( @@ -1566,19 +1566,19 @@ def update_psidot( def provide_core_profiles_t_plus_dt( + dt: jax.Array, static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, geo_t_plus_dt: geometry.Geometry, core_profiles_t: state.CoreProfiles, - dt: jax.Array, ) -> state.CoreProfiles: """Provides state at t_plus_dt with new boundary conditions and prescribed profiles.""" updated_boundary_conditions = ( core_profile_setters.compute_boundary_conditions( + dt, dynamic_runtime_params_slice_t_plus_dt, core_profiles_t, geo_t_plus_dt, - dt, ) ) updated_values = core_profile_setters.updated_prescribed_core_profiles( diff --git a/torax/tests/boundary_conditions.py b/torax/tests/boundary_conditions.py index 24959a4c..9fb4bb8f 100644 --- a/torax/tests/boundary_conditions.py +++ b/torax/tests/boundary_conditions.py @@ -15,9 +15,10 @@ """Tests for module torax.boundary_conditions.""" +import numpy as np from absl.testing import absltest from absl.testing import parameterized -import numpy as np + from torax import constants from torax import core_profile_setters from torax import geometry @@ -97,8 +98,10 @@ def test_setting_boundary_conditions( ) bc = core_profile_setters.compute_boundary_conditions( - dynamic_runtime_params_slice, - geo, + dt=runtime_params.numerics.fixed_dt, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice, + core_profiles_t_minus_dt=core_profiles, + geo=geo, ) updated = config_args.recursive_replace(core_profiles, **bc) diff --git a/torax/tests/test_lib/explicit_stepper.py b/torax/tests/test_lib/explicit_stepper.py index cc2f9434..a468f73e 100644 --- a/torax/tests/test_lib/explicit_stepper.py +++ b/torax/tests/test_lib/explicit_stepper.py @@ -23,6 +23,7 @@ import jax from jax import numpy as jnp + from torax import constants from torax import core_profile_setters from torax import geometry @@ -119,8 +120,10 @@ def __call__( # Update the potentially time-dependent boundary conditions as well. updated_boundary_conditions = ( core_profile_setters.compute_boundary_conditions( - dynamic_runtime_params_slice_t_plus_dt, - geo_t, + dt=dynamic_runtime_params_slice_t_plus_dt.numerics.fixed_dt, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t_plus_dt, + core_profiles_t_minus_dt=core_profiles_t, + geo=geo_t_plus_dt, ) ) temp_ion_new = dataclasses.replace( From c67a98d95f6e7026d394614724538afb0e1f875e Mon Sep 17 00:00:00 2001 From: Theo Brown <7982453+theo-brown@users.noreply.github.com> Date: Tue, 26 Nov 2024 10:20:29 +0000 Subject: [PATCH 5/8] Add comment on Vloop vs Ip_tot --- torax/config/profile_conditions.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torax/config/profile_conditions.py b/torax/config/profile_conditions.py index c058853d..794b005b 100644 --- a/torax/config/profile_conditions.py +++ b/torax/config/profile_conditions.py @@ -38,10 +38,14 @@ class ProfileConditions( # Total plasma current in MA # Note that if Ip_from_parameters=False in geometry, then this Ip will be - # overwritten by values from the geometry data + # overwritten by values from the geometry data. + # If Vloop_bound_right is not None, then this Ip is used as an + # initial condition ONLY. Ip_tot: interpolated_param.TimeInterpolatedInput = 15.0 # Boundary condition at LCFS for Vloop ( = dpsi_lcfs/dt ) + # If this is `None` the boundary condition for the psi equation at each timestep + # will instead be taken from `Ip_tot`. Vloop_bound_right: interpolated_param.TimeInterpolatedInput | None = None # Temperature boundary conditions at r=Rmin. If this is `None` the boundary From 3b7f33c963cae820121485b2b52b26a856ef728c Mon Sep 17 00:00:00 2001 From: Theo Brown <7982453+theo-brown@users.noreply.github.com> Date: Tue, 26 Nov 2024 10:34:16 +0000 Subject: [PATCH 6/8] TEMP: Initial attempts at getting Vloop BC working This includes a hack to support outputting different BCs in the output file. Previously, the simulation could only use one type of BC for the whole run (ie either grad or value constraint). By wrapping the output in a jnp.array(), BCs that are None get turned into NaN, which is compatible with tree_map. Hence, this change allows you to have `grad_constraint = [XXX, None, None, ...]` and `value_constraint = [None, XXX, YYY, ...]` which is useful for testing the Vloop BC. --- torax/core_profile_setters.py | 26 ++++++++-- torax/examples/vloop.py | 90 +++++++++++++++++++++++++++++++++++ torax/output.py | 8 +++- 3 files changed, 118 insertions(+), 6 deletions(-) create mode 100644 torax/examples/vloop.py diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index 4203df63..11b33839 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -564,6 +564,7 @@ def _calculate_psi_grad_constraint_from_Ip_tot( / (geo.g2g3_over_rhon_face[-1] * geo.F_face[-1]) ) + def _psi_value_constraint_from_Vloop( dt: jax.Array, dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, @@ -576,6 +577,7 @@ def _psi_value_constraint_from_Vloop( + dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right * dt ) + def _init_psi_and_current( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: Geometry, @@ -601,8 +603,14 @@ def _init_psi_and_current( Returns: Refined core profiles. """ + use_Vloop_bound_right = ( + dynamic_runtime_params_slice.profile_conditions.Vloop_bound_right is not None + ) + # Retrieving psi from the profile conditions. if dynamic_runtime_params_slice.profile_conditions.psi is not None: + # TODO: do we need to support the case where psi is given, but Vloop_bound_right + # is used to set the BC rather than Ip_tot? psi = cell_variable.CellVariable( value=dynamic_runtime_params_slice.profile_conditions.psi, right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot( @@ -631,7 +639,12 @@ def _init_psi_and_current( right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot( dynamic_runtime_params_slice, geo, - ), + ) + if not use_Vloop_bound_right + else None, + right_face_constraint=geo.psi_from_Ip[-1] + if use_Vloop_bound_right + else None, dr=geo.drho_norm, ) core_profiles = dataclasses.replace(core_profiles, psi=psi) @@ -954,10 +967,15 @@ def compute_boundary_conditions( right_face_constraint=jnp.array(nimp_bound_right), ), 'psi': dict( - right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot( + right_face_grad_constraint=( + _calculate_psi_grad_constraint_from_Ip_tot( dynamic_runtime_params_slice_t, geo, - ), + ) + if dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right + is None + else None + ), right_face_constraint=( _psi_value_constraint_from_Vloop( dynamic_runtime_params_slice_t, @@ -969,7 +987,7 @@ def compute_boundary_conditions( else None ), ), - } + } # pylint: disable=invalid-name diff --git a/torax/examples/vloop.py b/torax/examples/vloop.py new file mode 100644 index 00000000..a58ebe6c --- /dev/null +++ b/torax/examples/vloop.py @@ -0,0 +1,90 @@ +CONFIG = { + "runtime_params": { + "profile_conditions": { + 'Ip_tot': 15.0, + "Vloop_bound_right": 0.0, + "psi": dict( + zip( + [ + 0.02, + 0.06, + 0.1, + 0.14, + 0.18, + 0.22, + 0.26, + 0.3, + 0.34, + 0.38, + 0.42, + 0.46, + 0.5, + 0.54, + 0.58, + 0.62, + 0.66, + 0.7, + 0.74, + 0.78, + 0.82, + 0.86, + 0.9, + 0.94, + 0.98, + ], + [ + 4.27722812e-02, + 3.94379591e-01, + 1.09263271e00, + 2.10681692e00, + 3.41030703e00, + 4.97594319e00, + 6.77607994e00, + 8.79231633e00, + 1.10530604e01, + 1.36745883e01, + 1.67709665e01, + 2.02615760e01, + 2.39018070e01, + 2.74979635e01, + 3.09738348e01, + 3.43088066e01, + 3.74947223e01, + 4.05261312e01, + 4.33998473e01, + 4.61154118e01, + 4.86753469e01, + 5.10851399e01, + 5.33529513e01, + 5.54890272e01, + 5.75047356e01, + ], + ) + ), + "set_pedestal": False, + } + }, + "geometry": { + "geometry_type": "circular", + }, + "sources": { + "j_bootstrap": {}, + "generic_current_source": {}, + "generic_particle_source": {}, + "gas_puff_source": {}, + "pellet_source": {}, + "generic_ion_el_heat_source": {}, + "fusion_heat_source": {}, + "qei_source": {}, + "ohmic_heat_source": {}, + }, + "transport": { + "transport_model": "constant", + }, + "stepper": { + "stepper_type": "linear", + }, + "time_step_calculator": { + "calculator_type": "chi", + }, +} diff --git a/torax/output.py b/torax/output.py index f5d01c25..5c37176b 100644 --- a/torax/output.py +++ b/torax/output.py @@ -60,6 +60,7 @@ class ToraxSimOutputs: PSI = "psi" PSIDOT = "psidot" PSI_RIGHT_GRAD_BC = "psi_right_grad_bc" +PSI_RIGHT_BC = "psi_right_bc" NE = "ne" NE_RIGHT_BC = "ne_right_bc" NI = "ni" @@ -195,9 +196,9 @@ def __init__( post_processed_output = [ state.post_processed_outputs for state in sim_outputs.sim_history ] - stack = lambda *ys: jnp.stack(ys) + stack = lambda *ys: jnp.stack(jnp.array(ys)) self.core_profiles: state.CoreProfiles = jax.tree_util.tree_map( - stack, *core_profiles + stack, *core_profiles, is_leaf=lambda x: x is None, ) self.core_sources: source_profiles.SourceProfiles = jax.tree_util.tree_map( stack, *core_sources @@ -263,6 +264,9 @@ def _get_core_profiles( xr_dict[PSI_RIGHT_GRAD_BC] = ( self.core_profiles.psi.right_face_grad_constraint ) + xr_dict[PSI_RIGHT_BC] = ( + self.core_profiles.psi.right_face_constraint + ) xr_dict[PSIDOT] = self.core_profiles.psidot.value xr_dict[NE] = self.core_profiles.ne.value xr_dict[NE_RIGHT_BC] = self.core_profiles.ne.right_face_constraint From bcdbb044340d03fa52180c0b2ee9e71ac6a029f9 Mon Sep 17 00:00:00 2001 From: Theo Brown <7982453+theo-brown@users.noreply.github.com> Date: Fri, 6 Dec 2024 20:04:14 +0000 Subject: [PATCH 7/8] Add initialisation logic --- torax/config/profile_conditions.py | 3 +- torax/config/tests/profile_conditions.py | 6 +- torax/core_profile_setters.py | 77 ++++++++++++++------ torax/examples/vloop.py | 90 ------------------------ torax/sim.py | 8 +-- torax/tests/boundary_conditions.py | 3 +- torax/tests/core_profile_setters.py | 6 +- torax/tests/test_lib/explicit_stepper.py | 1 - 8 files changed, 67 insertions(+), 127 deletions(-) delete mode 100644 torax/examples/vloop.py diff --git a/torax/config/profile_conditions.py b/torax/config/profile_conditions.py index 9aebb0bd..1ba8be3e 100644 --- a/torax/config/profile_conditions.py +++ b/torax/config/profile_conditions.py @@ -20,13 +20,12 @@ import logging import chex -from typing_extensions import override - from torax import array_typing from torax import geometry from torax import interpolated_param from torax.config import base from torax.config import config_args +from typing_extensions import override # pylint: disable=invalid-name diff --git a/torax/config/tests/profile_conditions.py b/torax/config/tests/profile_conditions.py index c20a429f..1b3bfb84 100644 --- a/torax/config/tests/profile_conditions.py +++ b/torax/config/tests/profile_conditions.py @@ -14,15 +14,14 @@ """Unit tests for the `torax.config.profile_conditions` module.""" -import numpy as np -import xarray as xr from absl.testing import absltest from absl.testing import parameterized - +import numpy as np from torax import geometry from torax import interpolated_param from torax.config import config_args from torax.config import profile_conditions +import xarray as xr # pylint: disable=invalid-name @@ -251,5 +250,6 @@ def test_profile_conditions_raises_error_if_boundary_condition_not_defined( Ti_bound_right=None, ) + if __name__ == '__main__': absltest.main() diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index 11b33839..eee5c620 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -607,18 +607,35 @@ def _init_psi_and_current( dynamic_runtime_params_slice.profile_conditions.Vloop_bound_right is not None ) - # Retrieving psi from the profile conditions. + # Case 1: retrieving psi from the profile conditions. if dynamic_runtime_params_slice.profile_conditions.psi is not None: - # TODO: do we need to support the case where psi is given, but Vloop_bound_right - # is used to set the BC rather than Ip_tot? - psi = cell_variable.CellVariable( + # Calculate the dpsi/drho necessary to achieve the given Ip_tot + dpsi_drho_edge = _calculate_psi_grad_constraint_from_Ip_tot( + dynamic_runtime_params_slice, + geo, + ) + + # Set the psi BCs to ensure the correct Ip_tot + if use_Vloop_bound_right: + # Extrapolate using the dpsi/drho calculated above to set the psi value at the right face + psi = cell_variable.CellVariable( value=dynamic_runtime_params_slice.profile_conditions.psi, - right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot( - dynamic_runtime_params_slice, - geo, + right_face_grad_constraint=None, + right_face_constraint=( + dynamic_runtime_params_slice.profile_conditions.psi[-1] + + dpsi_drho_edge * geo.drho / 2 ), dr=geo.drho_norm, - ) + ) + else: + # Use the dpsi/drho calculated above as the right face gradient constraint + psi = cell_variable.CellVariable( + value=dynamic_runtime_params_slice.profile_conditions.psi, + right_face_grad_constraint=dpsi_drho_edge, + right_face_constraint=None, + dr=geo.drho_norm, + ) + core_profiles = dataclasses.replace(core_profiles, psi=psi) currents = _calculate_currents_from_psi( dynamic_runtime_params_slice=dynamic_runtime_params_slice, @@ -626,7 +643,8 @@ def _init_psi_and_current( core_profiles=core_profiles, source_models=source_models, ) - # Retrieving psi from the standard geometry input. + + # Case 2: retrieving psi from the standard geometry input. elif ( isinstance(geo, geometry.StandardGeometry) and not dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j @@ -634,31 +652,48 @@ def _init_psi_and_current( # psi is already provided from a numerical equilibrium, so no need to # first calculate currents. However, non-inductive currents are still # calculated and used in current diffusion equation. + + # Calculate the dpsi/drho necessary to achieve the given Ip_tot + dpsi_drho_edge = _calculate_psi_grad_constraint_from_Ip_tot( + dynamic_runtime_params_slice, + geo, + ) + + # Set the psi BCs based on whether Vloop is provided and the source of Ip + if use_Vloop_bound_right and geo.Ip_from_parameters: + right_face_grad_constraint = None + right_face_constraint = geo.psi_from_Ip[-1] + dpsi_drho_edge * geo.drho / 2 + elif use_Vloop_bound_right: + right_face_grad_constraint = None + right_face_constraint = geo.psi_from_Ip[-1] + else: + right_face_grad_constraint = dpsi_drho_edge + right_face_constraint = None + psi = cell_variable.CellVariable( - value=geo.psi_from_Ip, - right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot( - dynamic_runtime_params_slice, - geo, - ) - if not use_Vloop_bound_right - else None, - right_face_constraint=geo.psi_from_Ip[-1] - if use_Vloop_bound_right - else None, - dr=geo.drho_norm, + value=geo.psi_from_Ip, # Use psi from equilibrium + right_face_grad_constraint=right_face_grad_constraint, + right_face_constraint=right_face_constraint, + dr=geo.drho_norm, ) core_profiles = dataclasses.replace(core_profiles, psi=psi) + # Calculate non-inductive currents currents = _calculate_currents_from_psi( dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, core_profiles=core_profiles, source_models=source_models, ) - # Calculating j according to nu formula and psi from j. + + # Case 3: calculating j according to nu formula and psi from j. elif ( isinstance(geo, geometry.CircularAnalyticalGeometry) or dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j ): + # TODO: Vloop_bound_right is not yet supported for this case. + if use_Vloop_bound_right: + raise NotImplementedError('Vloop_bound_right not yet supported for this case.') + currents = _prescribe_currents_no_bootstrap( dynamic_runtime_params_slice=dynamic_runtime_params_slice, geo=geo, diff --git a/torax/examples/vloop.py b/torax/examples/vloop.py deleted file mode 100644 index a58ebe6c..00000000 --- a/torax/examples/vloop.py +++ /dev/null @@ -1,90 +0,0 @@ -CONFIG = { - "runtime_params": { - "profile_conditions": { - 'Ip_tot': 15.0, - "Vloop_bound_right": 0.0, - "psi": dict( - zip( - [ - 0.02, - 0.06, - 0.1, - 0.14, - 0.18, - 0.22, - 0.26, - 0.3, - 0.34, - 0.38, - 0.42, - 0.46, - 0.5, - 0.54, - 0.58, - 0.62, - 0.66, - 0.7, - 0.74, - 0.78, - 0.82, - 0.86, - 0.9, - 0.94, - 0.98, - ], - [ - 4.27722812e-02, - 3.94379591e-01, - 1.09263271e00, - 2.10681692e00, - 3.41030703e00, - 4.97594319e00, - 6.77607994e00, - 8.79231633e00, - 1.10530604e01, - 1.36745883e01, - 1.67709665e01, - 2.02615760e01, - 2.39018070e01, - 2.74979635e01, - 3.09738348e01, - 3.43088066e01, - 3.74947223e01, - 4.05261312e01, - 4.33998473e01, - 4.61154118e01, - 4.86753469e01, - 5.10851399e01, - 5.33529513e01, - 5.54890272e01, - 5.75047356e01, - ], - ) - ), - "set_pedestal": False, - } - }, - "geometry": { - "geometry_type": "circular", - }, - "sources": { - "j_bootstrap": {}, - "generic_current_source": {}, - "generic_particle_source": {}, - "gas_puff_source": {}, - "pellet_source": {}, - "generic_ion_el_heat_source": {}, - "fusion_heat_source": {}, - "qei_source": {}, - "ohmic_heat_source": {}, - }, - "transport": { - "transport_model": "constant", - }, - "stepper": { - "stepper_type": "linear", - }, - "time_step_calculator": { - "calculator_type": "chi", - }, -} diff --git a/torax/sim.py b/torax/sim.py index fb0293fd..39ca5fc0 100644 --- a/torax/sim.py +++ b/torax/sim.py @@ -28,16 +28,13 @@ import dataclasses import time -from typing import Any -from typing import Optional +from typing import Any, Optional +from absl import logging import chex import jax import jax.numpy as jnp import numpy as np -import xarray as xr -from absl import logging - from torax import calc_coeffs from torax import core_profile_setters from torax import geometry @@ -60,6 +57,7 @@ from torax.time_step_calculator import chi_time_step_calculator from torax.time_step_calculator import time_step_calculator as ts from torax.transport_model import transport_model as transport_model_lib +import xarray as xr def _log_timestep( diff --git a/torax/tests/boundary_conditions.py b/torax/tests/boundary_conditions.py index 9fb4bb8f..608007f5 100644 --- a/torax/tests/boundary_conditions.py +++ b/torax/tests/boundary_conditions.py @@ -15,10 +15,9 @@ """Tests for module torax.boundary_conditions.""" -import numpy as np from absl.testing import absltest from absl.testing import parameterized - +import numpy as np from torax import constants from torax import core_profile_setters from torax import geometry diff --git a/torax/tests/core_profile_setters.py b/torax/tests/core_profile_setters.py index 024dc0f9..283bfb61 100644 --- a/torax/tests/core_profile_setters.py +++ b/torax/tests/core_profile_setters.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for module torax.boundary_conditions.""" +"""Tests for module torax.core_profile_setters.""" -import numpy as np from absl.testing import absltest from absl.testing import parameterized - +import numpy as np from torax import core_profile_setters from torax import geometry from torax import physics @@ -28,6 +27,7 @@ from torax.stepper import runtime_params as stepper_params_lib from torax.transport_model import runtime_params as transport_params_lib + SMALL_VALUE = 1e-6 diff --git a/torax/tests/test_lib/explicit_stepper.py b/torax/tests/test_lib/explicit_stepper.py index e301ee94..7fa12343 100644 --- a/torax/tests/test_lib/explicit_stepper.py +++ b/torax/tests/test_lib/explicit_stepper.py @@ -23,7 +23,6 @@ import jax from jax import numpy as jnp - from torax import constants from torax import core_profile_setters from torax import geometry From 0ce5eba968f3bd6e1deaab1cee793b7a046388ad Mon Sep 17 00:00:00 2001 From: Theo Brown <7982453+theo-brown@users.noreply.github.com> Date: Fri, 6 Dec 2024 20:46:04 +0000 Subject: [PATCH 8/8] Fix arg order --- torax/core_profile_setters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index a2653838..a1dd73b9 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -1057,10 +1057,10 @@ def compute_boundary_conditions( ), right_face_constraint=( _psi_value_constraint_from_Vloop( + dt, dynamic_runtime_params_slice_t, core_profiles_t_minus_dt, geo, - dt, ) if dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right is not None else None