diff --git a/torax/config/runtime_params_slice.py b/torax/config/runtime_params_slice.py index a421d406..db2eb639 100644 --- a/torax/config/runtime_params_slice.py +++ b/torax/config/runtime_params_slice.py @@ -46,6 +46,7 @@ from torax.config import profile_conditions from torax.config import runtime_params as general_runtime_params_lib from torax.geometry import geometry +from torax.geometry import geometry_provider as geometry_provider_lib from torax.pedestal_model import runtime_params as pedestal_model_params from torax.sources import runtime_params as sources_params from torax.stepper import runtime_params as stepper_params @@ -354,6 +355,23 @@ def __call__( ) +def get_consistent_dynamic_runtime_params_slice_and_geometry( + *, + t: chex.Numeric, + dynamic_runtime_params_slice_provider: DynamicRuntimeParamsSliceProvider, + geometry_provider: geometry_provider_lib.GeometryProvider, +) -> tuple[DynamicRuntimeParamsSlice, geometry.Geometry]: + """Returns the dynamic runtime params and geometry for a given time.""" + geo = geometry_provider(t) + dynamic_runtime_params_slice = dynamic_runtime_params_slice_provider( + t=t, + ) + dynamic_runtime_params_slice, geo = make_ip_consistent( + dynamic_runtime_params_slice, geo + ) + return dynamic_runtime_params_slice, geo + + def make_ip_consistent( dynamic_runtime_params_slice: DynamicRuntimeParamsSlice, geo: geometry.Geometry, diff --git a/torax/config/tests/build_sim.py b/torax/config/tests/build_sim.py index 3c7c03b6..9f5669cb 100644 --- a/torax/config/tests/build_sim.py +++ b/torax/config/tests/build_sim.py @@ -293,12 +293,12 @@ def test_chease_geometry_updates_Ip(self): torax_mesh=geo_provider.torax_mesh, ) ) - geo = geo_provider(t=0) - dynamic_runtime_params_slice = runtime_params_provider( - t=0, - ) - dynamic_slice, geo = runtime_params_slice.make_ip_consistent( - dynamic_runtime_params_slice, geo + dynamic_slice, geo = ( + runtime_params_slice.get_consistent_dynamic_runtime_params_slice_and_geometry( + t=0, + dynamic_runtime_params_slice_provider=runtime_params_provider, + geometry_provider=geo_provider, + ) ) self.assertIsInstance(geo, geometry.StandardGeometry) self.assertIsNotNone(dynamic_slice) diff --git a/torax/orchestration/step_function.py b/torax/orchestration/step_function.py new file mode 100644 index 00000000..9eec5dc4 --- /dev/null +++ b/torax/orchestration/step_function.py @@ -0,0 +1,819 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Logic which controls the stepping over time of the simulation.""" + +from __future__ import annotations + +import dataclasses +from typing import Any + +import jax.numpy as jnp +from torax import core_profile_setters +from torax import jax_utils +from torax import physics +from torax import post_processing +from torax import state +from torax.config import runtime_params_slice +from torax.geometry import geometry +from torax.geometry import geometry_provider as geometry_provider_lib +from torax.pedestal_model import pedestal_model as pedestal_model_lib +from torax.sources import ohmic_heat_source +from torax.sources import source_models as source_models_lib +from torax.sources import source_profiles as source_profiles_lib +from torax.stepper import stepper as stepper_lib +from torax.time_step_calculator import time_step_calculator as ts +from torax.transport_model import transport_model as transport_model_lib + + +class SimulationStepFn: + """Advances the TORAX simulation one time step. + + Unlike the Stepper class, which updates certain parts of the state, a + SimulationStepFn takes in the ToraxSimState and outputs the updated + ToraxSimState, which contains not only the CoreProfiles but also extra + simulation state useful for stepping as well as extra outputs useful for + inspection inside the main run loop in `run_simulation()`. It wraps calls to + Stepper with useful features to increase robustness for convergence, like + dt-backtracking. + """ + + def __init__( + self, + stepper: stepper_lib.Stepper, + time_step_calculator: ts.TimeStepCalculator, + transport_model: transport_model_lib.TransportModel, + pedestal_model: pedestal_model_lib.PedestalModel, + ): + """Initializes the SimulationStepFn. + + If you wish to run a simulation with new versions of any of these arguments + (i.e. want to change to a new stepper), then you will need to build a new + SimulationStepFn. These arguments are fixed for the lifetime + of the SimulationStepFn and cannot change even with JAX recompiles. + + Args: + stepper: Evolves the core profiles. + time_step_calculator: Calculates the dt for each time step. + transport_model: Calculates diffusion and convection coefficients. + pedestal_model: Calculates pedestal coefficients. + """ + self._stepper_fn = stepper + self._time_step_calculator = time_step_calculator + self._transport_model = transport_model + self._pedestal_model = pedestal_model + self._jitted_transport_model = jax_utils.jit( + transport_model.__call__, + ) + + @property + def pedestal_model(self) -> pedestal_model_lib.PedestalModel: + return self._pedestal_model + + @property + def stepper(self) -> stepper_lib.Stepper: + return self._stepper_fn + + @property + def transport_model(self) -> transport_model_lib.TransportModel: + return self._transport_model + + @property + def time_step_calculator(self) -> ts.TimeStepCalculator: + return self._time_step_calculator + + def __call__( + self, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider, + geometry_provider: geometry_provider_lib.GeometryProvider, + input_state: state.ToraxSimState, + ) -> tuple[state.ToraxSimState, state.SimError]: + """Advances the simulation state one time step. + + Args: + static_runtime_params_slice: Static parameters that, if they change, + should trigger a recompilation of the SimulationStepFn. + dynamic_runtime_params_slice_provider: Object that returns a set of + runtime parameters which may change from time step to time step or + simulation run to run. If these runtime parameters change, it does NOT + trigger a JAX recompilation. + geometry_provider: Provides the magnetic geometry for each time step based + on the ToraxSimState at the start of the time step. The geometry may + change from time step to time step, so the sim needs a function to + provide which geometry to use for a given time step. A GeometryProvider + is any callable (class or function) which takes the ToraxSimState at the + start of a time step and returns the Geometry for that time step. For + most use cases, only the time will be relevant from the ToraxSimState + (in order to support time-dependent geometries). + input_state: State at the start of the time step, including the core + profiles which are being evolved. + + Returns: + ToraxSimState containing: + - the core profiles at the end of the time step. + - time and time step calculator state info. + - core_sources and core_transport at the end of the time step. + - stepper_numeric_outputs. This contains the number of iterations + performed in the stepper and the error state. The error states are: + 0 if solver converged with fine tolerance for this step + 1 if solver did not converge for this step (was above coarse tol) + 2 if solver converged within coarse tolerance. Allowed to pass with + a warning. Occasional error=2 has low impact on final sim state. + SimError indicating if an error has occurred during simulation. + """ + dynamic_runtime_params_slice_t, geo_t = ( + runtime_params_slice.get_consistent_dynamic_runtime_params_slice_and_geometry( + t=input_state.t, + dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider, + geometry_provider=geometry_provider, + ) + ) + + # This only computes sources set to explicit in the + # DynamicSourceConfigSlice. All implicit sources will have their profiles + # set to 0. + explicit_source_profiles = source_models_lib.build_source_profiles( + dynamic_runtime_params_slice=dynamic_runtime_params_slice_t, + static_runtime_params_slice=static_runtime_params_slice, + geo=geo_t, + core_profiles=input_state.core_profiles, + source_models=self.stepper.source_models, + explicit=True, + ) + + # The previous time step's state has an incomplete set of source profiles + # which was computed based on the previous time step's "guess" of the core + # profiles at this time step's t. We can merge those "implicit" source + # profiles with the explicit ones computed here. + input_state.core_sources = source_profiles_lib.SourceProfiles.merge( + explicit_source_profiles=explicit_source_profiles, + implicit_source_profiles=input_state.core_sources, + ) + + dt, time_step_calculator_state = self.init_time_step_calculator( + dynamic_runtime_params_slice_t, + geo_t, + input_state, + ) + + # The stepper needs the geo and dynamic_runtime_params_slice at time t + dt + # for implicit computations in the solver. Once geo_t_plus_dt is calculated + # we can use it to calculate Phibdot for both geo_t and geo_t_plus_dt, which + # then update the initialized Phibdot=0 in the geo instances. + dynamic_runtime_params_slice_t_plus_dt, geo_t, geo_t_plus_dt = ( + _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot( + input_state.t, + dt, + dynamic_runtime_params_slice_provider, + geo_t, + geometry_provider, + ) + ) + + output_state = self.step( + dt, + time_step_calculator_state, + static_runtime_params_slice, + dynamic_runtime_params_slice_t, + dynamic_runtime_params_slice_t_plus_dt, + geo_t, + geo_t_plus_dt, + input_state, + explicit_source_profiles, + ) + + if static_runtime_params_slice.adaptive_dt: + # This is a no-op if + # output_state.stepper_numeric_outputs.stepper_error_state == 0. + ( + dynamic_runtime_params_slice_t_plus_dt, + geo_t_plus_dt, + output_state, + ) = self.adaptive_step( + output_state, + static_runtime_params_slice, + dynamic_runtime_params_slice_t, + dynamic_runtime_params_slice_provider, + geo_t, + geometry_provider, + input_state, + explicit_source_profiles, + ) + + sim_state = self.finalize_output( + input_state, + output_state, + dynamic_runtime_params_slice_t_plus_dt, + static_runtime_params_slice, + geo_t_plus_dt, + ) + return sim_state, sim_state.check_for_errors() + + def init_time_step_calculator( + self, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + geo_t: geometry.Geometry, + input_state: state.ToraxSimState, + ) -> tuple[jnp.ndarray, Any]: + """First phase: Initialize the stepper state. + + Args: + dynamic_runtime_params_slice_t: Runtime parameters at time t. + geo_t: The geometry of the torus during this time step of the simulation. + While the geometry may change, any changes to the grid size can trigger + recompilation of the stepper (if it is jitted) or an error (assuming it + is JAX-compiled and lowered). + input_state: State at the start of the time step, including the core + profiles which are being evolved. + + Returns: + Tuple containing: + - time step duration (dt) + - internal time stepper state + """ + # TODO(b/335598388): We call the transport model both here and in the the + # Stepper / CoeffsCallback. This isn't a problem *so long as all of those + # calls fall within the same jit scope* because can use + # functools.lru_cache to avoid building duplicate expressions for the same + # transport coeffs. We should still refactor the design to more explicitly + # calculate transport coeffs at delta_t = 0 in only one place, so that we + # have some flexibility in where to place the jit boundaries. + pedestal_model_output = self._pedestal_model( + dynamic_runtime_params_slice_t, geo_t, input_state.core_profiles + ) + transport_coeffs = self._jitted_transport_model( + dynamic_runtime_params_slice_t, + geo_t, + input_state.core_profiles, + pedestal_model_output, + ) + + # initialize new dt and reset stepper iterations. + dt, time_step_calculator_state = self._time_step_calculator.next_dt( + dynamic_runtime_params_slice_t, + geo_t, + input_state.core_profiles, + input_state.time_step_calculator_state, + transport_coeffs, + ) + + crosses_t_final = ( + input_state.t < dynamic_runtime_params_slice_t.numerics.t_final + ) * ( + input_state.t + input_state.dt + > dynamic_runtime_params_slice_t.numerics.t_final + ) + dt = jnp.where( + jnp.logical_and( + dynamic_runtime_params_slice_t.numerics.exact_t_final, + crosses_t_final, + ), + dynamic_runtime_params_slice_t.numerics.t_final - input_state.t, + dt, + ) + if jnp.any(jnp.isnan(dt)): + raise ValueError('dt is NaN.') + + return (dt, time_step_calculator_state) + + def step( + self, + dt: jnp.ndarray, + time_step_calculator_state: Any, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, + geo_t: geometry.Geometry, + geo_t_plus_dt: geometry.Geometry, + input_state: state.ToraxSimState, + explicit_source_profiles: source_profiles_lib.SourceProfiles, + ) -> state.ToraxSimState: + """Performs a simulation step with given dt. + + Stepper may fail to converge in which case adaptive_step() can be used to + try smaller time step durations. + + Args: + dt: Time step duration. + time_step_calculator_state: Internal time stepper state. + static_runtime_params_slice: Static parameters that, if they change, + should trigger a recompilation of the SimulationStepFn. + dynamic_runtime_params_slice_t: Runtime parameters at time t. + dynamic_runtime_params_slice_t_plus_dt: Runtime parameters at time t + dt. + geo_t: The geometry of the torus during this time step of the simulation. + geo_t_plus_dt: The geometry of the torus during the next time step of the + simulation. + input_state: State at the start of the time step, including the core + profiles which are being evolved. + explicit_source_profiles: Explicit source profiles computed based on the + core profiles at the start of the time step. + + Returns: + ToraxSimState after the step. + """ + + core_profiles_t = input_state.core_profiles + + # Construct the CoreProfiles object for time t+dt with evolving boundary + # conditions and time-dependent prescribed profiles not directly solved by + # PDE system. + core_profiles_t_plus_dt = _provide_core_profiles_t_plus_dt( + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, + geo_t_plus_dt=geo_t_plus_dt, + core_profiles_t=core_profiles_t, + ) + + # Initial trial for stepper. If did not converge (can happen for nonlinear + # step with large dt) we apply the adaptive time step routine if requested. + core_profiles, core_sources, core_transport, stepper_numeric_outputs = ( + self._stepper_fn( + dt=dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, + geo_t=geo_t, + geo_t_plus_dt=geo_t_plus_dt, + core_profiles_t=core_profiles_t, + core_profiles_t_plus_dt=core_profiles_t_plus_dt, + explicit_source_profiles=explicit_source_profiles, + ) + ) + stepper_numeric_outputs.outer_stepper_iterations = 1 + + # post_processed_outputs set to zero since post-processing is done at the + # end of the simulation step following recalculation of explicit + # core_sources to be consistent with the final core_profiles. + return state.ToraxSimState( + t=input_state.t + dt, + dt=dt, + core_profiles=core_profiles, + core_transport=core_transport, + core_sources=core_sources, + post_processed_outputs=state.PostProcessedOutputs.zeros(geo_t_plus_dt), + time_step_calculator_state=time_step_calculator_state, + stepper_numeric_outputs=stepper_numeric_outputs, + ) + + def adaptive_step( + self, + output_state: state.ToraxSimState, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider, + geo_t: geometry.Geometry, + geometry_provider: geometry_provider_lib.GeometryProvider, + input_state: state.ToraxSimState, + explicit_source_profiles: source_profiles_lib.SourceProfiles, + ) -> tuple[ + runtime_params_slice.DynamicRuntimeParamsSlice, + geometry.Geometry, + state.ToraxSimState, + ]: + """Performs adaptive time stepping until stepper converges. + + If the initial step has converged (i.e. + output_state.stepper_numeric_outputs.stepper_error_state == 0), this + function is a no-op. + + Args: + output_state: State after a full step. + static_runtime_params_slice: Static parameters that, if they change, + should trigger a recompilation of the SimulationStepFn. + dynamic_runtime_params_slice_t: Runtime parameters at time t. + dynamic_runtime_params_slice_provider: Runtime parameters slice provider. + geo_t: The geometry of the torus during this time step of the simulation. + geometry_provider: Provides geometry during the next time step of the + simulation. + input_state: State at the start of the time step, including the core + profiles which are being evolved. + explicit_source_profiles: Explicit source profiles computed based on the + core profiles at the start of the time step. + + Returns: + A tuple containing: + - Runtime parameters at time t + dt, where dt is the actual time step + used. + - Geometry at time t + dt, where dt is the actual time step used. + - ToraxSimState after adaptive time stepping. + """ + core_profiles_t = input_state.core_profiles + + # Check if stepper converged. If not, proceed to body_fun + def cond_fun(updated_output: state.ToraxSimState) -> bool: + if updated_output.stepper_numeric_outputs.stepper_error_state == 1: + do_dt_backtrack = True + else: + do_dt_backtrack = False + return do_dt_backtrack + + # Make a new step with a smaller dt, starting with the original core + # profiles. + # Exit if dt < mindt + def body_fun( + updated_output: state.ToraxSimState, + ) -> state.ToraxSimState: + + dt = ( + updated_output.dt + / dynamic_runtime_params_slice_t.numerics.dt_reduction_factor + ) + if jnp.any(jnp.isnan(dt)): + raise ValueError('dt is NaN.') + if dt < dynamic_runtime_params_slice_t.numerics.mindt: + raise ValueError('dt below minimum timestep following adaptation') + + # Calculate dynamic_runtime_params and geo at t + dt. + # Update geos with phibdot. + # The updated geo_t is renamed to geo_t_with_phibdot due to name shadowing + ( + dynamic_runtime_params_slice_t_plus_dt, + geo_t_with_phibdot, + geo_t_plus_dt, + ) = _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot( + input_state.t, + dt, + dynamic_runtime_params_slice_provider, + geo_t, + geometry_provider, + ) + + core_profiles_t_plus_dt = _provide_core_profiles_t_plus_dt( + core_profiles_t=core_profiles_t, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + geo_t_plus_dt=geo_t_plus_dt, + ) + core_profiles, core_sources, core_transport, stepper_numeric_outputs = ( + self._stepper_fn( + dt=dt, + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t, + dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, + geo_t=geo_t_with_phibdot, + geo_t_plus_dt=geo_t_plus_dt, + core_profiles_t=core_profiles_t, + core_profiles_t_plus_dt=core_profiles_t_plus_dt, + explicit_source_profiles=explicit_source_profiles, + ) + ) + stepper_numeric_outputs.outer_stepper_iterations = ( + updated_output.stepper_numeric_outputs.outer_stepper_iterations + 1 + ) + + stepper_numeric_outputs.inner_solver_iterations += ( + updated_output.stepper_numeric_outputs.inner_solver_iterations + ) + return dataclasses.replace( + updated_output, + t=input_state.t + dt, + dt=dt, + core_profiles=core_profiles, + core_transport=core_transport, + core_sources=core_sources, + stepper_numeric_outputs=stepper_numeric_outputs, + ) + + output_state = jax_utils.py_while(cond_fun, body_fun, output_state) + + # Calculate dynamic_runtime_params and geo at t + dt. + # Update geos with phibdot. + ( + dynamic_runtime_params_slice_t_plus_dt, + geo_t, + geo_t_plus_dt, + ) = _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot( + input_state.t, + output_state.dt, + dynamic_runtime_params_slice_provider, + geo_t, + geometry_provider, + ) + + return ( + dynamic_runtime_params_slice_t_plus_dt, + geo_t_plus_dt, + output_state, + ) + + def finalize_output( + self, + input_state: state.ToraxSimState, + output_state: state.ToraxSimState, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + geo_t_plus_dt: geometry.Geometry, + ) -> state.ToraxSimState: + """Finalizes given output state at the end of the simulation step. + + Args: + input_state: Previous sim state. + output_state: State to be finalized. + dynamic_runtime_params_slice_t_plus_dt: Runtime parameters at time t + dt. + static_runtime_params_slice: Static runtime parameters. + geo_t_plus_dt: The geometry of the torus during the next time step of the + simulation. + + Returns: + Finalized ToraxSimState. + """ + + # Update total current, q, and s profiles based on new psi + q_corr = dynamic_runtime_params_slice_t_plus_dt.numerics.q_correction_factor + output_state.core_profiles = physics.update_jtot_q_face_s_face( + geo=geo_t_plus_dt, + core_profiles=output_state.core_profiles, + q_correction_factor=q_corr, + ) + + # Update ohmic and bootstrap current based on the new core profiles. + output_state.core_profiles = _update_current_distribution( + source_models=self._stepper_fn.source_models, + dynamic_runtime_params_slice=dynamic_runtime_params_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + geo=geo_t_plus_dt, + core_profiles=output_state.core_profiles, + ) + + # Update psidot based on the new core profiles. + # Will include the phibdot calculation since geo=geo_t_plus_dt. + output_state.core_profiles = _update_psidot( + source_models=self._stepper_fn.source_models, + dynamic_runtime_params_slice=dynamic_runtime_params_slice_t_plus_dt, + static_runtime_params_slice=static_runtime_params_slice, + geo=geo_t_plus_dt, + core_profiles=output_state.core_profiles, + ) + output_state = post_processing.make_outputs( + sim_state=output_state, + geo=geo_t_plus_dt, + previous_sim_state=input_state, + ) + + return output_state + + +def _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot( + t: jnp.ndarray, + dt: jnp.ndarray, + dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider, + geo_t: geometry.Geometry, + geometry_provider: geometry_provider_lib.GeometryProvider, +) -> tuple[ + runtime_params_slice.DynamicRuntimeParamsSlice, + geometry.Geometry, + geometry.Geometry, +]: + """Returns the geos including Phibdot, and dynamic runtime params at t + dt. + + Args: + t: Time at which the simulation is currently at. + dt: Time step duration. + dynamic_runtime_params_slice_provider: Object that returns a set of runtime + parameters which may change from time step to time step or simulation run + to run. If these runtime parameters change, it does NOT trigger a JAX + recompilation. + geo_t: The geometry of the torus during this time step of the simulation. + geometry_provider: Provides the magnetic geometry for each time step based + on the ToraxSimState at the start of the time step. + + Returns: + Tuple containing: + - The dynamic runtime params at time t + dt. + - The geometry of the torus during this time step of the simulation. + - The geometry of the torus during the next time step of the simulation. + """ + dynamic_runtime_params_slice_t_plus_dt, geo_t_plus_dt = ( + runtime_params_slice.get_consistent_dynamic_runtime_params_slice_and_geometry( + t=t + dt, + dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider, + geometry_provider=geometry_provider, + ) + ) + geo_t, geo_t_plus_dt = _add_Phibdot( + dt, dynamic_runtime_params_slice_t_plus_dt, geo_t, geo_t_plus_dt + ) + + return ( + dynamic_runtime_params_slice_t_plus_dt, + geo_t, + geo_t_plus_dt, + ) + + +# pylint: disable=invalid-name +def _add_Phibdot( + dt: jnp.ndarray, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + geo_t: geometry.Geometry, + geo_t_plus_dt: geometry.Geometry, +) -> tuple[geometry.Geometry, geometry.Geometry]: + """Update Phibdot in the geometry dataclasses used in the time interval. + + Phibdot is used in calc_coeffs to calcuate terms related to time-dependent + geometry. It should be set to be the same for geo_t and geo_t_plus_dt for + each given time interval. This means that geo_t_plus_dt.Phibdot will not + necessarily be the same as the geo_t.Phibdot at the next time step. + + Args: + dt: Time step duration. + dynamic_runtime_params_slice: Runtime parameters which may change from time + step to time step without triggering recompilations. + geo_t: The geometry of the torus during this time step of the simulation. + geo_t_plus_dt: The geometry of the torus during the next time step of the + simulation. + + Returns: + Tuple containing: + - The geometry of the torus during this time step of the simulation. + - The geometry of the torus during the next time step of the simulation. + """ + + # Calculate Phibdot for the time interval. + # If numerics.calcphibdot is False, set Phibdot to be 0 (useful for testing + # purposes) + Phibdot = jnp.where( + dynamic_runtime_params_slice.numerics.calcphibdot, + (geo_t_plus_dt.Phib - geo_t.Phib) / dt, + 0.0, + ) + + geo_t = dataclasses.replace( + geo_t, + Phibdot=Phibdot, + ) + geo_t_plus_dt = dataclasses.replace( + geo_t_plus_dt, + Phibdot=Phibdot, + ) + return geo_t, geo_t_plus_dt + + +# pylint: enable=invalid-name + + +def _update_current_distribution( + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + source_models: source_models_lib.SourceModels, +) -> state.CoreProfiles: + """Update bootstrap current based on the new core_profiles.""" + + bootstrap_profile = source_models.j_bootstrap.get_value( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_runtime_params_slice, + geo=geo, + core_profiles=core_profiles, + ) + + # calculate "External" current profile (e.g. ECCD) + # form of external current on face grid + external_current = source_models.external_current_source( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_runtime_params_slice, + geo=geo, + core_profiles=core_profiles, + ) + + johm = ( + core_profiles.currents.jtot + - bootstrap_profile.j_bootstrap + - external_current + ) + + currents = dataclasses.replace( + core_profiles.currents, + j_bootstrap=bootstrap_profile.j_bootstrap, + j_bootstrap_face=bootstrap_profile.j_bootstrap_face, + I_bootstrap=bootstrap_profile.I_bootstrap, + sigma=bootstrap_profile.sigma, + johm=johm, + external_current_source=external_current, + ) + new_core_profiles = dataclasses.replace( + core_profiles, + currents=currents, + ) + return new_core_profiles + + +def _update_psidot( + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + source_models: source_models_lib.SourceModels, +) -> state.CoreProfiles: + """Update psidot based on new core_profiles.""" + + psidot = dataclasses.replace( + core_profiles.psidot, + value=ohmic_heat_source.calc_psidot( + static_runtime_params_slice, + dynamic_runtime_params_slice, + geo, + core_profiles, + source_models, + ), + ) + + new_core_profiles = dataclasses.replace( + core_profiles, + psidot=psidot, + ) + return new_core_profiles + + +def _provide_core_profiles_t_plus_dt( + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, + dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, + geo_t_plus_dt: geometry.Geometry, + core_profiles_t: state.CoreProfiles, +) -> state.CoreProfiles: + """Provides state at t_plus_dt with new boundary conditions and prescribed profiles.""" + updated_boundary_conditions = ( + core_profile_setters.compute_boundary_conditions( + static_runtime_params_slice, + dynamic_runtime_params_slice_t_plus_dt, + geo_t_plus_dt, + ) + ) + updated_values = core_profile_setters.get_prescribed_core_profile_values( + static_runtime_params_slice=static_runtime_params_slice, + dynamic_runtime_params_slice=dynamic_runtime_params_slice_t_plus_dt, + geo=geo_t_plus_dt, + core_profiles=core_profiles_t, + ) + temp_ion = dataclasses.replace( + core_profiles_t.temp_ion, + value=updated_values['temp_ion'], + **updated_boundary_conditions['temp_ion'], + ) + temp_el = dataclasses.replace( + core_profiles_t.temp_el, + value=updated_values['temp_el'], + **updated_boundary_conditions['temp_el'], + ) + psi = dataclasses.replace( + core_profiles_t.psi, **updated_boundary_conditions['psi'] + ) + ne = dataclasses.replace( + core_profiles_t.ne, + value=updated_values['ne'], + **updated_boundary_conditions['ne'], + ) + ni = dataclasses.replace( + core_profiles_t.ni, + value=updated_values['ni'], + **updated_boundary_conditions['ni'], + ) + nimp = dataclasses.replace( + core_profiles_t.nimp, + value=updated_values['nimp'], + **updated_boundary_conditions['nimp'], + ) + + # pylint: disable=invalid-name + # Update Z_face with boundary condition Z, needed for cases where temp_el + # is evolving and updated_prescribed_core_profiles is a no-op. + Zi_face = jnp.concatenate( + [ + updated_values['Zi_face'][:-1], + jnp.array([updated_boundary_conditions['Zi_edge']]), + ], + ) + Zimp_face = jnp.concatenate( + [ + updated_values['Zimp_face'][:-1], + jnp.array([updated_boundary_conditions['Zimp_edge']]), + ], + ) + # pylint: enable=invalid-name + core_profiles_t_plus_dt = dataclasses.replace( + core_profiles_t, + temp_ion=temp_ion, + temp_el=temp_el, + psi=psi, + ne=ne, + ni=ni, + nimp=nimp, + Zi=updated_values['Zi'], + Zi_face=Zi_face, + Zimp=updated_values['Zimp'], + Zimp_face=Zimp_face, + ) + return core_profiles_t_plus_dt diff --git a/torax/sim.py b/torax/sim.py index a3274f89..de4fe92e 100644 --- a/torax/sim.py +++ b/torax/sim.py @@ -28,17 +28,14 @@ import dataclasses import time -from typing import Any, Optional +from typing import Optional from absl import logging -import chex import jax import jax.numpy as jnp import numpy as np from torax import core_profile_setters -from torax import jax_utils from torax import output -from torax import physics from torax import post_processing from torax import state from torax.config import config_args @@ -46,8 +43,8 @@ from torax.config import runtime_params_slice from torax.geometry import geometry from torax.geometry import geometry_provider as geometry_provider_lib +from torax.orchestration import step_function from torax.pedestal_model import pedestal_model as pedestal_model_lib -from torax.sources import ohmic_heat_source from torax.sources import source_models as source_models_lib from torax.sources import source_profiles as source_profiles_lib from torax.stepper import stepper as stepper_lib @@ -57,558 +54,11 @@ import xarray as xr -def get_consistent_dynamic_runtime_params_slice_and_geometry( - t: chex.Numeric, - dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider, - geometry_provider: geometry_provider_lib.GeometryProvider, -) -> tuple[ - runtime_params_slice.DynamicRuntimeParamsSlice, - geometry.Geometry, -]: - """Returns the dynamic runtime params and geometry for a given time.""" - geo = geometry_provider(t) - dynamic_runtime_params_slice = dynamic_runtime_params_slice_provider( - t=t, - ) - dynamic_runtime_params_slice, geo = runtime_params_slice.make_ip_consistent( - dynamic_runtime_params_slice, geo - ) - return dynamic_runtime_params_slice, geo - - -class SimulationStepFn: - """Advances the TORAX simulation one time step. - - Unlike the Stepper class, which updates certain parts of the state, a - SimulationStepFn takes in the ToraxSimState and outputs the updated - ToraxSimState, which contains not only the CoreProfiles but also extra - simulation state useful for stepping as well as extra outputs useful for - inspection inside the main run loop in `run_simulation()`. It wraps calls to - Stepper with useful features to increase robustness for convergence, like - dt-backtracking. - """ - - def __init__( - self, - stepper: stepper_lib.Stepper, - time_step_calculator: ts.TimeStepCalculator, - transport_model: transport_model_lib.TransportModel, - pedestal_model: pedestal_model_lib.PedestalModel, - ): - """Initializes the SimulationStepFn. - - If you wish to run a simulation with new versions of any of these arguments - (i.e. want to change to a new stepper), then you will need to build a new - SimulationStepFn. These arguments are fixed for the lifetime - of the SimulationStepFn and cannot change even with JAX recompiles. - - Args: - stepper: Evolves the core profiles. - time_step_calculator: Calculates the dt for each time step. - transport_model: Calculates diffusion and convection coefficients. - pedestal_model: Calculates pedestal coefficients. - """ - self._stepper_fn = stepper - self._time_step_calculator = time_step_calculator - self._transport_model = transport_model - self._pedestal_model = pedestal_model - self._jitted_transport_model = jax_utils.jit( - transport_model.__call__, - ) - - @property - def pedestal_model(self) -> pedestal_model_lib.PedestalModel: - return self._pedestal_model - - @property - def stepper(self) -> stepper_lib.Stepper: - return self._stepper_fn - - @property - def transport_model(self) -> transport_model_lib.TransportModel: - return self._transport_model - - @property - def time_step_calculator(self) -> ts.TimeStepCalculator: - return self._time_step_calculator - - def __call__( - self, - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider, - geometry_provider: geometry_provider_lib.GeometryProvider, - input_state: state.ToraxSimState, - ) -> tuple[state.ToraxSimState, state.SimError]: - """Advances the simulation state one time step. - - Args: - static_runtime_params_slice: Static parameters that, if they change, - should trigger a recompilation of the SimulationStepFn. - dynamic_runtime_params_slice_provider: Object that returns a set of - runtime parameters which may change from time step to time step or - simulation run to run. If these runtime parameters change, it does NOT - trigger a JAX recompilation. - geometry_provider: Provides the magnetic geometry for each time step based - on the ToraxSimState at the start of the time step. The geometry may - change from time step to time step, so the sim needs a function to - provide which geometry to use for a given time step. A GeometryProvider - is any callable (class or function) which takes the ToraxSimState at the - start of a time step and returns the Geometry for that time step. For - most use cases, only the time will be relevant from the ToraxSimState - (in order to support time-dependent geometries). - input_state: State at the start of the time step, including the core - profiles which are being evolved. - - Returns: - ToraxSimState containing: - - the core profiles at the end of the time step. - - time and time step calculator state info. - - core_sources and core_transport at the end of the time step. - - stepper_numeric_outputs. This contains the number of iterations - performed in the stepper and the error state. The error states are: - 0 if solver converged with fine tolerance for this step - 1 if solver did not converge for this step (was above coarse tol) - 2 if solver converged within coarse tolerance. Allowed to pass with - a warning. Occasional error=2 has low impact on final sim state. - SimError indicating if an error has occurred during simulation. - """ - dynamic_runtime_params_slice_t, geo_t = ( - get_consistent_dynamic_runtime_params_slice_and_geometry( - input_state.t, - dynamic_runtime_params_slice_provider, - geometry_provider, - ) - ) - - # This only computes sources set to explicit in the - # DynamicSourceConfigSlice. All implicit sources will have their profiles - # set to 0. - explicit_source_profiles = source_models_lib.build_source_profiles( - dynamic_runtime_params_slice=dynamic_runtime_params_slice_t, - static_runtime_params_slice=static_runtime_params_slice, - geo=geo_t, - core_profiles=input_state.core_profiles, - source_models=self.stepper.source_models, - explicit=True, - ) - - # The previous time step's state has an incomplete set of source profiles - # which was computed based on the previous time step's "guess" of the core - # profiles at this time step's t. We can merge those "implicit" source - # profiles with the explicit ones computed here. - input_state.core_sources = merge_source_profiles( - explicit_source_profiles=explicit_source_profiles, - implicit_source_profiles=input_state.core_sources, - ) - - dt, time_step_calculator_state = self.init_time_step_calculator( - dynamic_runtime_params_slice_t, - geo_t, - input_state, - ) - - # The stepper needs the geo and dynamic_runtime_params_slice at time t + dt - # for implicit computations in the solver. Once geo_t_plus_dt is calculated - # we can use it to calculate Phibdot for both geo_t and geo_t_plus_dt, which - # then update the initialized Phibdot=0 in the geo instances. - dynamic_runtime_params_slice_t_plus_dt, geo_t, geo_t_plus_dt = ( - _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot( - input_state.t, - dt, - dynamic_runtime_params_slice_provider, - geo_t, - geometry_provider, - ) - ) - - output_state = self.step( - dt, - time_step_calculator_state, - static_runtime_params_slice, - dynamic_runtime_params_slice_t, - dynamic_runtime_params_slice_t_plus_dt, - geo_t, - geo_t_plus_dt, - input_state, - explicit_source_profiles, - ) - - if static_runtime_params_slice.adaptive_dt: - # This is a no-op if - # output_state.stepper_numeric_outputs.stepper_error_state == 0. - ( - dynamic_runtime_params_slice_t_plus_dt, - geo_t_plus_dt, - output_state, - ) = self.adaptive_step( - output_state, - static_runtime_params_slice, - dynamic_runtime_params_slice_t, - dynamic_runtime_params_slice_provider, - geo_t, - geometry_provider, - input_state, - explicit_source_profiles, - ) - - sim_state = self.finalize_output( - input_state, - output_state, - dynamic_runtime_params_slice_t_plus_dt, - static_runtime_params_slice, - geo_t_plus_dt, - ) - return sim_state, sim_state.check_for_errors() - - def init_time_step_calculator( - self, - dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, - geo_t: geometry.Geometry, - input_state: state.ToraxSimState, - ) -> tuple[jnp.ndarray, Any]: - """First phase: Initialize the stepper state. - - Args: - dynamic_runtime_params_slice_t: Runtime parameters at time t. - geo_t: The geometry of the torus during this time step of the simulation. - While the geometry may change, any changes to the grid size can trigger - recompilation of the stepper (if it is jitted) or an error (assuming it - is JAX-compiled and lowered). - input_state: State at the start of the time step, including the core - profiles which are being evolved. - - Returns: - Tuple containing: - - time step duration (dt) - - internal time stepper state - """ - # TODO(b/335598388): We call the transport model both here and in the the - # Stepper / CoeffsCallback. This isn't a problem *so long as all of those - # calls fall within the same jit scope* because can use - # functools.lru_cache to avoid building duplicate expressions for the same - # transport coeffs. We should still refactor the design to more explicitly - # calculate transport coeffs at delta_t = 0 in only one place, so that we - # have some flexibility in where to place the jit boundaries. - pedestal_model_output = self._pedestal_model( - dynamic_runtime_params_slice_t, geo_t, input_state.core_profiles - ) - transport_coeffs = self._jitted_transport_model( - dynamic_runtime_params_slice_t, - geo_t, - input_state.core_profiles, - pedestal_model_output, - ) - - # initialize new dt and reset stepper iterations. - dt, time_step_calculator_state = self._time_step_calculator.next_dt( - dynamic_runtime_params_slice_t, - geo_t, - input_state.core_profiles, - input_state.time_step_calculator_state, - transport_coeffs, - ) - - crosses_t_final = ( - input_state.t < dynamic_runtime_params_slice_t.numerics.t_final - ) * ( - input_state.t + input_state.dt - > dynamic_runtime_params_slice_t.numerics.t_final - ) - dt = jnp.where( - jnp.logical_and( - dynamic_runtime_params_slice_t.numerics.exact_t_final, - crosses_t_final, - ), - dynamic_runtime_params_slice_t.numerics.t_final - input_state.t, - dt, - ) - if jnp.any(jnp.isnan(dt)): - raise ValueError('dt is NaN.') - - return (dt, time_step_calculator_state) - - def step( - self, - dt: jnp.ndarray, - time_step_calculator_state: Any, - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, - geo_t: geometry.Geometry, - geo_t_plus_dt: geometry.Geometry, - input_state: state.ToraxSimState, - explicit_source_profiles: source_profiles_lib.SourceProfiles, - ) -> state.ToraxSimState: - """Performs a simulation step with given dt. - - Stepper may fail to converge in which case adaptive_step() can be used to - try smaller time step durations. - - Args: - dt: Time step duration. - time_step_calculator_state: Internal time stepper state. - static_runtime_params_slice: Static parameters that, if they change, - should trigger a recompilation of the SimulationStepFn. - dynamic_runtime_params_slice_t: Runtime parameters at time t. - dynamic_runtime_params_slice_t_plus_dt: Runtime parameters at time t + dt. - geo_t: The geometry of the torus during this time step of the simulation. - geo_t_plus_dt: The geometry of the torus during the next time step of the - simulation. - input_state: State at the start of the time step, including the core - profiles which are being evolved. - explicit_source_profiles: Explicit source profiles computed based on the - core profiles at the start of the time step. - - Returns: - ToraxSimState after the step. - """ - - core_profiles_t = input_state.core_profiles - - # Construct the CoreProfiles object for time t+dt with evolving boundary - # conditions and time-dependent prescribed profiles not directly solved by - # PDE system. - core_profiles_t_plus_dt = provide_core_profiles_t_plus_dt( - static_runtime_params_slice=static_runtime_params_slice, - dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, - geo_t_plus_dt=geo_t_plus_dt, - core_profiles_t=core_profiles_t, - ) - - # Initial trial for stepper. If did not converge (can happen for nonlinear - # step with large dt) we apply the adaptive time step routine if requested. - core_profiles, core_sources, core_transport, stepper_numeric_outputs = ( - self._stepper_fn( - dt=dt, - static_runtime_params_slice=static_runtime_params_slice, - dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t, - dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, - geo_t=geo_t, - geo_t_plus_dt=geo_t_plus_dt, - core_profiles_t=core_profiles_t, - core_profiles_t_plus_dt=core_profiles_t_plus_dt, - explicit_source_profiles=explicit_source_profiles, - ) - ) - stepper_numeric_outputs.outer_stepper_iterations = 1 - - # post_processed_outputs set to zero since post-processing is done at the - # end of the simulation step following recalculation of explicit - # core_sources to be consistent with the final core_profiles. - return state.ToraxSimState( - t=input_state.t + dt, - dt=dt, - core_profiles=core_profiles, - core_transport=core_transport, - core_sources=core_sources, - post_processed_outputs=state.PostProcessedOutputs.zeros(geo_t_plus_dt), - time_step_calculator_state=time_step_calculator_state, - stepper_numeric_outputs=stepper_numeric_outputs, - ) - - def adaptive_step( - self, - output_state: state.ToraxSimState, - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider, - geo_t: geometry.Geometry, - geometry_provider: geometry_provider_lib.GeometryProvider, - input_state: state.ToraxSimState, - explicit_source_profiles: source_profiles_lib.SourceProfiles, - ) -> tuple[ - runtime_params_slice.DynamicRuntimeParamsSlice, - geometry.Geometry, - state.ToraxSimState, - ]: - """Performs adaptive time stepping until stepper converges. - - If the initial step has converged (i.e. - output_state.stepper_numeric_outputs.stepper_error_state == 0), this - function is a no-op. - - Args: - output_state: State after a full step. - static_runtime_params_slice: Static parameters that, if they change, - should trigger a recompilation of the SimulationStepFn. - dynamic_runtime_params_slice_t: Runtime parameters at time t. - dynamic_runtime_params_slice_provider: Runtime parameters slice provider. - geo_t: The geometry of the torus during this time step of the simulation. - geometry_provider: Provides geometry during the next time step of the - simulation. - input_state: State at the start of the time step, including the core - profiles which are being evolved. - explicit_source_profiles: Explicit source profiles computed based on the - core profiles at the start of the time step. - - Returns: - A tuple containing: - - Runtime parameters at time t + dt, where dt is the actual time step - used. - - Geometry at time t + dt, where dt is the actual time step used. - - ToraxSimState after adaptive time stepping. - """ - core_profiles_t = input_state.core_profiles - - # Check if stepper converged. If not, proceed to body_fun - def cond_fun(updated_output: state.ToraxSimState) -> bool: - if updated_output.stepper_numeric_outputs.stepper_error_state == 1: - do_dt_backtrack = True - else: - do_dt_backtrack = False - return do_dt_backtrack - - # Make a new step with a smaller dt, starting with the original core - # profiles. - # Exit if dt < mindt - def body_fun( - updated_output: state.ToraxSimState, - ) -> state.ToraxSimState: - - dt = ( - updated_output.dt - / dynamic_runtime_params_slice_t.numerics.dt_reduction_factor - ) - if jnp.any(jnp.isnan(dt)): - raise ValueError('dt is NaN.') - if dt < dynamic_runtime_params_slice_t.numerics.mindt: - raise ValueError('dt below minimum timestep following adaptation') - - # Calculate dynamic_runtime_params and geo at t + dt. - # Update geos with phibdot. - # The updated geo_t is renamed to geo_t_with_phibdot due to name shadowing - ( - dynamic_runtime_params_slice_t_plus_dt, - geo_t_with_phibdot, - geo_t_plus_dt, - ) = _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot( - input_state.t, - dt, - dynamic_runtime_params_slice_provider, - geo_t, - geometry_provider, - ) - - core_profiles_t_plus_dt = provide_core_profiles_t_plus_dt( - core_profiles_t=core_profiles_t, - dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, - static_runtime_params_slice=static_runtime_params_slice, - geo_t_plus_dt=geo_t_plus_dt, - ) - core_profiles, core_sources, core_transport, stepper_numeric_outputs = ( - self._stepper_fn( - dt=dt, - static_runtime_params_slice=static_runtime_params_slice, - dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t, - dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt, - geo_t=geo_t_with_phibdot, - geo_t_plus_dt=geo_t_plus_dt, - core_profiles_t=core_profiles_t, - core_profiles_t_plus_dt=core_profiles_t_plus_dt, - explicit_source_profiles=explicit_source_profiles, - ) - ) - stepper_numeric_outputs.outer_stepper_iterations = ( - updated_output.stepper_numeric_outputs.outer_stepper_iterations + 1 - ) - - stepper_numeric_outputs.inner_solver_iterations += ( - updated_output.stepper_numeric_outputs.inner_solver_iterations - ) - return dataclasses.replace( - updated_output, - t=input_state.t + dt, - dt=dt, - core_profiles=core_profiles, - core_transport=core_transport, - core_sources=core_sources, - stepper_numeric_outputs=stepper_numeric_outputs, - ) - - output_state = jax_utils.py_while(cond_fun, body_fun, output_state) - - # Calculate dynamic_runtime_params and geo at t + dt. - # Update geos with phibdot. - ( - dynamic_runtime_params_slice_t_plus_dt, - geo_t, - geo_t_plus_dt, - ) = _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot( - input_state.t, - output_state.dt, - dynamic_runtime_params_slice_provider, - geo_t, - geometry_provider, - ) - - return ( - dynamic_runtime_params_slice_t_plus_dt, - geo_t_plus_dt, - output_state, - ) - - def finalize_output( - self, - input_state: state.ToraxSimState, - output_state: state.ToraxSimState, - dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - geo_t_plus_dt: geometry.Geometry, - ) -> state.ToraxSimState: - """Finalizes given output state at the end of the simulation step. - - Args: - input_state: Previous sim state. - output_state: State to be finalized. - dynamic_runtime_params_slice_t_plus_dt: Runtime parameters at time t + dt. - static_runtime_params_slice: Static runtime parameters. - geo_t_plus_dt: The geometry of the torus during the next time step of the - simulation. - - Returns: - Finalized ToraxSimState. - """ - - # Update total current, q, and s profiles based on new psi - q_corr = dynamic_runtime_params_slice_t_plus_dt.numerics.q_correction_factor - output_state.core_profiles = physics.update_jtot_q_face_s_face( - geo=geo_t_plus_dt, - core_profiles=output_state.core_profiles, - q_correction_factor=q_corr, - ) - - # Update ohmic and bootstrap current based on the new core profiles. - output_state.core_profiles = update_current_distribution( - source_models=self._stepper_fn.source_models, - dynamic_runtime_params_slice=dynamic_runtime_params_slice_t_plus_dt, - static_runtime_params_slice=static_runtime_params_slice, - geo=geo_t_plus_dt, - core_profiles=output_state.core_profiles, - ) - - # Update psidot based on the new core profiles. - # Will include the phibdot calculation since geo=geo_t_plus_dt. - output_state.core_profiles = _update_psidot( - source_models=self._stepper_fn.source_models, - dynamic_runtime_params_slice=dynamic_runtime_params_slice_t_plus_dt, - static_runtime_params_slice=static_runtime_params_slice, - geo=geo_t_plus_dt, - core_profiles=output_state.core_profiles, - ) - output_state = post_processing.make_outputs( - sim_state=output_state, - geo=geo_t_plus_dt, - previous_sim_state=input_state, - ) - - return output_state - - def get_initial_state( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: geometry.Geometry, - step_fn: SimulationStepFn, + step_fn: step_function.SimulationStepFn, ) -> state.ToraxSimState: """Returns the initial state to be used by run_simulation().""" initial_core_profiles = core_profile_setters.initial_core_profiles( @@ -663,7 +113,7 @@ def __init__( dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider, geometry_provider: geometry_provider_lib.GeometryProvider, initial_state: state.ToraxSimState, - step_fn: SimulationStepFn, + step_fn: step_function.SimulationStepFn, file_restart: general_runtime_params.FileRestart | None = None, ): self._static_runtime_params_slice = static_runtime_params_slice @@ -704,7 +154,7 @@ def static_runtime_params_slice( return self._static_runtime_params_slice @property - def step_fn(self) -> SimulationStepFn: + def step_fn(self) -> step_function.SimulationStepFn: return self._step_fn @property @@ -797,10 +247,10 @@ def update_base_components( self._geometry_provider = geometry_provider dynamic_runtime_params_slice_for_init, geo_for_init = ( - get_consistent_dynamic_runtime_params_slice_and_geometry( - self._dynamic_runtime_params_slice_provider.runtime_params_provider.numerics.runtime_params_config.t_initial, - self._dynamic_runtime_params_slice_provider, - self._geometry_provider, + runtime_params_slice.get_consistent_dynamic_runtime_params_slice_and_geometry( + t=self._dynamic_runtime_params_slice_provider.runtime_params_provider.numerics.runtime_params_config.t_initial, + dynamic_runtime_params_slice_provider=self._dynamic_runtime_params_slice_provider, + geometry_provider=self._geometry_provider, ) ) self._initial_state = get_initial_state( @@ -901,10 +351,10 @@ def create( # Build dynamic_runtime_params_slice at t_initial for initial conditions. dynamic_runtime_params_slice_for_init, geo_for_init = ( - get_consistent_dynamic_runtime_params_slice_and_geometry( - runtime_params.numerics.t_initial, - dynamic_runtime_params_slice_provider, - geometry_provider, + runtime_params_slice.get_consistent_dynamic_runtime_params_slice_and_geometry( + t=runtime_params.numerics.t_initial, + dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider, + geometry_provider=geometry_provider, ) ) if file_restart is not None and file_restart.do_restart: @@ -950,7 +400,7 @@ def create( ) ) - step_fn = SimulationStepFn( + step_fn = step_function.SimulationStepFn( stepper=stepper, time_step_calculator=time_step_calculator, transport_model=transport_model, @@ -1055,7 +505,7 @@ def _run_simulation( dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider, geometry_provider: geometry_provider_lib.GeometryProvider, initial_state: state.ToraxSimState, - step_fn: SimulationStepFn, + step_fn: step_function.SimulationStepFn, log_timestep_info: bool = False, ) -> output.ToraxSimOutputs: """Runs the transport simulation over a prescribed time interval. @@ -1121,10 +571,10 @@ def _run_simulation( wall_clock_step_times = [] dynamic_runtime_params_slice, geo = ( - get_consistent_dynamic_runtime_params_slice_and_geometry( - initial_state.t, - dynamic_runtime_params_slice_provider, - geometry_provider, + runtime_params_slice.get_consistent_dynamic_runtime_params_slice_and_geometry( + t=initial_state.t, + dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider, + geometry_provider=geometry_provider, ) ) @@ -1174,10 +624,10 @@ def _run_simulation( # profiles computed based on the final state. logging.info("Updating last step's source profiles.") dynamic_runtime_params_slice, geo = ( - get_consistent_dynamic_runtime_params_slice_and_geometry( - sim_state.t, - dynamic_runtime_params_slice_provider, - geometry_provider, + runtime_params_slice.get_consistent_dynamic_runtime_params_slice_and_geometry( + t=sim_state.t, + dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider, + geometry_provider=geometry_provider, ) ) explicit_source_profiles = source_models_lib.build_source_profiles( @@ -1188,7 +638,7 @@ def _run_simulation( source_models=step_fn.stepper.source_models, explicit=True, ) - sim_state.core_sources = merge_source_profiles( + sim_state.core_sources = source_profiles_lib.SourceProfiles.merge( explicit_source_profiles=explicit_source_profiles, implicit_source_profiles=sim_state.core_sources, ) @@ -1225,260 +675,6 @@ def _run_simulation( ) -def _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot( - t: jnp.ndarray, - dt: jnp.ndarray, - dynamic_runtime_params_slice_provider: runtime_params_slice.DynamicRuntimeParamsSliceProvider, - geo_t: geometry.Geometry, - geometry_provider: geometry_provider_lib.GeometryProvider, -) -> tuple[ - runtime_params_slice.DynamicRuntimeParamsSlice, - geometry.Geometry, - geometry.Geometry, -]: - """Returns the geos including Phibdot, and dynamic runtime params at t + dt. - - Args: - t: Time at which the simulation is currently at. - dt: Time step duration. - dynamic_runtime_params_slice_provider: Object that returns a set of runtime - parameters which may change from time step to time step or simulation run - to run. If these runtime parameters change, it does NOT trigger a JAX - recompilation. - geo_t: The geometry of the torus during this time step of the simulation. - geometry_provider: Provides the magnetic geometry for each time step based - on the ToraxSimState at the start of the time step. - - Returns: - Tuple containing: - - The dynamic runtime params at time t + dt. - - The geometry of the torus during this time step of the simulation. - - The geometry of the torus during the next time step of the simulation. - """ - dynamic_runtime_params_slice_t_plus_dt, geo_t_plus_dt = ( - get_consistent_dynamic_runtime_params_slice_and_geometry( - t + dt, - dynamic_runtime_params_slice_provider, - geometry_provider, - ) - ) - geo_t, geo_t_plus_dt = _add_Phibdot( - dt, dynamic_runtime_params_slice_t_plus_dt, geo_t, geo_t_plus_dt - ) - - return ( - dynamic_runtime_params_slice_t_plus_dt, - geo_t, - geo_t_plus_dt, - ) - - -# pylint: disable=invalid-name -def _add_Phibdot( - dt: jnp.ndarray, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - geo_t: geometry.Geometry, - geo_t_plus_dt: geometry.Geometry, -) -> tuple[geometry.Geometry, geometry.Geometry]: - """Update Phibdot in the geometry dataclasses used in the time interval. - - Phibdot is used in calc_coeffs to calcuate terms related to time-dependent - geometry. It should be set to be the same for geo_t and geo_t_plus_dt for - each given time interval. This means that geo_t_plus_dt.Phibdot will not - necessarily be the same as the geo_t.Phibdot at the next time step. - - Args: - dt: Time step duration. - dynamic_runtime_params_slice: Runtime parameters which may change from time - step to time step without triggering recompilations. - geo_t: The geometry of the torus during this time step of the simulation. - geo_t_plus_dt: The geometry of the torus during the next time step of the - simulation. - - Returns: - Tuple containing: - - The geometry of the torus during this time step of the simulation. - - The geometry of the torus during the next time step of the simulation. - """ - - # Calculate Phibdot for the time interval. - # If numerics.calcphibdot is False, set Phibdot to be 0 (useful for testing - # purposes) - Phibdot = jnp.where( - dynamic_runtime_params_slice.numerics.calcphibdot, - (geo_t_plus_dt.Phib - geo_t.Phib) / dt, - 0.0, - ) - - geo_t = dataclasses.replace( - geo_t, - Phibdot=Phibdot, - ) - geo_t_plus_dt = dataclasses.replace( - geo_t_plus_dt, - Phibdot=Phibdot, - ) - return geo_t, geo_t_plus_dt - - -# pylint: enable=invalid-name - - -def update_current_distribution( - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - geo: geometry.Geometry, - core_profiles: state.CoreProfiles, - source_models: source_models_lib.SourceModels, -) -> state.CoreProfiles: - """Update bootstrap current based on the new core_profiles.""" - - bootstrap_profile = source_models.j_bootstrap.get_value( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_runtime_params_slice, - geo=geo, - core_profiles=core_profiles, - ) - - # calculate "External" current profile (e.g. ECCD) - # form of external current on face grid - external_current = source_models.external_current_source( - dynamic_runtime_params_slice=dynamic_runtime_params_slice, - static_runtime_params_slice=static_runtime_params_slice, - geo=geo, - core_profiles=core_profiles, - ) - - johm = ( - core_profiles.currents.jtot - - bootstrap_profile.j_bootstrap - - external_current - ) - - currents = dataclasses.replace( - core_profiles.currents, - j_bootstrap=bootstrap_profile.j_bootstrap, - j_bootstrap_face=bootstrap_profile.j_bootstrap_face, - I_bootstrap=bootstrap_profile.I_bootstrap, - sigma=bootstrap_profile.sigma, - johm=johm, - external_current_source=external_current, - ) - new_core_profiles = dataclasses.replace( - core_profiles, - currents=currents, - ) - return new_core_profiles - - -def _update_psidot( - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - geo: geometry.Geometry, - core_profiles: state.CoreProfiles, - source_models: source_models_lib.SourceModels, -) -> state.CoreProfiles: - """Update psidot based on new core_profiles.""" - - psidot = dataclasses.replace( - core_profiles.psidot, - value=ohmic_heat_source.calc_psidot( - static_runtime_params_slice, - dynamic_runtime_params_slice, - geo, - core_profiles, - source_models, - ), - ) - - new_core_profiles = dataclasses.replace( - core_profiles, - psidot=psidot, - ) - return new_core_profiles - - -def provide_core_profiles_t_plus_dt( - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, - dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice, - geo_t_plus_dt: geometry.Geometry, - core_profiles_t: state.CoreProfiles, -) -> state.CoreProfiles: - """Provides state at t_plus_dt with new boundary conditions and prescribed profiles.""" - updated_boundary_conditions = ( - core_profile_setters.compute_boundary_conditions( - static_runtime_params_slice, - dynamic_runtime_params_slice_t_plus_dt, - geo_t_plus_dt, - ) - ) - updated_values = core_profile_setters.get_prescribed_core_profile_values( - static_runtime_params_slice=static_runtime_params_slice, - dynamic_runtime_params_slice=dynamic_runtime_params_slice_t_plus_dt, - geo=geo_t_plus_dt, - core_profiles=core_profiles_t, - ) - temp_ion = dataclasses.replace( - core_profiles_t.temp_ion, - value=updated_values['temp_ion'], - **updated_boundary_conditions['temp_ion'], - ) - temp_el = dataclasses.replace( - core_profiles_t.temp_el, - value=updated_values['temp_el'], - **updated_boundary_conditions['temp_el'], - ) - psi = dataclasses.replace( - core_profiles_t.psi, **updated_boundary_conditions['psi'] - ) - ne = dataclasses.replace( - core_profiles_t.ne, - value=updated_values['ne'], - **updated_boundary_conditions['ne'], - ) - ni = dataclasses.replace( - core_profiles_t.ni, - value=updated_values['ni'], - **updated_boundary_conditions['ni'], - ) - nimp = dataclasses.replace( - core_profiles_t.nimp, - value=updated_values['nimp'], - **updated_boundary_conditions['nimp'], - ) - - # pylint: disable=invalid-name - # Update Z_face with boundary condition Z, needed for cases where temp_el - # is evolving and updated_prescribed_core_profiles is a no-op. - Zi_face = jnp.concatenate( - [ - updated_values['Zi_face'][:-1], - jnp.array([updated_boundary_conditions['Zi_edge']]), - ], - ) - Zimp_face = jnp.concatenate( - [ - updated_values['Zimp_face'][:-1], - jnp.array([updated_boundary_conditions['Zimp_edge']]), - ], - ) - # pylint: enable=invalid-name - core_profiles_t_plus_dt = dataclasses.replace( - core_profiles_t, - temp_ion=temp_ion, - temp_el=temp_el, - psi=psi, - ne=ne, - ni=ni, - nimp=nimp, - Zi=updated_values['Zi'], - Zi_face=Zi_face, - Zimp=updated_values['Zimp'], - Zimp_face=Zimp_face, - ) - return core_profiles_t_plus_dt - - def get_initial_source_profiles( static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, @@ -1527,65 +723,13 @@ def get_initial_source_profiles( source_models=source_models, explicit=True, ) - initial_profiles = merge_source_profiles( + initial_profiles = source_profiles_lib.SourceProfiles.merge( explicit_source_profiles=explicit_source_profiles, implicit_source_profiles=implicit_profiles, ) return initial_profiles -# This function can be jitted if source_models is a static argument. However, -# in our tests, jitting this function actually slightly slows down runs, so this -# is left as pure python. -def merge_source_profiles( - explicit_source_profiles: source_profiles_lib.SourceProfiles, - implicit_source_profiles: source_profiles_lib.SourceProfiles, -) -> source_profiles_lib.SourceProfiles: - """Returns a SourceProfiles that merges the input profiles. - - Sources can either be explicit or implicit. The explicit_source_profiles - contain the profiles for all source models that are set to explicit, and it - contains profiles with all zeros for any implicit source. The opposite holds - for the implicit_source_profiles. - - This function adds the two dictionaries of profiles and returns a single - SourceProfiles that includes both. - - Args: - explicit_source_profiles: Profiles from explicit source models. This - SourceProfiles dict will include keys for both the explicit and implicit - sources, but only the explicit sources will have non-zero profiles. See - source.py and runtime_params.py for more info on explicit vs. implicit. - implicit_source_profiles: Profiles from implicit source models. This - SourceProfiles dict will include keys for both the explicit and implicit - sources, but only the implicit sources will have non-zero profiles. See - source.py and runtime_params.py for more info on explicit vs. implicit. - - Returns: - A SourceProfiles with non-zero profiles for all sources, both explicit and - implicit (assuming the source model outputted a non-zero profile). - """ - sum_profiles = lambda a, b: a + b - summed_bootstrap_profile = jax.tree_util.tree_map( - sum_profiles, - explicit_source_profiles.j_bootstrap, - implicit_source_profiles.j_bootstrap, - ) - summed_qei_info = jax.tree_util.tree_map( - sum_profiles, explicit_source_profiles.qei, implicit_source_profiles.qei - ) - summed_other_profiles = jax.tree_util.tree_map( - sum_profiles, - explicit_source_profiles.profiles, - implicit_source_profiles.profiles, - ) - return source_profiles_lib.SourceProfiles( - profiles=summed_other_profiles, - j_bootstrap=summed_bootstrap_profile, - qei=summed_qei_info, - ) - - def _log_timestep( sim_state: state.ToraxSimState, ) -> None: diff --git a/torax/sources/source_profiles.py b/torax/sources/source_profiles.py index b0e9e78a..85456333 100644 --- a/torax/sources/source_profiles.py +++ b/torax/sources/source_profiles.py @@ -50,6 +50,60 @@ def get_profile(self, name: str) -> jax.Array: return self.profiles[name] return jnp.zeros_like(self.j_bootstrap.j_bootstrap) + # This function can be jitted if source_models is a static argument. However, + # in our tests, jitting this function actually slightly slows down runs, so + # this is left as pure python. + @classmethod + def merge( + cls, + explicit_source_profiles: SourceProfiles, + implicit_source_profiles: SourceProfiles, + ) -> SourceProfiles: + """Returns a SourceProfiles that merges the input profiles. + + Sources can either be explicit or implicit. The explicit_source_profiles + contain the profiles for all source models that are set to explicit, and it + contains profiles with all zeros for any implicit source. The opposite holds + for the implicit_source_profiles. + + This function adds the two dictionaries of profiles and returns a single + SourceProfiles that includes both. + + Args: + explicit_source_profiles: Profiles from explicit source models. This + SourceProfiles dict will include keys for both the explicit and implicit + sources, but only the explicit sources will have non-zero profiles. See + source.py and runtime_params.py for more info on explicit vs. implicit. + implicit_source_profiles: Profiles from implicit source models. This + SourceProfiles dict will include keys for both the explicit and implicit + sources, but only the implicit sources will have non-zero profiles. See + source.py and runtime_params.py for more info on explicit vs. implicit. + + Returns: + A SourceProfiles with non-zero profiles for all sources, both explicit and + implicit (assuming the source model outputted a non-zero profile). + + """ + sum_profiles = lambda a, b: a + b + summed_bootstrap_profile = jax.tree_util.tree_map( + sum_profiles, + explicit_source_profiles.j_bootstrap, + implicit_source_profiles.j_bootstrap, + ) + summed_qei_info = jax.tree_util.tree_map( + sum_profiles, explicit_source_profiles.qei, implicit_source_profiles.qei + ) + summed_other_profiles = jax.tree_util.tree_map( + sum_profiles, + explicit_source_profiles.profiles, + implicit_source_profiles.profiles, + ) + return cls( + profiles=summed_other_profiles, + j_bootstrap=summed_bootstrap_profile, + qei=summed_qei_info, + ) + @chex.dataclass(frozen=True) class BootstrapCurrentProfile: diff --git a/torax/tests/sim_output_source_profiles.py b/torax/tests/sim_output_source_profiles.py index 0259cd86..e6a643c9 100644 --- a/torax/tests/sim_output_source_profiles.py +++ b/torax/tests/sim_output_source_profiles.py @@ -30,6 +30,7 @@ from torax.config import runtime_params as general_runtime_params from torax.geometry import geometry from torax.geometry import geometry_provider as geometry_provider_lib +from torax.orchestration import step_function from torax.pedestal_model import set_tped_nped from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib @@ -87,7 +88,7 @@ def test_merging_source_profiles(self): source_models=source_models, value=2.0, ) - merged_profiles = sim_lib.merge_source_profiles( # pylint: disable=protected-access + merged_profiles = source_profiles_lib.SourceProfiles.merge( implicit_source_profiles=fake_implicit_source_profiles, explicit_source_profiles=fake_explicit_source_profiles, ) @@ -191,7 +192,7 @@ def mock_step_fn( time_step_calculator=time_stepper, ) with mock.patch.object( - sim_lib.SimulationStepFn, '__call__', new=mock_step_fn + step_function.SimulationStepFn, '__call__', new=mock_step_fn ): sim_outputs = sim.run() diff --git a/torax/tests/test_lib/sim_test_case.py b/torax/tests/test_lib/sim_test_case.py index 3e9670cb..00aa4100 100644 --- a/torax/tests/test_lib/sim_test_case.py +++ b/torax/tests/test_lib/sim_test_case.py @@ -241,16 +241,14 @@ def _test_torax_sim( if ref_name is None: ref_name = test_lib.get_data_file(config_name[:-3]) - # Load reference profiles ref_profiles, ref_time = self._get_refs(ref_name, profiles) - # Build geo needed for output generation - geo = sim.geometry_provider(sim.initial_state.t) - dynamic_runtime_params_slice = sim.dynamic_runtime_params_slice_provider( - t=sim.initial_state.t, - ) - _, geo = runtime_params_slice.make_ip_consistent( - dynamic_runtime_params_slice, geo + _, geo = ( + runtime_params_slice.get_consistent_dynamic_runtime_params_slice_and_geometry( + t=sim.initial_state.t, + dynamic_runtime_params_slice_provider=sim.dynamic_runtime_params_slice_provider, + geometry_provider=sim.geometry_provider, + ) ) # Run full simulation diff --git a/torax/tests/test_lib/torax_refs.py b/torax/tests/test_lib/torax_refs.py index a50a27c3..29935596 100644 --- a/torax/tests/test_lib/torax_refs.py +++ b/torax/tests/test_lib/torax_refs.py @@ -23,7 +23,6 @@ import numpy as np import torax from torax import fvm -from torax import sim as sim_lib from torax.config import config_args from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice @@ -60,16 +59,16 @@ def build_consistent_dynamic_runtime_params_slice_and_geometry( ) -> tuple[runtime_params_slice.DynamicRuntimeParamsSlice, geometry.Geometry]: """Builds a consistent Geometry and a DynamicRuntimeParamsSlice.""" t = runtime_params.numerics.t_initial if t is None else t - return sim_lib.get_consistent_dynamic_runtime_params_slice_and_geometry( - t, - runtime_params_slice.DynamicRuntimeParamsSliceProvider( + return runtime_params_slice.get_consistent_dynamic_runtime_params_slice_and_geometry( + t=t, + dynamic_runtime_params_slice_provider=runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, transport=transport_model_params.RuntimeParams(), sources=sources, stepper=stepper_params.RuntimeParams(), torax_mesh=geometry_provider.torax_mesh, ), - geometry_provider, + geometry_provider=geometry_provider, )