Skip to content

Commit

Permalink
Move standard_geometry logic to its own file
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718283158
  • Loading branch information
tamaranorman authored and Torax team committed Jan 27, 2025
1 parent 2f7514b commit a04a7fb
Show file tree
Hide file tree
Showing 11 changed files with 1,426 additions and 1,351 deletions.
32 changes: 21 additions & 11 deletions torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from torax.config import config_args
from torax.config import runtime_params as runtime_params_lib
from torax.geometry import circular_geometry
from torax.geometry import geometry
from torax.geometry import geometry_provider
from torax.geometry import standard_geometry
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.pedestal_model import set_tped_nped
from torax.sources import register_source
Expand Down Expand Up @@ -59,7 +59,9 @@ def _build_standard_geometry_provider(
"""Constructs a geometry provider for a standard geometry."""
global_params = {'Ip_from_parameters', 'n_rho', 'geometry_dir'}
if geometry_type == 'chease':
intermediate_builder = geometry.StandardGeometryIntermediates.from_chease
intermediate_builder = (
standard_geometry.StandardGeometryIntermediates.from_chease
)
elif geometry_type == 'fbt':
# Check if parameters indicate a bundled FBT file and input validity.
if 'LY_bundle_object' in kwargs:
Expand All @@ -72,20 +74,26 @@ def _build_standard_geometry_provider(
"Cannot use 'LY_object' together with a bundled FBT file"
)
# Build and return the GeometryProvider for the bundled case.
intermediates = geometry.StandardGeometryIntermediates.from_fbt_bundle(
**kwargs,
intermediates = (
standard_geometry.StandardGeometryIntermediates.from_fbt_bundle(
**kwargs,
)
)
geometries = {
t: geometry.build_standard_geometry(intermediates[t])
t: standard_geometry.build_standard_geometry(intermediates[t])
for t in intermediates
}
return geometry.StandardGeometryProvider.create_provider(geometries)
return standard_geometry.StandardGeometryProvider.create_provider(
geometries
)
else:
intermediate_builder = (
geometry.StandardGeometryIntermediates.from_fbt_single_slice
standard_geometry.StandardGeometryIntermediates.from_fbt_single_slice
)
elif geometry_type == 'eqdsk':
intermediate_builder = geometry.StandardGeometryIntermediates.from_eqdsk
intermediate_builder = (
standard_geometry.StandardGeometryIntermediates.from_eqdsk
)
else:
raise ValueError(f'Unknown geometry type: {geometry_type}')
if 'geometry_configs' in kwargs:
Expand All @@ -101,14 +109,16 @@ def _build_standard_geometry_provider(
f' {", ".join(x)}'
)
config.update(global_kwargs)
geometries[time] = geometry.build_standard_geometry(
geometries[time] = standard_geometry.build_standard_geometry(
intermediate_builder(
**config,
)
)
return geometry.StandardGeometryProvider.create_provider(geometries)
return standard_geometry.StandardGeometryProvider.create_provider(
geometries
)
return geometry_provider.ConstantGeometryProvider(
geometry.build_standard_geometry(
standard_geometry.build_standard_geometry(
intermediate_builder(
**kwargs,
)
Expand Down
3 changes: 2 additions & 1 deletion torax/config/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from torax.config import runtime_params as general_runtime_params_lib
from torax.geometry import geometry
from torax.geometry import geometry_provider as geometry_provider_lib
from torax.geometry import standard_geometry
from torax.pedestal_model import runtime_params as pedestal_model_params
from torax.sources import runtime_params as sources_params
from torax.stepper import runtime_params as stepper_params
Expand Down Expand Up @@ -377,7 +378,7 @@ def make_ip_consistent(
geo: geometry.Geometry,
) -> tuple[DynamicRuntimeParamsSlice, geometry.Geometry]:
"""Fixes Ip to be the same across dynamic_runtime_params_slice and geo."""
if isinstance(geo, geometry.StandardGeometry):
if isinstance(geo, standard_geometry.StandardGeometry):
if geo.Ip_from_parameters:
# If Ip is from parameters, renormalise psi etc to match the Ip in the
# parameters.
Expand Down
12 changes: 7 additions & 5 deletions torax/config/tests/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from torax.config import runtime_params as runtime_params_lib
from torax.config import runtime_params_slice
from torax.geometry import circular_geometry
from torax.geometry import geometry
from torax.geometry import geometry_provider
from torax.geometry import standard_geometry
from torax.pedestal_model import set_tped_nped
from torax.sources import runtime_params as source_runtime_params_lib
from torax.stepper import linear_theta_method
Expand Down Expand Up @@ -232,7 +232,7 @@ def test_build_geometry_from_chease(self):
self.assertIsInstance(
geo_provider, geometry_provider.ConstantGeometryProvider
)
self.assertIsInstance(geo_provider(t=0), geometry.StandardGeometry)
self.assertIsInstance(geo_provider(t=0), standard_geometry.StandardGeometry)
np.testing.assert_array_equal(geo_provider.torax_mesh.nx, 5)

def test_build_time_dependent_geometry_from_chease(self):
Expand Down Expand Up @@ -260,8 +260,10 @@ def test_build_time_dependent_geometry_from_chease(self):

# Test valid config
geo_provider = build_sim.build_geometry_provider_from_config(base_config)
self.assertIsInstance(geo_provider, geometry.StandardGeometryProvider)
self.assertIsInstance(geo_provider(t=0), geometry.StandardGeometry)
self.assertIsInstance(
geo_provider, standard_geometry.StandardGeometryProvider
)
self.assertIsInstance(geo_provider(t=0), standard_geometry.StandardGeometry)
np.testing.assert_array_equal(geo_provider.torax_mesh.nx, 10)

# Test invalid configs:
Expand Down Expand Up @@ -301,7 +303,7 @@ def test_chease_geometry_updates_Ip(self):
geometry_provider=geo_provider,
)
)
self.assertIsInstance(geo, geometry.StandardGeometry)
self.assertIsInstance(geo, standard_geometry.StandardGeometry)
self.assertIsNotNone(dynamic_slice)
self.assertNotEqual(
dynamic_slice.profile_conditions.Ip_tot, original_Ip_tot
Expand Down
3 changes: 2 additions & 1 deletion torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from torax.fvm import cell_variable
from torax.geometry import circular_geometry
from torax.geometry import geometry
from torax.geometry import standard_geometry
from torax.sources import ohmic_heat_source
from torax.sources import source_models as source_models_lib
from torax.sources import source_profiles as source_profiles_lib
Expand Down Expand Up @@ -625,7 +626,7 @@ def _init_psi_and_current(
)
# Retrieving psi from the standard geometry input.
elif (
isinstance(geo, geometry.StandardGeometry)
isinstance(geo, standard_geometry.StandardGeometry)
and not dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j
):
# psi is already provided from a numerical equilibrium, so no need to
Expand Down
Loading

0 comments on commit a04a7fb

Please sign in to comment.