Skip to content

Commit

Permalink
Fix arg order
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-brown committed Nov 22, 2024
1 parent 4a942a0 commit d3d66be
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 12 deletions.
4 changes: 2 additions & 2 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,10 +565,10 @@ def _calculate_psi_grad_constraint_from_Ip_tot(
)

def _psi_value_constraint_from_Vloop(
dt: jax.Array,
dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice,
core_profiles_t_minus_dt: state.CoreProfiles,
geo: Geometry,
dt: jax.Array,
) -> jax.Array:
"""Calculates the value constraint on the poloidal flux (psi) from Vloop."""
return (
Expand Down Expand Up @@ -879,10 +879,10 @@ def get_update(x_new, var):


def compute_boundary_conditions(
dt: jax.Array,
dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice,
core_profiles_t_minus_dt: state.CoreProfiles,
geo: geometry.Geometry,
dt: jax.Array,
) -> dict[str, dict[str, jax.Array | None]]:
"""Computes boundary conditions for time t and returns updates to State.
Expand Down
10 changes: 5 additions & 5 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import jax
import jax.numpy as jnp
import numpy as np
import xarray as xr
from absl import logging

from torax import calc_coeffs
Expand All @@ -58,7 +59,6 @@
from torax.time_step_calculator import chi_time_step_calculator
from torax.time_step_calculator import time_step_calculator as ts
from torax.transport_model import transport_model as transport_model_lib
import xarray as xr


def _log_timestep(
Expand Down Expand Up @@ -485,11 +485,11 @@ def step(
# conditions and time-dependent prescribed profiles not directly solved by
# PDE system.
core_profiles_t_plus_dt = provide_core_profiles_t_plus_dt(
dt=dt,
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt,
geo_t_plus_dt=geo_t_plus_dt,
core_profiles_t=core_profiles_t,
dt=dt,
)

# Initial trial for stepper. If did not converge (can happen for nonlinear
Expand Down Expand Up @@ -607,11 +607,11 @@ def body_fun(
)

core_profiles_t_plus_dt = provide_core_profiles_t_plus_dt(
dt=dt,
static_runtime_params_slice=static_runtime_params_slice,
dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt,
geo_t_plus_dt=geo_t_plus_dt,
core_profiles_t=core_profiles_t,
dt=dt,
)
core_profiles, core_sources, core_transport, stepper_numeric_outputs = (
self._stepper_fn(
Expand Down Expand Up @@ -1566,19 +1566,19 @@ def update_psidot(


def provide_core_profiles_t_plus_dt(
dt: jax.Array,
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice,
geo_t_plus_dt: geometry.Geometry,
core_profiles_t: state.CoreProfiles,
dt: jax.Array,
) -> state.CoreProfiles:
"""Provides state at t_plus_dt with new boundary conditions and prescribed profiles."""
updated_boundary_conditions = (
core_profile_setters.compute_boundary_conditions(
dt,
dynamic_runtime_params_slice_t_plus_dt,
core_profiles_t,
geo_t_plus_dt,
dt,
)
)
updated_values = core_profile_setters.updated_prescribed_core_profiles(
Expand Down
9 changes: 6 additions & 3 deletions torax/tests/boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
"""Tests for module torax.boundary_conditions."""


import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np

from torax import constants
from torax import core_profile_setters
from torax import geometry
Expand Down Expand Up @@ -97,8 +98,10 @@ def test_setting_boundary_conditions(
)

bc = core_profile_setters.compute_boundary_conditions(
dynamic_runtime_params_slice,
geo,
dt=runtime_params.numerics.fixed_dt,
dynamic_runtime_params_slice_t=dynamic_runtime_params_slice,
core_profiles_t_minus_dt=core_profiles,
geo=geo,
)

updated = config_args.recursive_replace(core_profiles, **bc)
Expand Down
7 changes: 5 additions & 2 deletions torax/tests/test_lib/explicit_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import jax
from jax import numpy as jnp

from torax import constants
from torax import core_profile_setters
from torax import geometry
Expand Down Expand Up @@ -119,8 +120,10 @@ def __call__(
# Update the potentially time-dependent boundary conditions as well.
updated_boundary_conditions = (
core_profile_setters.compute_boundary_conditions(
dynamic_runtime_params_slice_t_plus_dt,
geo_t,
dt=dynamic_runtime_params_slice_t_plus_dt.numerics.fixed_dt,
dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t_plus_dt,
core_profiles_t_minus_dt=core_profiles_t,
geo=geo_t_plus_dt,
)
)
temp_ion_new = dataclasses.replace(
Expand Down

0 comments on commit d3d66be

Please sign in to comment.