Skip to content

Commit

Permalink
Change core_profile_setters_test to have test prefix and simplify tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 716503857
  • Loading branch information
tamaranorman authored and Torax team committed Jan 20, 2025
1 parent ce6da2d commit aee685c
Show file tree
Hide file tree
Showing 15 changed files with 1,021 additions and 1,194 deletions.
2 changes: 1 addition & 1 deletion torax/config/profile_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def build_dynamic_params(


@chex.dataclass
class DynamicProfileConditions:
class DynamicProfileConditions:
"""Prescribed values and boundary conditions for the core profiles."""

Ip_tot: array_typing.ScalarFloat
Expand Down
18 changes: 18 additions & 0 deletions torax/config/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from torax.config import profile_conditions
from torax.config import runtime_params as general_runtime_params_lib
from torax.geometry import geometry
from torax.geometry import geometry_provider as geometry_provider_lib
from torax.pedestal_model import runtime_params as pedestal_model_params
from torax.sources import runtime_params as sources_params
from torax.stepper import runtime_params as stepper_params
Expand Down Expand Up @@ -354,6 +355,23 @@ def __call__(
)


def get_consistent_dynamic_runtime_params_slice_and_geometry(
*,
t: chex.Numeric,
dynamic_runtime_params_slice_provider: DynamicRuntimeParamsSliceProvider,
geometry_provider: geometry_provider_lib.GeometryProvider,
) -> tuple[DynamicRuntimeParamsSlice, geometry.Geometry]:
"""Returns the dynamic runtime params and geometry for a given time."""
geo = geometry_provider(t)
dynamic_runtime_params_slice = dynamic_runtime_params_slice_provider(
t=t,
)
dynamic_runtime_params_slice, geo = make_ip_consistent(
dynamic_runtime_params_slice, geo
)
return dynamic_runtime_params_slice, geo


def make_ip_consistent(
dynamic_runtime_params_slice: DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
Expand Down
12 changes: 6 additions & 6 deletions torax/config/tests/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,12 @@ def test_chease_geometry_updates_Ip(self):
torax_mesh=geo_provider.torax_mesh,
)
)
geo = geo_provider(t=0)
dynamic_runtime_params_slice = runtime_params_provider(
t=0,
)
dynamic_slice, geo = runtime_params_slice.make_ip_consistent(
dynamic_runtime_params_slice, geo
dynamic_slice, geo = (
runtime_params_slice.get_consistent_dynamic_runtime_params_slice_and_geometry(
t=0,
dynamic_runtime_params_slice_provider=runtime_params_provider,
geometry_provider=geo_provider,
)
)
self.assertIsInstance(geo, geometry.StandardGeometry)
self.assertIsNotNone(dynamic_slice)
Expand Down
122 changes: 54 additions & 68 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from torax import math_utils
from torax import physics
from torax import state
from torax.config import numerics
from torax.config import profile_conditions
from torax.config import runtime_params_slice
from torax.fvm import cell_variable
from torax.geometry import geometry
Expand All @@ -35,83 +37,75 @@

_trapz = jax.scipy.integrate.trapezoid

# Using capitalized variables for physics notational conventions rather than
# Python style.
# pylint: disable=invalid-name

def updated_ion_temperature(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,

def _updated_ion_temperature(
dynamic_profile_conditions: profile_conditions.DynamicProfileConditions,
geo: geometry.Geometry,
) -> cell_variable.CellVariable:
"""Updated ion temp. Used upon initialization and if temp_ion=False."""
# pylint: disable=invalid-name
Ti_bound_right = (
dynamic_runtime_params_slice.profile_conditions.Ti_bound_right
)

Ti_bound_right = jax_utils.error_if_not_positive(
Ti_bound_right,
dynamic_profile_conditions.Ti_bound_right,
'Ti_bound_right',
)
temp_ion = cell_variable.CellVariable(
value=dynamic_runtime_params_slice.profile_conditions.Ti,
value=dynamic_profile_conditions.Ti,
left_face_grad_constraint=jnp.zeros(()),
right_face_grad_constraint=None,
right_face_constraint=Ti_bound_right,
dr=geo.drho_norm,
)
# pylint: enable=invalid-name

return temp_ion


def updated_electron_temperature(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
def _updated_electron_temperature(
dynamic_profile_conditions: profile_conditions.DynamicProfileConditions,
geo: geometry.Geometry,
) -> cell_variable.CellVariable:
"""Updated electron temp. Used upon initialization and if temp_el=False."""
# pylint: disable=invalid-name
Te_bound_right = (
dynamic_runtime_params_slice.profile_conditions.Te_bound_right
)

Te_bound_right = jax_utils.error_if_not_positive(
Te_bound_right,
dynamic_profile_conditions.Te_bound_right,
'Te_bound_right',
)
temp_el = cell_variable.CellVariable(
value=dynamic_runtime_params_slice.profile_conditions.Te,
value=dynamic_profile_conditions.Te,
left_face_grad_constraint=jnp.zeros(()),
right_face_grad_constraint=None,
right_face_constraint=Te_bound_right,
dr=geo.drho_norm,
)
# pylint: enable=invalid-name
return temp_el


# pylint: disable=invalid-name
def _get_ne(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_numerics: numerics.DynamicNumerics,
dynamic_profile_conditions: profile_conditions.DynamicProfileConditions,
geo: geometry.Geometry,
) -> cell_variable.CellVariable:
"""Helper to get the electron density profile at the current timestep."""
# pylint: disable=invalid-name
nGW = (
dynamic_runtime_params_slice.profile_conditions.Ip_tot
dynamic_profile_conditions.Ip_tot
/ (jnp.pi * geo.Rmin**2)
* 1e20
/ dynamic_runtime_params_slice.numerics.nref
/ dynamic_numerics.nref
)
ne_value = jnp.where(
dynamic_runtime_params_slice.profile_conditions.ne_is_fGW,
dynamic_runtime_params_slice.profile_conditions.ne * nGW,
dynamic_runtime_params_slice.profile_conditions.ne,
dynamic_profile_conditions.ne_is_fGW,
dynamic_profile_conditions.ne * nGW,
dynamic_profile_conditions.ne,
)
# Calculate ne_bound_right.
ne_bound_right = jnp.where(
dynamic_runtime_params_slice.profile_conditions.ne_bound_right_is_fGW,
dynamic_runtime_params_slice.profile_conditions.ne_bound_right * nGW,
dynamic_runtime_params_slice.profile_conditions.ne_bound_right,
dynamic_profile_conditions.ne_bound_right_is_fGW,
dynamic_profile_conditions.ne_bound_right * nGW,
dynamic_profile_conditions.ne_bound_right,
)

if dynamic_runtime_params_slice.profile_conditions.normalize_to_nbar:
if dynamic_profile_conditions.normalize_to_nbar:
face_left = ne_value[0] # Zero gradient boundary condition at left face.
face_right = ne_bound_right
face_inner = (ne_value[..., :-1] + ne_value[..., 1:]) / 2.0
Expand All @@ -129,16 +123,15 @@ def _get_ne(
Rmin_out = geo.Rout_face[-1] - geo.Rout_face[0]
# find target nbar in absolute units
target_nbar = jnp.where(
dynamic_runtime_params_slice.profile_conditions.ne_is_fGW,
dynamic_runtime_params_slice.profile_conditions.nbar * nGW,
dynamic_runtime_params_slice.profile_conditions.nbar,
dynamic_profile_conditions.ne_is_fGW,
dynamic_profile_conditions.nbar * nGW,
dynamic_profile_conditions.nbar,
)
if (
not dynamic_runtime_params_slice.profile_conditions.ne_bound_right_is_absolute
not dynamic_profile_conditions.ne_bound_right_is_absolute
):
# In this case, ne_bound_right is taken from ne and we also normalize it.
C = target_nbar / (_trapz(ne_face, geo.Rout_face) / Rmin_out)
# pylint: enable=invalid-name
ne_bound_right = C * ne_bound_right
else:
# If ne_bound_right is absolute, subtract off contribution from outer
Expand Down Expand Up @@ -166,7 +159,7 @@ def _get_ne(
return ne


def _updated_ion_density(
def updated_ion_density(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
ne: cell_variable.CellVariable,
Expand Down Expand Up @@ -212,7 +205,7 @@ def _updated_ion_density(
return ni, nimp


def updated_density(
def _updated_density(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
) -> tuple[
Expand All @@ -222,10 +215,11 @@ def updated_density(
]:
"""Updated particle density. Used upon initialization and if dens_eq=False."""
ne = _get_ne(
dynamic_runtime_params_slice,
dynamic_runtime_params_slice.numerics,
dynamic_runtime_params_slice.profile_conditions,
geo,
)
ni, nimp = _updated_ion_density(
ni, nimp = updated_ion_density(
dynamic_runtime_params_slice,
geo,
ne,
Expand Down Expand Up @@ -255,7 +249,6 @@ def _prescribe_currents_no_bootstrap(
"""
# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
# pylint: disable=invalid-name

# Calculate splitting of currents depending on input runtime params.
Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot
Expand Down Expand Up @@ -344,7 +337,6 @@ def _prescribe_currents_with_bootstrap(

# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
# pylint: disable=invalid-name
Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot

bootstrap_profile = source_models.j_bootstrap.get_value(
Expand Down Expand Up @@ -430,7 +422,6 @@ def _calculate_currents_from_psi(

# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
# pylint: disable=invalid-name
jtot, jtot_face, Ip_profile_face = physics.calc_jtot_from_psi(
geo,
core_profiles.psi,
Expand Down Expand Up @@ -521,7 +512,6 @@ def _update_psi_from_j(
return psi


# pylint: enable=invalid-name
def _calculate_psi_grad_constraint(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
Expand Down Expand Up @@ -636,12 +626,10 @@ def _init_psi_and_current(
geo,
currents.jtot_hires,
)
# pylint: disable=invalid-name
_, _, Ip_profile_face = physics.calc_jtot_from_psi(
geo,
psi,
)
# pylint: enable=invalid-name
currents = dataclasses.replace(currents, Ip_profile_face=Ip_profile_face)
else:
raise ValueError('Cannot compute psi for given config.')
Expand All @@ -668,14 +656,17 @@ def initial_core_profiles(
Returns:
Initial core profiles.
"""
# pylint: disable=invalid-name

# To set initial values and compute the boundary conditions, we need to handle
# potentially time-varying inputs from the users.
# The default time in build_dynamic_runtime_params_slice is t_initial
temp_ion = updated_ion_temperature(dynamic_runtime_params_slice, geo)
temp_el = updated_electron_temperature(dynamic_runtime_params_slice, geo)
ne, ni, nimp = updated_density(dynamic_runtime_params_slice, geo)
temp_ion = _updated_ion_temperature(
dynamic_runtime_params_slice.profile_conditions, geo
)
temp_el = _updated_electron_temperature(
dynamic_runtime_params_slice.profile_conditions, geo
)
ne, ni, nimp = _updated_density(dynamic_runtime_params_slice, geo)

# The later calculation needs core profiles.
# So initialize these quantities with zeros.
Expand Down Expand Up @@ -732,15 +723,12 @@ def initial_core_profiles(
core_profiles = dataclasses.replace(core_profiles, psidot=psidot)

# Set psi as source of truth and recalculate jtot, q, s
core_profiles = physics.update_jtot_q_face_s_face(
return physics.update_jtot_q_face_s_face(
geo=geo,
core_profiles=core_profiles,
q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor,
)

# pylint: enable=invalid-name
return core_profiles


def updated_prescribed_core_profiles(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
Expand All @@ -761,31 +749,30 @@ def updated_prescribed_core_profiles(
Returns:
Updated core profiles.
"""
# pylint: disable=invalid-name

# If profiles are not evolved, they can still potential be time-evolving,
# depending on the runtime params. If so, they are updated below.
if (
not static_runtime_params_slice.ion_heat_eq
and dynamic_runtime_params_slice.numerics.enable_prescribed_profile_evolution
):
temp_ion = updated_ion_temperature(dynamic_runtime_params_slice, geo).value
temp_ion = _updated_ion_temperature(
dynamic_runtime_params_slice.profile_conditions, geo).value
else:
temp_ion = core_profiles.temp_ion.value
if (
not static_runtime_params_slice.el_heat_eq
and dynamic_runtime_params_slice.numerics.enable_prescribed_profile_evolution
):
temp_el = updated_electron_temperature(
dynamic_runtime_params_slice, geo
temp_el = _updated_electron_temperature(
dynamic_runtime_params_slice.profile_conditions, geo
).value
else:
temp_el = core_profiles.temp_el.value
if (
not static_runtime_params_slice.dens_eq
and dynamic_runtime_params_slice.numerics.enable_prescribed_profile_evolution
):
ne, ni, nimp = updated_density(dynamic_runtime_params_slice, geo)
ne, ni, nimp = _updated_density(dynamic_runtime_params_slice, geo)
ne = ne.value
ni = ni.value
nimp = nimp.value
Expand Down Expand Up @@ -832,7 +819,8 @@ def get_update(x_new, var):
psi = get_update(x_new, 'psi')
ne = get_update(x_new, 'ne')

ni, nimp = _updated_ion_density(dynamic_runtime_params_slice, geo, ne)
ni, nimp = updated_ion_density(
dynamic_runtime_params_slice, geo, ne)

return dataclasses.replace(
core_profiles,
Expand Down Expand Up @@ -860,18 +848,19 @@ def compute_boundary_conditions(
each CellVariable in the state. This dict can in theory recursively replace
values in a State object.
"""
Ti_bound_right = jax_utils.error_if_not_positive( # pylint: disable=invalid-name
Ti_bound_right = jax_utils.error_if_not_positive(
dynamic_runtime_params_slice.profile_conditions.Ti_bound_right,
'Ti_bound_right',
)

Te_bound_right = jax_utils.error_if_not_positive( # pylint: disable=invalid-name
Te_bound_right = jax_utils.error_if_not_positive(
dynamic_runtime_params_slice.profile_conditions.Te_bound_right,
'Te_bound_right',
)

ne = _get_ne(
dynamic_runtime_params_slice,
dynamic_runtime_params_slice.numerics,
dynamic_runtime_params_slice.profile_conditions,
geo,
)
ne_bound_right = ne.right_face_constraint
Expand Down Expand Up @@ -929,7 +918,6 @@ def compute_boundary_conditions(
}


# pylint: disable=invalid-name
def _get_jtot_hires(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
Expand Down Expand Up @@ -967,5 +955,3 @@ def _get_jtot_hires(
johm_hires = jformula_hires * Cohm_hires
jtot_hires = johm_hires + external_current_hires + j_bootstrap_hires
return jtot_hires

# pylint: enable=invalid-name
Loading

0 comments on commit aee685c

Please sign in to comment.