From 3e93bcfe8db866e4c60cbf06f77d7e70ed7b94ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sat, 24 Aug 2024 15:49:35 +0200 Subject: [PATCH 1/9] added uncertainty checks to jax md --- apax/config/md_config.py | 20 +++++++++++++ apax/md/dynamics_checks.py | 52 ++++++++++++++++++++++++++++++++++ apax/md/simulate.py | 58 +++++++++++++++++++++++++++----------- 3 files changed, 114 insertions(+), 16 deletions(-) create mode 100644 apax/md/dynamics_checks.py diff --git a/apax/config/md_config.py b/apax/config/md_config.py index 4f67b9c2..29ec8cae 100644 --- a/apax/config/md_config.py +++ b/apax/config/md_config.py @@ -2,6 +2,7 @@ # from types import UnionType from typing import Literal, Union +from typing_extensions import Annotated import yaml from pydantic import BaseModel, Field, NonNegativeInt, PositiveFloat, PositiveInt @@ -190,6 +191,23 @@ class NPTOptions(NVTOptions, extra="forbid"): barostat_chain: NHCOptions = NHCOptions(tau=1000) + + +class EnergyUncertaintyCheck(BaseModel, extra="forbid"): + name: Literal["energy_uncertainty"] = "energy_uncertainty" + threshold: PositiveFloat + per_atom: bool = True + +class ForcesUncertaintyCheck(BaseModel, extra="forbid"): + name: Literal["forces_uncertainty"] = "forces_uncertainty" + threshold: PositiveFloat + + +DynamicsCheck = Annotated[ + Union[EnergyUncertaintyCheck, ForcesUncertaintyCheck], Field(discriminator="name") +] + + class MDConfig(BaseModel, frozen=True, extra="forbid"): """ Configuration for a NHC molecular dynamics simulation. @@ -247,6 +265,8 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"): dr_threshold: PositiveFloat = 0.5 extra_capacity: NonNegativeInt = 0 + dynamics_checks: list[DynamicsCheck] = [] + initial_structure: str load_momenta: bool = False sim_dir: str = "." diff --git a/apax/md/dynamics_checks.py b/apax/md/dynamics_checks.py new file mode 100644 index 00000000..ee190f45 --- /dev/null +++ b/apax/md/dynamics_checks.py @@ -0,0 +1,52 @@ + +from dataclasses import dataclass +from typing import Literal, Union +import jax.numpy as jnp +from pydantic import TypeAdapter, BaseModel + + +class DynamicsCheckBase(BaseModel): + def check(self, predictions): + pass + +class EnergyUncertaintyCheck(DynamicsCheckBase, extra="forbid"): + name: Literal["energy_uncertainty"] = "energy_uncertainty" + threshold: float + per_atom: bool = True + + def check(self, predictions): + + if not "energy_uncertainty" in predictions.keys(): + m = "No energy uncertainty found. Are you using a model ensemble?" + raise ValueError(m) + + energy_uncertainty = predictions["energy_uncertainty"] + if self.per_atom: + n_atoms = predictions["forces"].shape[0] + energy_uncertainty = energy_uncertainty/ n_atoms + + check_passed = jnp.all(energy_uncertainty < self.threshold) + return check_passed + + +from jax import debug + +class ForceUncertaintyCheck(DynamicsCheckBase, extra="forbid"): + name: Literal["forces_uncertainty"] = "forces_uncertainty" + threshold: float + + def check(self, predictions): + + if not "forces_uncertainty" in predictions.keys(): + m = "No force uncertainties found. Are you using a model ensemble?" + raise ValueError(m) + + forces_uncertainty = predictions["forces_uncertainty"] + + check_passed = jnp.all(forces_uncertainty < self.threshold) + return check_passed + + +DynamicsChecks = TypeAdapter( + Union[EnergyUncertaintyCheck, ForceUncertaintyCheck] +).validate_python diff --git a/apax/md/simulate.py b/apax/md/simulate.py index 0b374439..afc9644c 100644 --- a/apax/md/simulate.py +++ b/apax/md/simulate.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import logging import time from functools import partial @@ -17,6 +18,7 @@ from apax.config import Config, MDConfig, parse_config from apax.config.md_config import Integrator from apax.md.ase_calc import make_ensemble, maybe_vmap +from apax.md.dynamics_checks import DynamicsCheckBase, DynamicsChecks from apax.md.io import H5TrajHandler, TrajHandler, truncate_trajectory_to_checkpoint from apax.md.md_checkpoint import load_md_state from apax.md.sim_utils import SimulationFunctions, System @@ -133,19 +135,25 @@ def handle_checkpoints(state, step, system, load_momenta, ckpt_dir, should_load_ return state, step -def create_evaluation_functions(aux_fn, positions, Z, neighbor, box): +def create_evaluation_functions(aux_fn, positions, Z, neighbor, box, dynamics_checks): offsets = jnp.zeros((neighbor.idx.shape[1], 3)) def on_eval(positions, neighbor, box): predictions = aux_fn(positions, Z, neighbor, box, offsets) - return predictions + all_checks_passed = True + + for check in dynamics_checks: + check_passed = check.check(predictions) + all_checks_passed = all_checks_passed & check_passed + return predictions, all_checks_passed predictions = aux_fn(positions, Z, neighbor, box, offsets) dummpy_preds = jax.tree_map(lambda x: jnp.zeros_like(x), predictions) def no_eval(positions, neighbor, box): predictions = dummpy_preds - return predictions + all_checks_passed = True + return predictions, all_checks_passed return on_eval, no_eval @@ -163,6 +171,7 @@ def run_sim( restart: bool = True, checkpoint_interval: int = 50_000, traj_handler: TrajHandler = TrajHandler(), + dynamics_checks: list[DynamicsCheckBase] = [], disable_pbar: bool = False, ): """ @@ -188,7 +197,6 @@ def run_sim( sim_dir : Path Directory where the trajectory and simulation checkpoints will be saved. """ - energy_fn = sim_fns.energy_fn neighbor_fn = sim_fns.neighbor_fn ckpt_dir = sim_dir / "ckpts" ckpt_dir.mkdir(exist_ok=True) @@ -231,12 +239,13 @@ def run_sim( system.atomic_numbers, neighbor, system.box, + dynamics_checks, ) @jax.jit def sim(state, outer_step, neighbor): # TODO make more modular def body_fn(i, state): - state, outer_step, neighbor = state + state, outer_step, neighbor, all_checks_passed = state step = i + outer_step * n_inner apply_fn_kwargs = {} @@ -253,22 +262,25 @@ def body_fn(i, state): neighbor = neighbor.update(state.position, **nbr_kwargs) condition = step % sampling_rate == 0 - predictions = jax.lax.cond( + predictions, check_passed = jax.lax.cond( condition, on_eval, no_eval, state.position, neighbor, box ) + all_checks_passed = all_checks_passed & check_passed + # maybe move this to on_eval io_callback(traj_handler.step, None, (state, predictions, nbr_kwargs)) - return state, outer_step, neighbor + return state, outer_step, neighbor, all_checks_passed - state, outer_step, neighbor = jax.lax.fori_loop( - 0, n_inner, body_fn, (state, outer_step, neighbor) + all_checks_passed = True + state, outer_step, neighbor, all_checks_passed = jax.lax.fori_loop( + 0, n_inner, body_fn, (state, outer_step, neighbor, all_checks_passed) ) current_temperature = ( quantity.temperature(velocity=state.velocity, mass=state.mass) / units.kB ) - return state, neighbor, current_temperature + return state, neighbor, current_temperature, all_checks_passed start = time.time() total_sim_time = n_steps * ensemble.dt / 1000 @@ -285,7 +297,19 @@ def body_fn(i, state): leave=True, ) while step < n_outer: - new_state, neighbor, current_temperature = sim(state, step, neighbor) + new_state, neighbor, current_temperature, all_checks_passed = sim(state, step, neighbor) + + if np.any(np.isnan(state.position)) or np.any(np.isnan(state.velocity)): + raise ValueError( + f"NaN encountered, simulation aborted after {step+1} steps." + ) + + if not all_checks_passed: + with logging_redirect_tqdm(): + log.info( + f"One or more dynamics checks failed at step: {step+1}" + ) + break if neighbor.did_buffer_overflow: with logging_redirect_tqdm(): @@ -298,11 +322,6 @@ def body_fn(i, state): state = new_state step += 1 - if np.any(np.isnan(state.position)) or np.any(np.isnan(state.velocity)): - raise ValueError( - f"NaN encountered, simulation aborted after {step} steps." - ) - if step % checkpoint_interval == 0: with logging_redirect_tqdm(): current_sim_time = step * n_inner * ensemble.dt / 1000 @@ -475,6 +494,12 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"): traj_path = sim_dir / md_config.traj_name system, sim_fns = md_setup(model_config, md_config) + + dynamics_checks = [] + if md_config.dynamics_checks: + check_list = [DynamicsChecks(check.model_dump()) for check in md_config.dynamics_checks] + dynamics_checks.extend(check_list) + n_steps = int(np.ceil(md_config.duration / md_config.ensemble.dt)) traj_handler = H5TrajHandler( @@ -499,4 +524,5 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"): checkpoint_interval=md_config.checkpoint_interval, sim_dir=sim_dir, traj_handler=traj_handler, + dynamics_checks=dynamics_checks, ) From f62c3bd0ea464b26c4ab2ad8238b219cf72bcf0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 26 Aug 2024 09:51:44 +0200 Subject: [PATCH 2/9] remove debug import --- apax/md/dynamics_checks.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/apax/md/dynamics_checks.py b/apax/md/dynamics_checks.py index ee190f45..ee1ef63e 100644 --- a/apax/md/dynamics_checks.py +++ b/apax/md/dynamics_checks.py @@ -29,8 +29,6 @@ def check(self, predictions): return check_passed -from jax import debug - class ForceUncertaintyCheck(DynamicsCheckBase, extra="forbid"): name: Literal["forces_uncertainty"] = "forces_uncertainty" threshold: float From e6a36b42b3d563df32fac6cc23b1ec61148a4675 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 07:53:30 +0000 Subject: [PATCH 3/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/md/dynamics_checks.py | 10 +++++----- apax/md/simulate.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/apax/md/dynamics_checks.py b/apax/md/dynamics_checks.py index ee1ef63e..ad95b0b0 100644 --- a/apax/md/dynamics_checks.py +++ b/apax/md/dynamics_checks.py @@ -15,11 +15,11 @@ class EnergyUncertaintyCheck(DynamicsCheckBase, extra="forbid"): per_atom: bool = True def check(self, predictions): - + if not "energy_uncertainty" in predictions.keys(): m = "No energy uncertainty found. Are you using a model ensemble?" raise ValueError(m) - + energy_uncertainty = predictions["energy_uncertainty"] if self.per_atom: n_atoms = predictions["forces"].shape[0] @@ -27,18 +27,18 @@ def check(self, predictions): check_passed = jnp.all(energy_uncertainty < self.threshold) return check_passed - + class ForceUncertaintyCheck(DynamicsCheckBase, extra="forbid"): name: Literal["forces_uncertainty"] = "forces_uncertainty" threshold: float def check(self, predictions): - + if not "forces_uncertainty" in predictions.keys(): m = "No force uncertainties found. Are you using a model ensemble?" raise ValueError(m) - + forces_uncertainty = predictions["forces_uncertainty"] check_passed = jnp.all(forces_uncertainty < self.threshold) diff --git a/apax/md/simulate.py b/apax/md/simulate.py index 03c93425..0198cf2f 100644 --- a/apax/md/simulate.py +++ b/apax/md/simulate.py @@ -303,7 +303,7 @@ def body_fn(i, state): raise ValueError( f"NaN encountered, simulation aborted after {step+1} steps." ) - + if not all_checks_passed: with logging_redirect_tqdm(): log.info( From 93ee222c156519090dd5976be1185c50c49a32ea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 08:12:03 +0000 Subject: [PATCH 4/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/config/md_config.py | 2 +- apax/md/dynamics_checks.py | 8 ++++---- apax/md/simulate.py | 1 - 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/apax/config/md_config.py b/apax/config/md_config.py index 29ec8cae..e316b620 100644 --- a/apax/config/md_config.py +++ b/apax/config/md_config.py @@ -2,10 +2,10 @@ # from types import UnionType from typing import Literal, Union -from typing_extensions import Annotated import yaml from pydantic import BaseModel, Field, NonNegativeInt, PositiveFloat, PositiveInt +from typing_extensions import Annotated class ConstantTempSchedule(BaseModel, extra="forbid"): diff --git a/apax/md/dynamics_checks.py b/apax/md/dynamics_checks.py index ad95b0b0..55cbf1f0 100644 --- a/apax/md/dynamics_checks.py +++ b/apax/md/dynamics_checks.py @@ -1,8 +1,8 @@ -from dataclasses import dataclass from typing import Literal, Union + import jax.numpy as jnp -from pydantic import TypeAdapter, BaseModel +from pydantic import BaseModel, TypeAdapter class DynamicsCheckBase(BaseModel): @@ -16,7 +16,7 @@ class EnergyUncertaintyCheck(DynamicsCheckBase, extra="forbid"): def check(self, predictions): - if not "energy_uncertainty" in predictions.keys(): + if "energy_uncertainty" not in predictions.keys(): m = "No energy uncertainty found. Are you using a model ensemble?" raise ValueError(m) @@ -35,7 +35,7 @@ class ForceUncertaintyCheck(DynamicsCheckBase, extra="forbid"): def check(self, predictions): - if not "forces_uncertainty" in predictions.keys(): + if "forces_uncertainty" not in predictions.keys(): m = "No force uncertainties found. Are you using a model ensemble?" raise ValueError(m) diff --git a/apax/md/simulate.py b/apax/md/simulate.py index 0198cf2f..64870c58 100644 --- a/apax/md/simulate.py +++ b/apax/md/simulate.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass import logging import time from functools import partial From 49e07e8acd5b45b2abc2c8d7b66020b9b84c5b6f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 08:16:31 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/config/md_config.py | 3 +-- apax/md/dynamics_checks.py | 6 ++---- apax/md/simulate.py | 16 ++++++++-------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/apax/config/md_config.py b/apax/config/md_config.py index e316b620..f126fc52 100644 --- a/apax/config/md_config.py +++ b/apax/config/md_config.py @@ -191,13 +191,12 @@ class NPTOptions(NVTOptions, extra="forbid"): barostat_chain: NHCOptions = NHCOptions(tau=1000) - - class EnergyUncertaintyCheck(BaseModel, extra="forbid"): name: Literal["energy_uncertainty"] = "energy_uncertainty" threshold: PositiveFloat per_atom: bool = True + class ForcesUncertaintyCheck(BaseModel, extra="forbid"): name: Literal["forces_uncertainty"] = "forces_uncertainty" threshold: PositiveFloat diff --git a/apax/md/dynamics_checks.py b/apax/md/dynamics_checks.py index 55cbf1f0..f5ef758e 100644 --- a/apax/md/dynamics_checks.py +++ b/apax/md/dynamics_checks.py @@ -1,4 +1,3 @@ - from typing import Literal, Union import jax.numpy as jnp @@ -9,13 +8,13 @@ class DynamicsCheckBase(BaseModel): def check(self, predictions): pass + class EnergyUncertaintyCheck(DynamicsCheckBase, extra="forbid"): name: Literal["energy_uncertainty"] = "energy_uncertainty" threshold: float per_atom: bool = True def check(self, predictions): - if "energy_uncertainty" not in predictions.keys(): m = "No energy uncertainty found. Are you using a model ensemble?" raise ValueError(m) @@ -23,7 +22,7 @@ def check(self, predictions): energy_uncertainty = predictions["energy_uncertainty"] if self.per_atom: n_atoms = predictions["forces"].shape[0] - energy_uncertainty = energy_uncertainty/ n_atoms + energy_uncertainty = energy_uncertainty / n_atoms check_passed = jnp.all(energy_uncertainty < self.threshold) return check_passed @@ -34,7 +33,6 @@ class ForceUncertaintyCheck(DynamicsCheckBase, extra="forbid"): threshold: float def check(self, predictions): - if "forces_uncertainty" not in predictions.keys(): m = "No force uncertainties found. Are you using a model ensemble?" raise ValueError(m) diff --git a/apax/md/simulate.py b/apax/md/simulate.py index 64870c58..c4481b13 100644 --- a/apax/md/simulate.py +++ b/apax/md/simulate.py @@ -296,18 +296,16 @@ def body_fn(i, state): leave=True, ) while step < n_outer: - new_state, neighbor, current_temperature, all_checks_passed = sim(state, step, neighbor) + new_state, neighbor, current_temperature, all_checks_passed = sim( + state, step, neighbor + ) if np.any(np.isnan(state.position)) or np.any(np.isnan(state.velocity)): - raise ValueError( - f"NaN encountered, simulation aborted after {step+1} steps." - ) + raise ValueError(f"NaN encountered, simulation aborted after {step+1} steps.") if not all_checks_passed: with logging_redirect_tqdm(): - log.info( - f"One or more dynamics checks failed at step: {step+1}" - ) + log.info(f"One or more dynamics checks failed at step: {step+1}") break if neighbor.did_buffer_overflow: @@ -509,7 +507,9 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"): dynamics_checks = [] if md_config.dynamics_checks: - check_list = [DynamicsChecks(check.model_dump()) for check in md_config.dynamics_checks] + check_list = [ + DynamicsChecks(check.model_dump()) for check in md_config.dynamics_checks + ] dynamics_checks.extend(check_list) n_steps = int(np.ceil(md_config.duration / md_config.ensemble.dt)) From 84610b141e1269df38b867f4797c393bfdf64b2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 26 Aug 2024 10:18:58 +0200 Subject: [PATCH 6/9] changed threshold logging to critical --- apax/md/simulate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/md/simulate.py b/apax/md/simulate.py index c4481b13..eb609a99 100644 --- a/apax/md/simulate.py +++ b/apax/md/simulate.py @@ -305,7 +305,7 @@ def body_fn(i, state): if not all_checks_passed: with logging_redirect_tqdm(): - log.info(f"One or more dynamics checks failed at step: {step+1}") + log.critical(f"One or more dynamics checks failed at step: {step+1}") break if neighbor.did_buffer_overflow: From 9043beae61b805e0a707fb3c3c96ddb0b88cad4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 26 Aug 2024 14:56:01 +0200 Subject: [PATCH 7/9] added test for schedules and threshold --- tests/conftest.py | 1 + tests/integration_tests/md/config.yaml | 14 +++++-- .../md/md_config_threshold.yaml | 17 +++++++++ tests/integration_tests/md/test_md.py | 38 +++++++++++++++++++ 4 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 tests/integration_tests/md/md_config_threshold.yaml diff --git a/tests/conftest.py b/tests/conftest.py index 45061d4d..38c23f32 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -160,3 +160,4 @@ def load_and_dump_config(config_path, dump_path): def load_config_and_run_training(config_path, updated_config): config_dict = mod_config(config_path, updated_config) run(config_dict) + return config_dict diff --git a/tests/integration_tests/md/config.yaml b/tests/integration_tests/md/config.yaml index 2d7bd124..9bf42fd3 100644 --- a/tests/integration_tests/md/config.yaml +++ b/tests/integration_tests/md/config.yaml @@ -1,12 +1,12 @@ -n_epochs: 2 +n_epochs: 5 data: experiment: apax_dummy data_path: dummy_ds # ds.extxyz #ethanol.traj # ds.extxyz # - n_train: 4 + n_train: 10 n_valid: 2 - batch_size: 2 + batch_size: 1 valid_batch_size: 2 model: @@ -16,6 +16,12 @@ model: name: gaussian n_basis: 5 n_radial: 3 + + ensemble: + kind: shallow + n_members: 4 + force_variance: true + descriptor_dtype: fp64 readout_dtype: fp32 scale_shift_dtype: fp64 @@ -27,4 +33,6 @@ metrics: loss: - name: energy + loss_type: crps - name: forces + loss_type: crps diff --git a/tests/integration_tests/md/md_config_threshold.yaml b/tests/integration_tests/md/md_config_threshold.yaml new file mode 100644 index 00000000..b1d01d11 --- /dev/null +++ b/tests/integration_tests/md/md_config_threshold.yaml @@ -0,0 +1,17 @@ +ensemble: + name: nvt + dt: 0.1 # fs time step + temperature_schedule: + name: piecewise + T0: 5 # K + values: [100, 200, 1000] + steps: [10, 20, 30] + +duration: 100 # fs +n_inner: 1 +sampling_rate: 1 +checkpoint_interval: 2 +restart: True +dynamics_checks: + - name: forces_uncertainty + threshold: 1.0 \ No newline at end of file diff --git a/tests/integration_tests/md/test_md.py b/tests/integration_tests/md/test_md.py index 59188115..40266b0c 100644 --- a/tests/integration_tests/md/test_md.py +++ b/tests/integration_tests/md/test_md.py @@ -1,9 +1,11 @@ import os import pathlib +import uuid import jax import jax.numpy as jnp import numpy as np +import pytest import yaml import znh5md from ase import Atoms @@ -14,6 +16,7 @@ from apax.md import run_md from apax.md.ase_calc import ASECalculator from apax.utils import jax_md_reduced +from tests.conftest import load_config_and_run_training TEST_PATH = pathlib.Path(__file__).parent.resolve() @@ -178,3 +181,38 @@ def test_ase_calc(get_tmp_path): assert "energy_uncertainty" in atoms.calc.results.keys() assert "forces_uncertainty" in atoms.calc.results.keys() assert "stress_uncertainty" in atoms.calc.results.keys() + + + +@pytest.mark.parametrize("num_data", (30,)) +def test_jaxmd_schedule_and_thresold(get_tmp_path, example_dataset): + model_confg_path = TEST_PATH / "config.yaml" + working_dir = get_tmp_path / str(uuid.uuid4()) + data_path = get_tmp_path / "ds.extxyz" + + write(data_path, example_dataset) + + data_config_mods = { + "data": { + "directory": working_dir.as_posix(), + "experiment": "model", + "data_path": data_path.as_posix(), + }, + } + model_config_dict = load_config_and_run_training(model_confg_path, data_config_mods) + + md_confg_path = TEST_PATH / "md_config.yaml" + + with open(md_confg_path.as_posix(), "r") as stream: + md_config_dict = yaml.safe_load(stream) + md_config_dict["sim_dir"] = get_tmp_path.as_posix() + md_config_dict["initial_structure"] = get_tmp_path.as_posix() + "/ds.extxyz" + md_config = MDConfig.model_validate(md_config_dict) + + model_config = Config.model_validate(model_config_dict) + + run_md(model_config, md_config) + + traj = znh5md.IO(md_config.sim_dir + "/" + md_config.traj_name)[:] + assert len(traj) < 1000 # num steps + From ca31e8500c6d392f7cd018741d68f0f144b803a4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 12:56:45 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/integration_tests/md/md_config_threshold.yaml | 2 +- tests/integration_tests/md/test_md.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/integration_tests/md/md_config_threshold.yaml b/tests/integration_tests/md/md_config_threshold.yaml index b1d01d11..13a662f6 100644 --- a/tests/integration_tests/md/md_config_threshold.yaml +++ b/tests/integration_tests/md/md_config_threshold.yaml @@ -14,4 +14,4 @@ checkpoint_interval: 2 restart: True dynamics_checks: - name: forces_uncertainty - threshold: 1.0 \ No newline at end of file + threshold: 1.0 diff --git a/tests/integration_tests/md/test_md.py b/tests/integration_tests/md/test_md.py index 40266b0c..99f2e43a 100644 --- a/tests/integration_tests/md/test_md.py +++ b/tests/integration_tests/md/test_md.py @@ -215,4 +215,3 @@ def test_jaxmd_schedule_and_thresold(get_tmp_path, example_dataset): traj = znh5md.IO(md_config.sim_dir + "/" + md_config.traj_name)[:] assert len(traj) < 1000 # num steps - From 9088876d31c25d29d9ca568aa9631c881a7b91d5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 12:57:22 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/integration_tests/md/test_md.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration_tests/md/test_md.py b/tests/integration_tests/md/test_md.py index 99f2e43a..d7e85004 100644 --- a/tests/integration_tests/md/test_md.py +++ b/tests/integration_tests/md/test_md.py @@ -183,7 +183,6 @@ def test_ase_calc(get_tmp_path): assert "stress_uncertainty" in atoms.calc.results.keys() - @pytest.mark.parametrize("num_data", (30,)) def test_jaxmd_schedule_and_thresold(get_tmp_path, example_dataset): model_confg_path = TEST_PATH / "config.yaml"