Skip to content

Commit

Permalink
Move circular geometry to its own file
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 717722211
  • Loading branch information
tamaranorman authored and Torax team committed Jan 21, 2025
1 parent e854805 commit 91c05b4
Show file tree
Hide file tree
Showing 49 changed files with 504 additions and 463 deletions.
9 changes: 5 additions & 4 deletions torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torax import sim as sim_lib
from torax.config import config_args
from torax.config import runtime_params as runtime_params_lib
from torax.geometry import circular_geometry
from torax.geometry import geometry
from torax.geometry import geometry_provider
from torax.pedestal_model import pedestal_model as pedestal_model_lib
Expand Down Expand Up @@ -126,14 +127,14 @@ def _build_circular_geometry_provider(
raise ValueError('n_rho must be set in the input config.')
geometries = {}
for time, c in kwargs['geometry_configs'].items():
geometries[time] = geometry.build_circular_geometry(
geometries[time] = circular_geometry.build_circular_geometry(
n_rho=kwargs['n_rho'], **c
)
return geometry.CircularAnalyticalGeometryProvider.create_provider(
return circular_geometry.CircularAnalyticalGeometryProvider.create_provider(
geometries
)
return geometry_provider.ConstantGeometryProvider(
geometry.build_circular_geometry(**kwargs)
circular_geometry.build_circular_geometry(**kwargs)
)


Expand All @@ -153,7 +154,7 @@ def build_geometry_provider_from_config(
expected in the rest of the config. See the following functions to get a full
list of the arguments exposed:
- `geometry.build_circular_geometry()`
- `circular_geometry.build_circular_geometry()`
- `geometry.StandardGeometryIntermediates.from_chease()`
- `geometry.StandardGeometryIntermediates.from_fbt()`
Expand Down
7 changes: 4 additions & 3 deletions torax/config/tests/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torax.config import build_sim
from torax.config import runtime_params as runtime_params_lib
from torax.config import runtime_params_slice
from torax.geometry import circular_geometry
from torax.geometry import geometry
from torax.geometry import geometry_provider
from torax.pedestal_model import set_tped_nped
Expand Down Expand Up @@ -105,7 +106,7 @@ def test_build_sim_with_full_config(self):
)
with self.subTest('geometry'):
geo = sim.geometry_provider(sim.initial_state.t)
self.assertIsInstance(geo, geometry.CircularAnalyticalGeometry)
self.assertIsInstance(geo, circular_geometry.CircularAnalyticalGeometry)
self.assertEqual(geo.torax_mesh.nx, 5)
with self.subTest('sources'):
self.assertEqual(
Expand Down Expand Up @@ -185,7 +186,7 @@ def test_general_runtime_params_with_time_dependent_args(self):
self.assertEqual(runtime_params.profile_conditions.ne_is_fGW, False)
self.assertEqual(runtime_params.numerics.q_correction_factor, 0.2)
self.assertEqual(runtime_params.output_dir, '/tmp/this/is/a/test')
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
dynamic_runtime_params_slice = (
runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params,
Expand Down Expand Up @@ -218,7 +219,7 @@ def test_build_circular_geometry(self):
)
geo = geo_provider(t=0)
np.testing.assert_array_equal(geo_provider.torax_mesh.nx, 5)
self.assertIsInstance(geo, geometry.CircularAnalyticalGeometry)
self.assertIsInstance(geo, circular_geometry.CircularAnalyticalGeometry)
np.testing.assert_array_equal(geo.B0, 5.3) # test a default.

def test_build_geometry_from_chease(self):
Expand Down
6 changes: 3 additions & 3 deletions torax/config/tests/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
from absl.testing import parameterized
from torax import interpolated_param
from torax.config import numerics
from torax.geometry import geometry
from torax.geometry import circular_geometry


class NumericsTest(parameterized.TestCase):
"""Unit tests for the `torax.config.numerics` module."""

def test_numerics_make_provider(self):
nums = numerics.Numerics()
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = nums.make_provider(geo.torax_mesh)
provider.build_dynamic_params(t=0.0)

Expand All @@ -35,7 +35,7 @@ def test_interpolated_vars_are_only_constructed_once(
):
"""Tests that interpolated vars are only constructed once."""
nums = numerics.Numerics()
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = nums.make_provider(geo.torax_mesh)
interpolated_params = {}
for field in provider:
Expand Down
10 changes: 5 additions & 5 deletions torax/config/tests/plasma_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torax import charge_states
from torax import interpolated_param
from torax.config import plasma_composition
from torax.geometry import geometry
from torax.geometry import circular_geometry


class PlasmaCompositionTest(parameterized.TestCase):
Expand All @@ -29,7 +29,7 @@ class PlasmaCompositionTest(parameterized.TestCase):
def test_plasma_composition_make_provider(self):
"""Checks provider construction with no issues."""
pc = plasma_composition.PlasmaComposition()
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = pc.make_provider(geo.torax_mesh)
provider.build_dynamic_params(t=0.0)

Expand All @@ -40,7 +40,7 @@ def test_plasma_composition_make_provider(self):
)
def test_zeff_accepts_float_inputs(self, zeff: float):
"""Tests that zeff accepts a single float input."""
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
pc = plasma_composition.PlasmaComposition(Zeff=zeff)
provider = pc.make_provider(geo.torax_mesh)
dynamic_pc = provider.build_dynamic_params(t=0.0)
Expand All @@ -63,7 +63,7 @@ def test_zeff_and_zeff_face_match_expected(self):
1.0: {0.0: 1.8, 0.5: 2.1, 1.0: 2.4},
}

geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
pc = plasma_composition.PlasmaComposition(Zeff=zeff_profile)
provider = pc.make_provider(geo.torax_mesh)

Expand Down Expand Up @@ -102,7 +102,7 @@ def test_interpolated_vars_are_only_constructed_once(
):
"""Tests that interpolated vars are only constructed once."""
pc = plasma_composition.PlasmaComposition()
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = pc.make_provider(geo.torax_mesh)
interpolated_params = {}
for field in provider:
Expand Down
14 changes: 7 additions & 7 deletions torax/config/tests/profile_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torax import interpolated_param
from torax.config import config_args
from torax.config import profile_conditions
from torax.geometry import geometry
from torax.geometry import circular_geometry
import xarray as xr


Expand All @@ -30,7 +30,7 @@ class ProfileConditionsTest(parameterized.TestCase):

def test_profile_conditions_make_provider(self):
pc = profile_conditions.ProfileConditions()
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = pc.make_provider(geo.torax_mesh)
provider.build_dynamic_params(t=0.0)

Expand All @@ -46,7 +46,7 @@ def test_profile_conditions_sets_Te_bound_right_correctly(
Te={0: {0: 1.0, 1: 2.0}, 1.5: {0: 100.0, 1: 200.0}},
Te_bound_right=Te_bound_right,
)
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = pc.make_provider(geo.torax_mesh)
dcs = provider.build_dynamic_params(t=0.0)
self.assertEqual(dcs.Te_bound_right, expected_initial_value)
Expand All @@ -65,7 +65,7 @@ def test_profile_conditions_sets_Ti_bound_right_correctly(
Ti={0: {0: 1.0, 1: 2.0}, 1.5: {0: 100.0, 1: 200.0}},
Ti_bound_right=Ti_bound_right,
)
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = pc.make_provider(geo.torax_mesh)
dcs = provider.build_dynamic_params(t=0.0)
self.assertEqual(dcs.Ti_bound_right, expected_initial_value)
Expand All @@ -84,7 +84,7 @@ def test_profile_conditions_sets_ne_bound_right_correctly(
ne={0: {0: 1.0, 1: 2.0}, 1.5: {0: 100.0, 1: 200.0}},
ne_bound_right=ne_bound_right,
)
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = pc.make_provider(geo.torax_mesh)
dcs = provider.build_dynamic_params(t=0.0)
self.assertEqual(dcs.ne_bound_right, expected_initial_value)
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_profile_conditions_sets_psi_correctly(
self, psi, expected_initial_value, expected_second_value
):
"""Tests that psi is set correctly."""
geo = geometry.build_circular_geometry(n_rho=4)
geo = circular_geometry.build_circular_geometry(n_rho=4)
pc = profile_conditions.ProfileConditions(
psi=psi,
)
Expand All @@ -147,7 +147,7 @@ def test_interpolated_vars_are_only_constructed_once(
):
"""Tests that interpolated vars are only constructed once."""
pc = profile_conditions.ProfileConditions()
geo = geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry()
provider = pc.make_provider(geo.torax_mesh)
interpolated_params = {}
for field in provider:
Expand Down
4 changes: 2 additions & 2 deletions torax/config/tests/runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torax.config import config_args
from torax.config import profile_conditions as profile_conditions_lib
from torax.config import runtime_params as general_runtime_params
from torax.geometry import geometry
from torax.geometry import circular_geometry


# pylint: disable=invalid-name
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_runtime_params_make_provider(self):
runtime_params = general_runtime_params.GeneralRuntimeParams(
profile_conditions=profile_conditions_lib.ProfileConditions()
)
torax_mesh = geometry.build_circular_geometry().torax_mesh
torax_mesh = circular_geometry.build_circular_geometry().torax_mesh
runtime_params_provider = runtime_params.make_provider(torax_mesh)
runtime_params_provider.build_dynamic_params(0.0)

Expand Down
14 changes: 7 additions & 7 deletions torax/config/tests/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torax.config import profile_conditions as profile_conditions_lib
from torax.config import runtime_params as general_runtime_params
from torax.config import runtime_params_slice as runtime_params_slice_lib
from torax.geometry import geometry
from torax.geometry import circular_geometry
from torax.pedestal_model import set_tped_nped
from torax.sources import electron_density_sources
from torax.sources import generic_current_source
Expand All @@ -37,7 +37,7 @@ class RuntimeParamsSliceTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self._geo = geometry.build_circular_geometry()
self._geo = circular_geometry.build_circular_geometry()

def test_dynamic_slice_can_be_input_to_jitted_function(self):
"""Tests that the slice can be input to a jitted function."""
Expand Down Expand Up @@ -351,7 +351,7 @@ def test_profile_conditions_set_electron_temperature_and_boundary_condition(
runtime_params = general_runtime_params.GeneralRuntimeParams(
profile_conditions=profile_conditions,
)
geo = geometry.build_circular_geometry(n_rho=4)
geo = circular_geometry.build_circular_geometry(n_rho=4)
dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
torax_mesh=geo.torax_mesh,
Expand Down Expand Up @@ -394,7 +394,7 @@ def test_profile_conditions_set_electron_density_and_boundary_condition(
ne_is_fGW=ne_is_fGW,
),
)
geo = geometry.build_circular_geometry(n_rho=4)
geo = circular_geometry.build_circular_geometry(n_rho=4)

dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
Expand Down Expand Up @@ -427,7 +427,7 @@ def test_update_dynamic_slice_provider_updates_runtime_params(
Ti_bound_right={0.0: 1.0, 1.0: 2.0},
),
)
geo = geometry.build_circular_geometry(n_rho=4)
geo = circular_geometry.build_circular_geometry(n_rho=4)
provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
torax_mesh=geo.torax_mesh,
Expand Down Expand Up @@ -463,7 +463,7 @@ def test_update_dynamic_slice_provider_updates_sources(
source_models_builder.runtime_params[
generic_current_source.GenericCurrentSource.SOURCE_NAME
].Iext = 1.0
geo = geometry.build_circular_geometry(n_rho=4)
geo = circular_geometry.build_circular_geometry(n_rho=4)
provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
sources=source_models_builder.runtime_params,
Expand Down Expand Up @@ -519,7 +519,7 @@ def test_update_dynamic_slice_provider_updates_transport(
"""Tests that the dynamic slice provider can be updated."""
runtime_params = general_runtime_params.GeneralRuntimeParams()
transport = transport_params_lib.RuntimeParams(De_inner=1.0)
geo = geometry.build_circular_geometry(n_rho=4)
geo = circular_geometry.build_circular_geometry(n_rho=4)
provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
torax_mesh=geo.torax_mesh,
Expand Down
3 changes: 2 additions & 1 deletion torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torax.config import profile_conditions
from torax.config import runtime_params_slice
from torax.fvm import cell_variable
from torax.geometry import circular_geometry
from torax.geometry import geometry
from torax.sources import ohmic_heat_source
from torax.sources import source_models as source_models_lib
Expand Down Expand Up @@ -648,7 +649,7 @@ def _init_psi_and_current(
)
# Calculating j according to nu formula and psi from j.
elif (
isinstance(geo, geometry.CircularAnalyticalGeometry)
isinstance(geo, circular_geometry.CircularAnalyticalGeometry)
or dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j
):
currents = _prescribe_currents_no_bootstrap(
Expand Down
4 changes: 2 additions & 2 deletions torax/fvm/tests/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torax.config import runtime_params as general_runtime_params
from torax.config import runtime_params_slice as runtime_params_slice_lib
from torax.fvm import calc_coeffs
from torax.geometry import geometry
from torax.geometry import circular_geometry
from torax.pedestal_model import set_tped_nped
from torax.sources import runtime_params as source_runtime_params
from torax.sources import source_models as source_models_lib
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_calc_coeffs_smoke_test(
predictor_corrector=False,
theta_imp=theta_imp,
)
geo = geometry.build_circular_geometry(n_rho=num_cells)
geo = circular_geometry.build_circular_geometry(n_rho=num_cells)
transport_model_builder = (
constant_transport_model.ConstantTransportModelBuilder(
runtime_params=constant_transport_model.RuntimeParams(
Expand Down
10 changes: 5 additions & 5 deletions torax/fvm/tests/fvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from torax.fvm import cell_variable
from torax.fvm import implicit_solve_block
from torax.fvm import residual_and_loss
from torax.geometry import geometry
from torax.geometry import circular_geometry
from torax.pedestal_model import set_tped_nped
from torax.sources import runtime_params as source_runtime_params
from torax.sources import source_models as source_models_lib
Expand Down Expand Up @@ -390,7 +390,7 @@ def test_nonlinear_solve_block_loss_minimum(
predictor_corrector=False,
theta_imp=theta_imp,
)
geo = geometry.build_circular_geometry(n_rho=num_cells)
geo = circular_geometry.build_circular_geometry(n_rho=num_cells)
transport_model_builder = (
constant_transport_model.ConstantTransportModelBuilder(
runtime_params=constant_transport_model.RuntimeParams(
Expand Down Expand Up @@ -558,7 +558,7 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self):
pedestal_model_builder = (
set_tped_nped.SetTemperatureDensityPedestalModelBuilder()
)
geo = geometry.build_circular_geometry(n_rho=num_cells)
geo = circular_geometry.build_circular_geometry(n_rho=num_cells)
dynamic_runtime_params_slice = (
runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params,
Expand All @@ -579,7 +579,7 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self):
stepper=stepper_params,
)
)
geo = geometry.build_circular_geometry(n_rho=num_cells)
geo = circular_geometry.build_circular_geometry(n_rho=num_cells)
source_models = source_models_builder()
initial_core_profiles = core_profile_setters.initial_core_profiles(
static_runtime_params_slice,
Expand Down Expand Up @@ -681,7 +681,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
predictor_corrector=False,
theta_imp=0.0,
)
geo = geometry.build_circular_geometry(n_rho=num_cells)
geo = circular_geometry.build_circular_geometry(n_rho=num_cells)
transport_model_builder = (
constant_transport_model.ConstantTransportModelBuilder(
runtime_params=constant_transport_model.RuntimeParams(
Expand Down
Loading

0 comments on commit 91c05b4

Please sign in to comment.