Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add Vloop BC #456

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
21 changes: 17 additions & 4 deletions torax/config/profile_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
import logging

import chex
from typing_extensions import override

from torax import array_typing
from torax import geometry
from torax import interpolated_param
from torax.config import base
from torax.config import config_args
from typing_extensions import override


# pylint: disable=invalid-name
Expand All @@ -38,7 +39,11 @@ class ProfileConditions(
# total plasma current in MA
# Note that if Ip_from_parameters=False in geometry, then this Ip will be
# overwritten by values from the geometry data
Ip: interpolated_param.TimeInterpolatedInput = 15.0
Ip: interpolated_param.TimeInterpolatedInput | None = 15.0

# Boundary condition at LCFS for Vloop ( = dpsi/dt )
# Used if total plasma current is a state and edge flux is an input
Vloop_bound_right: interpolated_param.TimeInterpolatedInput | None = None

# Temperature boundary conditions at r=Rmin. If this is `None` the boundary
# condition will instead be taken from `Ti` and `Te` at rhon=1.
Expand Down Expand Up @@ -142,6 +147,12 @@ def __post_init__(self):
'ne',
)

# Check that only one of Vloop_bound_right and Ip are provided
# ^ is XOR
if not ((self.Ip is None) ^ (self.Vloop_bound_right is None)):
raise ValueError("Only one of Ip and Vloop_bound_right can be defined")


@override
def make_provider(
self,
Expand Down Expand Up @@ -187,7 +198,8 @@ class ProfileConditionsProvider(
"""Provider to retrieve initial and prescribed values and boundary conditions."""

runtime_params_config: ProfileConditions
Ip: interpolated_param.InterpolatedVarSingleAxis
Ip: interpolated_param.InterpolatedVarSingleAxis | None
Vloop_bound_right: interpolated_param.InterpolatedVarSingleAxis | None
Ti_bound_right: (
interpolated_param.InterpolatedVarSingleAxis
| interpolated_param.InterpolatedVarTimeRho
Expand Down Expand Up @@ -224,7 +236,8 @@ def build_dynamic_params(
class DynamicProfileConditions:
"""Prescribed values and boundary conditions for the core profiles."""

Ip: array_typing.ScalarFloat
Ip: array_typing.ScalarFloat | None
Vloop_bound_right: array_typing.ScalarFloat | None
Ti_bound_right: array_typing.ScalarFloat
Te_bound_right: array_typing.ScalarFloat
# Temperature profiles defined on the cell grid.
Expand Down
23 changes: 23 additions & 0 deletions torax/config/tests/profile_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,29 @@ def test_profile_conditions_raises_error_if_boundary_condition_not_defined(
Ti_bound_right=None,
)

def test_raises_error_if_Ip_and_Vloop_clash(self, values, raises):
"""Tests that an error is raised if a) Vloop and Ip are both None or
b) Vloop and Ip are both not None."""
if raises:
with self.assertRaises(ValueError):
profile_conditions.ProfileConditions(
Ip=values,
Vloop_bound_right=values,
)
with self.assertRaises(ValueError):
profile_conditions.ProfileConditions(
Ip=None,
Vloop_bound_right=None,
)
else:
profile_conditions.ProfileConditions(
Ip=values,
Vloop_bound_right=values,
)
profile_conditions.ProfileConditions(
Ip=None,
Vloop_bound_right=None,
)

if __name__ == '__main__':
absltest.main()
60 changes: 45 additions & 15 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
evolved by the PDE system.
"""
import dataclasses

import jax
from jax import numpy as jnp

from torax import constants
from torax import geometry
from torax import jax_utils
Expand Down Expand Up @@ -541,7 +543,7 @@ def _update_psi_from_j(
Returns:
psi: Poloidal flux cell variable.
"""
psi_grad_constraint = _calculate_psi_grad_constraint(
psi_grad_constraint = _psi_grad_constraint_from_Ip(
dynamic_runtime_params_slice,
geo,
)
Expand Down Expand Up @@ -579,18 +581,29 @@ def _update_psi_from_j(
# pylint: enable=invalid-name


def _calculate_psi_grad_constraint(
def _psi_grad_constraint_from_Ip(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: Geometry,
) -> jax.Array:
"""Calculates the constraint on the poloidal flux (psi)."""
"""Calculates the gradient constraint on the poloidal flux (psi) from Ip."""
return (
dynamic_runtime_params_slice.profile_conditions.Ip
* 1e6
* (16 * jnp.pi**3 * constants.CONSTANTS.mu0 * geo.Phib)
/ (geo.g2g3_over_rhon_face[-1] * geo.F_face[-1])
)

def _psi_value_constraint_from_Vloop(
dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice,
core_profiles_t_minus_dt: state.CoreProfiles,
geo: Geometry,
dt: jax.Array,
) -> jax.Array:
"""Calculates the value constraint on the poloidal flux (psi) from Vloop."""
return (
core_profiles_t_minus_dt.psi.face_value(geo.rho_b)
theo-brown marked this conversation as resolved.
Show resolved Hide resolved
+ dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right * dt
)

def _initial_psi(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
Expand Down Expand Up @@ -624,7 +637,7 @@ def _initial_psi(
if dynamic_runtime_params_slice.profile_conditions.psi is not None:
psi = cell_variable.CellVariable(
value=dynamic_runtime_params_slice.profile_conditions.psi,
right_face_grad_constraint=_calculate_psi_grad_constraint(
right_face_grad_constraint=_psi_grad_constraint_from_Ip(
dynamic_runtime_params_slice,
geo,
),
Expand Down Expand Up @@ -687,7 +700,7 @@ def _initial_psi(
# psi is already provided from a numerical equilibrium, so no need to
# first calculate currents. However, non-inductive currents are still
# calculated and used in current diffusion equation.
psi_grad_constraint = _calculate_psi_grad_constraint(
psi_grad_constraint = _psi_grad_constraint_from_Ip(
dynamic_runtime_params_slice,
geo,
)
Expand Down Expand Up @@ -907,8 +920,10 @@ def get_update(x_new, var):


def compute_boundary_conditions(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice,
core_profiles_t_minus_dt: state.CoreProfiles,
geo: geometry.Geometry,
dt: jax.Array,
) -> dict[str, dict[str, jax.Array | None]]:
"""Computes boundary conditions for time t and returns updates to State.

Expand All @@ -922,17 +937,17 @@ def compute_boundary_conditions(
values in a State object.
"""
Ti_bound_right = jax_utils.error_if_not_positive( # pylint: disable=invalid-name
dynamic_runtime_params_slice.profile_conditions.Ti_bound_right,
dynamic_runtime_params_slice_t.profile_conditions.Ti_bound_right,
'Ti_bound_right',
)

Te_bound_right = jax_utils.error_if_not_positive( # pylint: disable=invalid-name
dynamic_runtime_params_slice.profile_conditions.Te_bound_right,
dynamic_runtime_params_slice_t.profile_conditions.Te_bound_right,
'Te_bound_right',
)

ne = _get_ne(
dynamic_runtime_params_slice,
dynamic_runtime_params_slice_t,
geo,
)
ne_bound_right = ne.right_face_constraint
Expand All @@ -943,14 +958,15 @@ def compute_boundary_conditions(
# Zeff = (ni + Zimp**2 * nimp)/ne ; nimp*Zimp + ni = ne

dilution_factor_edge = physics.get_main_ion_dilution_factor(
dynamic_runtime_params_slice.plasma_composition.Zimp,
dynamic_runtime_params_slice.plasma_composition.Zeff_face[-1],
dynamic_runtime_params_slice_t.plasma_composition.Zimp,
dynamic_runtime_params_slice_t.plasma_composition.Zeff_face[-1],
)

ni_bound_right = ne_bound_right * dilution_factor_edge
nimp_bound_right = (
ne_bound_right - ni_bound_right
) / dynamic_runtime_params_slice.plasma_composition.Zimp
) / dynamic_runtime_params_slice_t.plasma_composition.Zimp


return {
'temp_ion': dict(
Expand Down Expand Up @@ -979,11 +995,25 @@ def compute_boundary_conditions(
right_face_constraint=jnp.array(nimp_bound_right),
),
'psi': dict(
right_face_grad_constraint=_calculate_psi_grad_constraint(
dynamic_runtime_params_slice,
right_face_grad_constraint=jnp.where(
dynamic_runtime_params_slice_t.profile_conditions.Ip is not None,
_psi_grad_constraint_from_Ip(
dynamic_runtime_params_slice_t,
geo,
),
None
),
),
right_face_constraint=jnp.where(
dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right is not None,
_psi_value_constraint_from_Vloop(
dynamic_runtime_params_slice_t,
core_profiles_t_minus_dt,
geo,
dt,
),
None,
),
),
}


Expand Down
15 changes: 11 additions & 4 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@

import dataclasses
import time
from typing import Any, Optional
from typing import Any
from typing import Optional

from absl import logging
import chex
import jax
import jax.numpy as jnp
import numpy as np
from absl import logging

from torax import calc_coeffs
from torax import core_profile_setters
from torax import geometry
Expand Down Expand Up @@ -485,6 +487,7 @@ def step(
dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt,
geo_t_plus_dt=geo_t_plus_dt,
core_profiles_t=core_profiles_t,
dt=dt,
)

# Initial trial for stepper. If did not converge (can happen for nonlinear
Expand Down Expand Up @@ -602,10 +605,11 @@ def body_fun(
)

core_profiles_t_plus_dt = provide_core_profiles_t_plus_dt(
core_profiles_t=core_profiles_t,
dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt,
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt,
geo_t_plus_dt=geo_t_plus_dt,
core_profiles_t=core_profiles_t,
dt=dt,
)
core_profiles, core_sources, core_transport, stepper_numeric_outputs = (
self._stepper_fn(
Expand Down Expand Up @@ -1508,12 +1512,15 @@ def provide_core_profiles_t_plus_dt(
dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice,
geo_t_plus_dt: geometry.Geometry,
core_profiles_t: state.CoreProfiles,
dt: jax.Array,
) -> state.CoreProfiles:
"""Provides state at t_plus_dt with new boundary conditions and prescribed profiles."""
updated_boundary_conditions = (
core_profile_setters.compute_boundary_conditions(
dynamic_runtime_params_slice_t_plus_dt,
core_profiles_t,
geo_t_plus_dt,
dt,
)
)
updated_values = core_profile_setters.updated_prescribed_core_profiles(
Expand Down