From 7397275e246b6a0657a6a3126299dc064ecca19b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sun, 11 Aug 2024 12:33:00 +0200 Subject: [PATCH 01/11] readout optional in feature model --- apax/nn/models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/apax/nn/models.py b/apax/nn/models.py index b18dc22e..de7df810 100644 --- a/apax/nn/models.py +++ b/apax/nn/models.py @@ -79,8 +79,10 @@ def __call__( perturbation, ) - gm = self.descriptor(dr_vec, Z, idx) - features = jax.vmap(self.readout)(gm) + features = self.descriptor(dr_vec, Z, idx) + + if self.readout: + features = jax.vmap(self.readout)(features) if self.mask_atoms: features = mask_by_atom(features, Z) From e07685a6c7a13a6275ab3d234d38dcaf31fd1e50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sun, 11 Aug 2024 12:33:14 +0200 Subject: [PATCH 02/11] added ensemble output to make ensemble --- apax/md/ase_calc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/apax/md/ase_calc.py b/apax/md/ase_calc.py index a5915d6e..aa7a47df 100644 --- a/apax/md/ase_calc.py +++ b/apax/md/ase_calc.py @@ -82,8 +82,10 @@ def make_ensemble(model): def ensemble(positions, Z, idx, box, offsets): results = model(positions, Z, idx, box, offsets) uncertainty = {k + "_uncertainty": jnp.std(v, axis=0) for k, v in results.items()} + ensemble = {k + "_ensemble": v for k, v in results.items()} results = {k: jnp.mean(v, axis=0) for k, v in results.items()} results.update(uncertainty) + results.update(ensemble) return results From 0e4fd3aa6974da1094897adb6f12adc967fb00b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 2 Sep 2024 10:48:57 +0200 Subject: [PATCH 03/11] added jaxmd properties whitelist --- apax/config/md_config.py | 13 ++++++++-- apax/md/io.py | 25 ++++++++++++++----- apax/md/simulate.py | 5 ++-- apax/utils/helpers.py | 23 ++++++++++------- .../md/md_config_threshold.yaml | 13 +++++++--- tests/integration_tests/md/test_md.py | 12 ++++++++- 6 files changed, 67 insertions(+), 24 deletions(-) diff --git a/apax/config/md_config.py b/apax/config/md_config.py index f126fc52..116297b3 100644 --- a/apax/config/md_config.py +++ b/apax/config/md_config.py @@ -1,12 +1,11 @@ import os - -# from types import UnionType from typing import Literal, Union import yaml from pydantic import BaseModel, Field, NonNegativeInt, PositiveFloat, PositiveInt from typing_extensions import Annotated +from apax.utils.helpers import APAX_PROPERTIES class ConstantTempSchedule(BaseModel, extra="forbid"): """Constant temperature schedule. @@ -234,6 +233,14 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"): extra_capacity : int, default = 0 | JaxMD allocates a maximal number of neighbors. This argument lets you add | additional capacity to avoid recompilation. The default is usually fine. + + dynamics_checks: list[DynamicsCheck] + | List of termination criteria. Currently energy and force uncertainty + | are available + properties: list[str] + | Whitelist of properties to be saved in the trajectory. + | This does not effect what the model will calculate, e.g.. + | an ensemble will still calculate uncertainties. initial_structure : str, required | Path to the starting structure of the simulation. sim_dir : str, default = "." @@ -266,6 +273,8 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"): dynamics_checks: list[DynamicsCheck] = [] + properties: list[str] = APAX_PROPERTIES + initial_structure: str load_momenta: bool = False sim_dir: str = "." diff --git a/apax/md/io.py b/apax/md/io.py index 26452031..d4c509b9 100644 --- a/apax/md/io.py +++ b/apax/md/io.py @@ -14,12 +14,22 @@ class TrajHandler: - def __init__(self) -> None: - self.system: System - self.sampling_rate: int - self.buffer_size: int - self.traj_path: Path - self.time_step: float + def __init__( + self, + system: System, + sampling_rate: int, + buffer_size: int, + traj_path: Path, + time_step: float = 0.5, + properties: list[str] = [], + ) -> None: + self.atomic_numbers = system.atomic_numbers + self.box = system.box + self.fractional = np.any(self.box > 1e-6) + self.sampling_rate = sampling_rate + self.traj_path = traj_path + self.time_step = time_step + self.properties = properties def step(self, state_and_energy, transform=None): pass @@ -53,6 +63,7 @@ def atoms_from_state(self, state, predictions, nbr_kwargs): atoms.pbc = np.diag(atoms.cell.array) > 1e-6 predictions = {k: np.array(v) for k, v in predictions.items()} predictions["energy"] = predictions["energy"].item() + predictions = {k: v for k,v in predictions.items() if k in self.properties} atoms.calc = SinglePointCalculator(atoms, **predictions) return atoms @@ -65,6 +76,7 @@ def __init__( buffer_size: int, traj_path: Path, time_step: float = 0.5, + properties: list[str] = [], ) -> None: self.atomic_numbers = system.atomic_numbers self.box = system.box @@ -72,6 +84,7 @@ def __init__( self.sampling_rate = sampling_rate self.traj_path = traj_path self.time_step = time_step + self.properties = properties self.db = znh5md.IO( self.traj_path, timestep=self.time_step, store="time", save_units=False ) diff --git a/apax/md/simulate.py b/apax/md/simulate.py index eb609a99..974001d8 100644 --- a/apax/md/simulate.py +++ b/apax/md/simulate.py @@ -166,10 +166,10 @@ def run_sim( n_inner: int, extra_capacity: int, rng_key: int, + traj_handler: TrajHandler, load_momenta: bool = False, restart: bool = True, checkpoint_interval: int = 50_000, - traj_handler: TrajHandler = TrajHandler(), dynamics_checks: list[DynamicsCheckBase] = [], disable_pbar: bool = False, ): @@ -520,6 +520,7 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"): md_config.buffer_size, traj_path, md_config.ensemble.dt, + properties=md_config.properties, ) # TODO implement correct chunking @@ -531,10 +532,10 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"): n_inner=md_config.n_inner, extra_capacity=md_config.extra_capacity, load_momenta=md_config.load_momenta, + traj_handler=traj_handler, rng_key=jax.random.PRNGKey(md_config.seed), restart=md_config.restart, checkpoint_interval=md_config.checkpoint_interval, sim_dir=sim_dir, - traj_handler=traj_handler, dynamics_checks=dynamics_checks, ) diff --git a/apax/utils/helpers.py b/apax/utils/helpers.py index acb80830..7ee0d9e9 100644 --- a/apax/utils/helpers.py +++ b/apax/utils/helpers.py @@ -3,21 +3,26 @@ import yaml +APAX_PROPERTIES = [ + "energy", + "forces", + "stress", + "forces_uncertainty", + "energy_uncertainty", + "stress_uncertainty", + "energy_ensemble", + "forces_ensemble", + "stress_ensemble", +] + + def setup_ase(): """Add uncertainty keys to ASE all properties. from https://github.com/zincware/IPSuite/blob/main/ipsuite/utils/helpers.py#L10 """ from ase.calculators.calculator import all_properties - additional_keys = [ - "forces_uncertainty", - "energy_uncertainty", - "stress_uncertainty", - "energy_ensemble", - "forces_ensemble", - ] - - for val in additional_keys: + for val in APAX_PROPERTIES: if val not in all_properties: all_properties.append(val) diff --git a/tests/integration_tests/md/md_config_threshold.yaml b/tests/integration_tests/md/md_config_threshold.yaml index 13a662f6..0d79d578 100644 --- a/tests/integration_tests/md/md_config_threshold.yaml +++ b/tests/integration_tests/md/md_config_threshold.yaml @@ -1,17 +1,22 @@ ensemble: name: nvt - dt: 0.1 # fs time step + dt: 0.2 # fs time step temperature_schedule: name: piecewise - T0: 5 # K + T0: 50 # K values: [100, 200, 1000] steps: [10, 20, 30] -duration: 100 # fs +duration: 500 # fs n_inner: 1 sampling_rate: 1 checkpoint_interval: 2 restart: True dynamics_checks: - name: forces_uncertainty - threshold: 1.0 + threshold: 0.01 +properties: + - energy + - forces + - energy_uncertainty + - forces_ensemble \ No newline at end of file diff --git a/tests/integration_tests/md/test_md.py b/tests/integration_tests/md/test_md.py index d7e85004..b4df1850 100644 --- a/tests/integration_tests/md/test_md.py +++ b/tests/integration_tests/md/test_md.py @@ -200,7 +200,7 @@ def test_jaxmd_schedule_and_thresold(get_tmp_path, example_dataset): } model_config_dict = load_config_and_run_training(model_confg_path, data_config_mods) - md_confg_path = TEST_PATH / "md_config.yaml" + md_confg_path = TEST_PATH / "md_config_threshold.yaml" with open(md_confg_path.as_posix(), "r") as stream: md_config_dict = yaml.safe_load(stream) @@ -214,3 +214,13 @@ def test_jaxmd_schedule_and_thresold(get_tmp_path, example_dataset): traj = znh5md.IO(md_config.sim_dir + "/" + md_config.traj_name)[:] assert len(traj) < 1000 # num steps + + results_keys = list(traj[0].calc.results.keys()) + + assert "energy" in results_keys + assert "forces" in results_keys + assert "energy_uncertainty" in results_keys + assert "forces_ensemble" in results_keys + + assert "energy_ensemble" not in results_keys + assert "forces_uncertainty" not in results_keys From a54b6860f0e1d0855df83857fd16815b6e5e066a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 08:49:12 +0000 Subject: [PATCH 04/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/integration_tests/md/md_config_threshold.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/md/md_config_threshold.yaml b/tests/integration_tests/md/md_config_threshold.yaml index 0d79d578..728b8f3e 100644 --- a/tests/integration_tests/md/md_config_threshold.yaml +++ b/tests/integration_tests/md/md_config_threshold.yaml @@ -19,4 +19,4 @@ properties: - energy - forces - energy_uncertainty - - forces_ensemble \ No newline at end of file + - forces_ensemble From f2d0f5e766145098aa354e4f0e503cbc138a485b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 08:51:21 +0000 Subject: [PATCH 05/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/config/md_config.py | 1 + apax/utils/helpers.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/config/md_config.py b/apax/config/md_config.py index 116297b3..7e732c29 100644 --- a/apax/config/md_config.py +++ b/apax/config/md_config.py @@ -7,6 +7,7 @@ from apax.utils.helpers import APAX_PROPERTIES + class ConstantTempSchedule(BaseModel, extra="forbid"): """Constant temperature schedule. diff --git a/apax/utils/helpers.py b/apax/utils/helpers.py index 7ee0d9e9..78efaa12 100644 --- a/apax/utils/helpers.py +++ b/apax/utils/helpers.py @@ -2,7 +2,6 @@ import yaml - APAX_PROPERTIES = [ "energy", "forces", From 9dc30410457d8e796b10ed5f0f124a11e69e65a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 08:56:38 +0000 Subject: [PATCH 06/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/md/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/md/io.py b/apax/md/io.py index d4c509b9..0b95d700 100644 --- a/apax/md/io.py +++ b/apax/md/io.py @@ -63,7 +63,7 @@ def atoms_from_state(self, state, predictions, nbr_kwargs): atoms.pbc = np.diag(atoms.cell.array) > 1e-6 predictions = {k: np.array(v) for k, v in predictions.items()} predictions["energy"] = predictions["energy"].item() - predictions = {k: v for k,v in predictions.items() if k in self.properties} + predictions = {k: v for k, v in predictions.items() if k in self.properties} atoms.calc = SinglePointCalculator(atoms, **predictions) return atoms From 124856082dd3830b1e65dd4f9f5d5d6538b2e4c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 2 Sep 2024 16:57:55 +0200 Subject: [PATCH 07/11] set all properties as default --- apax/md/io.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/apax/md/io.py b/apax/md/io.py index 0b95d700..39b0ca6d 100644 --- a/apax/md/io.py +++ b/apax/md/io.py @@ -8,6 +8,7 @@ from ase.calculators.singlepoint import SinglePointCalculator from apax.md.sim_utils import System +from apax.utils.helpers import APAX_PROPERTIES from apax.utils.jax_md_reduced import space log = logging.getLogger(__name__) @@ -21,7 +22,7 @@ def __init__( buffer_size: int, traj_path: Path, time_step: float = 0.5, - properties: list[str] = [], + properties: list[str] = APAX_PROPERTIES, ) -> None: self.atomic_numbers = system.atomic_numbers self.box = system.box From 61f0205fd654394236626fc49fbbeb7dccc006fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Thu, 5 Sep 2024 15:55:10 +0200 Subject: [PATCH 08/11] new convention for ensemble prediction order --- apax/md/ase_calc.py | 4 ++++ apax/nn/models.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/apax/md/ase_calc.py b/apax/md/ase_calc.py index aa7a47df..6c8d754d 100644 --- a/apax/md/ase_calc.py +++ b/apax/md/ase_calc.py @@ -84,6 +84,10 @@ def ensemble(positions, Z, idx, box, offsets): uncertainty = {k + "_uncertainty": jnp.std(v, axis=0) for k, v in results.items()} ensemble = {k + "_ensemble": v for k, v in results.items()} results = {k: jnp.mean(v, axis=0) for k, v in results.items()} + if "forces_ensemble" in ensemble.keys(): + ensemble["forces_ensemble"] = jnp.transpose(ensemble["forces_ensemble"], (1,2,0)) + if "forces_ensemble" in ensemble.keys(): + ensemble["stress_ensemble"] = jnp.transpose(ensemble["forces_ensemble"], (1,2,0)) results.update(uncertainty) results.update(ensemble) diff --git a/apax/nn/models.py b/apax/nn/models.py index de7df810..53228887 100644 --- a/apax/nn/models.py +++ b/apax/nn/models.py @@ -270,7 +270,9 @@ def __call__( prediction["forces"] = forces_mean prediction["forces_uncertainty"] = jnp.sqrt(forces_variance) - prediction["forces_ensemble"] = forces_ens + + forces_ens + jnp.transpose(forces_ens, (1,2,0)) + prediction["forces_ensemble"] = forces_ens # n_atoms x 3 x n_members else: forces_mean = -jax.grad(mean_energy_fn)(R, Z, neighbor, box, offsets) From d7f914c50cffec6e68c80d1cd766af22b0993bcb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 13:55:26 +0000 Subject: [PATCH 09/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/nn/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/nn/models.py b/apax/nn/models.py index 53228887..9df8836b 100644 --- a/apax/nn/models.py +++ b/apax/nn/models.py @@ -271,7 +271,7 @@ def __call__( prediction["forces"] = forces_mean prediction["forces_uncertainty"] = jnp.sqrt(forces_variance) - forces_ens + jnp.transpose(forces_ens, (1,2,0)) + forces_ens + jnp.transpose(forces_ens, (1,2,0)) prediction["forces_ensemble"] = forces_ens # n_atoms x 3 x n_members else: From fc2179f58b3d418b58d00742948a029ce8595219 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Sep 2024 11:45:28 +0000 Subject: [PATCH 10/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/md/ase_calc.py | 8 ++++++-- apax/nn/models.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/apax/md/ase_calc.py b/apax/md/ase_calc.py index 6c8d754d..3e0d064c 100644 --- a/apax/md/ase_calc.py +++ b/apax/md/ase_calc.py @@ -85,9 +85,13 @@ def ensemble(positions, Z, idx, box, offsets): ensemble = {k + "_ensemble": v for k, v in results.items()} results = {k: jnp.mean(v, axis=0) for k, v in results.items()} if "forces_ensemble" in ensemble.keys(): - ensemble["forces_ensemble"] = jnp.transpose(ensemble["forces_ensemble"], (1,2,0)) + ensemble["forces_ensemble"] = jnp.transpose( + ensemble["forces_ensemble"], (1, 2, 0) + ) if "forces_ensemble" in ensemble.keys(): - ensemble["stress_ensemble"] = jnp.transpose(ensemble["forces_ensemble"], (1,2,0)) + ensemble["stress_ensemble"] = jnp.transpose( + ensemble["forces_ensemble"], (1, 2, 0) + ) results.update(uncertainty) results.update(ensemble) diff --git a/apax/nn/models.py b/apax/nn/models.py index 9df8836b..8c268cb6 100644 --- a/apax/nn/models.py +++ b/apax/nn/models.py @@ -271,8 +271,8 @@ def __call__( prediction["forces"] = forces_mean prediction["forces_uncertainty"] = jnp.sqrt(forces_variance) - forces_ens + jnp.transpose(forces_ens, (1,2,0)) - prediction["forces_ensemble"] = forces_ens # n_atoms x 3 x n_members + forces_ens + jnp.transpose(forces_ens, (1, 2, 0)) + prediction["forces_ensemble"] = forces_ens # n_atoms x 3 x n_members else: forces_mean = -jax.grad(mean_energy_fn)(R, Z, neighbor, box, offsets) From 4d43d8ff0347b100a99367c87fa006914c7e51b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 6 Sep 2024 13:55:50 +0200 Subject: [PATCH 11/11] fix syntax error --- apax/nn/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/nn/models.py b/apax/nn/models.py index 9df8836b..20270b7f 100644 --- a/apax/nn/models.py +++ b/apax/nn/models.py @@ -271,7 +271,7 @@ def __call__( prediction["forces"] = forces_mean prediction["forces_uncertainty"] = jnp.sqrt(forces_variance) - forces_ens + jnp.transpose(forces_ens, (1,2,0)) + forces_ens = jnp.transpose(forces_ens, (1,2,0)) prediction["forces_ensemble"] = forces_ens # n_atoms x 3 x n_members else: