diff --git a/apax/config/md_config.py b/apax/config/md_config.py index 4f67b9c2..f126fc52 100644 --- a/apax/config/md_config.py +++ b/apax/config/md_config.py @@ -5,6 +5,7 @@ import yaml from pydantic import BaseModel, Field, NonNegativeInt, PositiveFloat, PositiveInt +from typing_extensions import Annotated class ConstantTempSchedule(BaseModel, extra="forbid"): @@ -190,6 +191,22 @@ class NPTOptions(NVTOptions, extra="forbid"): barostat_chain: NHCOptions = NHCOptions(tau=1000) +class EnergyUncertaintyCheck(BaseModel, extra="forbid"): + name: Literal["energy_uncertainty"] = "energy_uncertainty" + threshold: PositiveFloat + per_atom: bool = True + + +class ForcesUncertaintyCheck(BaseModel, extra="forbid"): + name: Literal["forces_uncertainty"] = "forces_uncertainty" + threshold: PositiveFloat + + +DynamicsCheck = Annotated[ + Union[EnergyUncertaintyCheck, ForcesUncertaintyCheck], Field(discriminator="name") +] + + class MDConfig(BaseModel, frozen=True, extra="forbid"): """ Configuration for a NHC molecular dynamics simulation. @@ -247,6 +264,8 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"): dr_threshold: PositiveFloat = 0.5 extra_capacity: NonNegativeInt = 0 + dynamics_checks: list[DynamicsCheck] = [] + initial_structure: str load_momenta: bool = False sim_dir: str = "." diff --git a/apax/md/dynamics_checks.py b/apax/md/dynamics_checks.py new file mode 100644 index 00000000..f5ef758e --- /dev/null +++ b/apax/md/dynamics_checks.py @@ -0,0 +1,48 @@ +from typing import Literal, Union + +import jax.numpy as jnp +from pydantic import BaseModel, TypeAdapter + + +class DynamicsCheckBase(BaseModel): + def check(self, predictions): + pass + + +class EnergyUncertaintyCheck(DynamicsCheckBase, extra="forbid"): + name: Literal["energy_uncertainty"] = "energy_uncertainty" + threshold: float + per_atom: bool = True + + def check(self, predictions): + if "energy_uncertainty" not in predictions.keys(): + m = "No energy uncertainty found. Are you using a model ensemble?" + raise ValueError(m) + + energy_uncertainty = predictions["energy_uncertainty"] + if self.per_atom: + n_atoms = predictions["forces"].shape[0] + energy_uncertainty = energy_uncertainty / n_atoms + + check_passed = jnp.all(energy_uncertainty < self.threshold) + return check_passed + + +class ForceUncertaintyCheck(DynamicsCheckBase, extra="forbid"): + name: Literal["forces_uncertainty"] = "forces_uncertainty" + threshold: float + + def check(self, predictions): + if "forces_uncertainty" not in predictions.keys(): + m = "No force uncertainties found. Are you using a model ensemble?" + raise ValueError(m) + + forces_uncertainty = predictions["forces_uncertainty"] + + check_passed = jnp.all(forces_uncertainty < self.threshold) + return check_passed + + +DynamicsChecks = TypeAdapter( + Union[EnergyUncertaintyCheck, ForceUncertaintyCheck] +).validate_python diff --git a/apax/md/simulate.py b/apax/md/simulate.py index 5f42c8f2..eb609a99 100644 --- a/apax/md/simulate.py +++ b/apax/md/simulate.py @@ -17,6 +17,7 @@ from apax.config import Config, MDConfig, parse_config from apax.config.md_config import Integrator from apax.md.ase_calc import make_ensemble, maybe_vmap +from apax.md.dynamics_checks import DynamicsCheckBase, DynamicsChecks from apax.md.io import H5TrajHandler, TrajHandler, truncate_trajectory_to_checkpoint from apax.md.md_checkpoint import load_md_state from apax.md.sim_utils import SimulationFunctions, System @@ -133,19 +134,25 @@ def handle_checkpoints(state, step, system, load_momenta, ckpt_dir, should_load_ return state, step -def create_evaluation_functions(aux_fn, positions, Z, neighbor, box): +def create_evaluation_functions(aux_fn, positions, Z, neighbor, box, dynamics_checks): offsets = jnp.zeros((neighbor.idx.shape[1], 3)) def on_eval(positions, neighbor, box): predictions = aux_fn(positions, Z, neighbor, box, offsets) - return predictions + all_checks_passed = True + + for check in dynamics_checks: + check_passed = check.check(predictions) + all_checks_passed = all_checks_passed & check_passed + return predictions, all_checks_passed predictions = aux_fn(positions, Z, neighbor, box, offsets) dummpy_preds = jax.tree_map(lambda x: jnp.zeros_like(x), predictions) def no_eval(positions, neighbor, box): predictions = dummpy_preds - return predictions + all_checks_passed = True + return predictions, all_checks_passed return on_eval, no_eval @@ -163,6 +170,7 @@ def run_sim( restart: bool = True, checkpoint_interval: int = 50_000, traj_handler: TrajHandler = TrajHandler(), + dynamics_checks: list[DynamicsCheckBase] = [], disable_pbar: bool = False, ): """ @@ -188,7 +196,6 @@ def run_sim( sim_dir : Path Directory where the trajectory and simulation checkpoints will be saved. """ - energy_fn = sim_fns.energy_fn neighbor_fn = sim_fns.neighbor_fn ckpt_dir = sim_dir / "ckpts" ckpt_dir.mkdir(exist_ok=True) @@ -231,12 +238,13 @@ def run_sim( system.atomic_numbers, neighbor, system.box, + dynamics_checks, ) @jax.jit def sim(state, outer_step, neighbor): # TODO make more modular def body_fn(i, state): - state, outer_step, neighbor = state + state, outer_step, neighbor, all_checks_passed = state step = i + outer_step * n_inner apply_fn_kwargs = {} @@ -253,22 +261,25 @@ def body_fn(i, state): neighbor = neighbor.update(state.position, **nbr_kwargs) condition = step % sampling_rate == 0 - predictions = jax.lax.cond( + predictions, check_passed = jax.lax.cond( condition, on_eval, no_eval, state.position, neighbor, box ) + all_checks_passed = all_checks_passed & check_passed + # maybe move this to on_eval io_callback(traj_handler.step, None, (state, predictions, nbr_kwargs)) - return state, outer_step, neighbor + return state, outer_step, neighbor, all_checks_passed - state, outer_step, neighbor = jax.lax.fori_loop( - 0, n_inner, body_fn, (state, outer_step, neighbor) + all_checks_passed = True + state, outer_step, neighbor, all_checks_passed = jax.lax.fori_loop( + 0, n_inner, body_fn, (state, outer_step, neighbor, all_checks_passed) ) current_temperature = ( quantity.temperature(velocity=state.velocity, mass=state.mass) / units.kB ) - return state, neighbor, current_temperature + return state, neighbor, current_temperature, all_checks_passed start = time.time() total_sim_time = n_steps * ensemble.dt / 1000 @@ -285,7 +296,17 @@ def body_fn(i, state): leave=True, ) while step < n_outer: - new_state, neighbor, current_temperature = sim(state, step, neighbor) + new_state, neighbor, current_temperature, all_checks_passed = sim( + state, step, neighbor + ) + + if np.any(np.isnan(state.position)) or np.any(np.isnan(state.velocity)): + raise ValueError(f"NaN encountered, simulation aborted after {step+1} steps.") + + if not all_checks_passed: + with logging_redirect_tqdm(): + log.critical(f"One or more dynamics checks failed at step: {step+1}") + break if neighbor.did_buffer_overflow: with logging_redirect_tqdm(): @@ -298,11 +319,6 @@ def body_fn(i, state): state = new_state step += 1 - if np.any(np.isnan(state.position)) or np.any(np.isnan(state.velocity)): - raise ValueError( - f"NaN encountered, simulation aborted after {step} steps." - ) - if step % checkpoint_interval == 0: with logging_redirect_tqdm(): current_sim_time = step * n_inner * ensemble.dt / 1000 @@ -488,6 +504,14 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"): traj_path = sim_dir / md_config.traj_name system, sim_fns = md_setup(model_config, md_config) + + dynamics_checks = [] + if md_config.dynamics_checks: + check_list = [ + DynamicsChecks(check.model_dump()) for check in md_config.dynamics_checks + ] + dynamics_checks.extend(check_list) + n_steps = int(np.ceil(md_config.duration / md_config.ensemble.dt)) traj_handler = H5TrajHandler( @@ -512,4 +536,5 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"): checkpoint_interval=md_config.checkpoint_interval, sim_dir=sim_dir, traj_handler=traj_handler, + dynamics_checks=dynamics_checks, ) diff --git a/tests/conftest.py b/tests/conftest.py index 45061d4d..38c23f32 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -160,3 +160,4 @@ def load_and_dump_config(config_path, dump_path): def load_config_and_run_training(config_path, updated_config): config_dict = mod_config(config_path, updated_config) run(config_dict) + return config_dict diff --git a/tests/integration_tests/md/config.yaml b/tests/integration_tests/md/config.yaml index 2d7bd124..9bf42fd3 100644 --- a/tests/integration_tests/md/config.yaml +++ b/tests/integration_tests/md/config.yaml @@ -1,12 +1,12 @@ -n_epochs: 2 +n_epochs: 5 data: experiment: apax_dummy data_path: dummy_ds # ds.extxyz #ethanol.traj # ds.extxyz # - n_train: 4 + n_train: 10 n_valid: 2 - batch_size: 2 + batch_size: 1 valid_batch_size: 2 model: @@ -16,6 +16,12 @@ model: name: gaussian n_basis: 5 n_radial: 3 + + ensemble: + kind: shallow + n_members: 4 + force_variance: true + descriptor_dtype: fp64 readout_dtype: fp32 scale_shift_dtype: fp64 @@ -27,4 +33,6 @@ metrics: loss: - name: energy + loss_type: crps - name: forces + loss_type: crps diff --git a/tests/integration_tests/md/md_config_threshold.yaml b/tests/integration_tests/md/md_config_threshold.yaml new file mode 100644 index 00000000..13a662f6 --- /dev/null +++ b/tests/integration_tests/md/md_config_threshold.yaml @@ -0,0 +1,17 @@ +ensemble: + name: nvt + dt: 0.1 # fs time step + temperature_schedule: + name: piecewise + T0: 5 # K + values: [100, 200, 1000] + steps: [10, 20, 30] + +duration: 100 # fs +n_inner: 1 +sampling_rate: 1 +checkpoint_interval: 2 +restart: True +dynamics_checks: + - name: forces_uncertainty + threshold: 1.0 diff --git a/tests/integration_tests/md/test_md.py b/tests/integration_tests/md/test_md.py index 59188115..d7e85004 100644 --- a/tests/integration_tests/md/test_md.py +++ b/tests/integration_tests/md/test_md.py @@ -1,9 +1,11 @@ import os import pathlib +import uuid import jax import jax.numpy as jnp import numpy as np +import pytest import yaml import znh5md from ase import Atoms @@ -14,6 +16,7 @@ from apax.md import run_md from apax.md.ase_calc import ASECalculator from apax.utils import jax_md_reduced +from tests.conftest import load_config_and_run_training TEST_PATH = pathlib.Path(__file__).parent.resolve() @@ -178,3 +181,36 @@ def test_ase_calc(get_tmp_path): assert "energy_uncertainty" in atoms.calc.results.keys() assert "forces_uncertainty" in atoms.calc.results.keys() assert "stress_uncertainty" in atoms.calc.results.keys() + + +@pytest.mark.parametrize("num_data", (30,)) +def test_jaxmd_schedule_and_thresold(get_tmp_path, example_dataset): + model_confg_path = TEST_PATH / "config.yaml" + working_dir = get_tmp_path / str(uuid.uuid4()) + data_path = get_tmp_path / "ds.extxyz" + + write(data_path, example_dataset) + + data_config_mods = { + "data": { + "directory": working_dir.as_posix(), + "experiment": "model", + "data_path": data_path.as_posix(), + }, + } + model_config_dict = load_config_and_run_training(model_confg_path, data_config_mods) + + md_confg_path = TEST_PATH / "md_config.yaml" + + with open(md_confg_path.as_posix(), "r") as stream: + md_config_dict = yaml.safe_load(stream) + md_config_dict["sim_dir"] = get_tmp_path.as_posix() + md_config_dict["initial_structure"] = get_tmp_path.as_posix() + "/ds.extxyz" + md_config = MDConfig.model_validate(md_config_dict) + + model_config = Config.model_validate(model_config_dict) + + run_md(model_config, md_config) + + traj = znh5md.IO(md_config.sim_dir + "/" + md_config.traj_name)[:] + assert len(traj) < 1000 # num steps