diff --git a/apax/config/md_config.py b/apax/config/md_config.py index f126fc52..7e732c29 100644 --- a/apax/config/md_config.py +++ b/apax/config/md_config.py @@ -1,12 +1,12 @@ import os - -# from types import UnionType from typing import Literal, Union import yaml from pydantic import BaseModel, Field, NonNegativeInt, PositiveFloat, PositiveInt from typing_extensions import Annotated +from apax.utils.helpers import APAX_PROPERTIES + class ConstantTempSchedule(BaseModel, extra="forbid"): """Constant temperature schedule. @@ -234,6 +234,14 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"): extra_capacity : int, default = 0 | JaxMD allocates a maximal number of neighbors. This argument lets you add | additional capacity to avoid recompilation. The default is usually fine. + + dynamics_checks: list[DynamicsCheck] + | List of termination criteria. Currently energy and force uncertainty + | are available + properties: list[str] + | Whitelist of properties to be saved in the trajectory. + | This does not effect what the model will calculate, e.g.. + | an ensemble will still calculate uncertainties. initial_structure : str, required | Path to the starting structure of the simulation. sim_dir : str, default = "." @@ -266,6 +274,8 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"): dynamics_checks: list[DynamicsCheck] = [] + properties: list[str] = APAX_PROPERTIES + initial_structure: str load_momenta: bool = False sim_dir: str = "." diff --git a/apax/md/ase_calc.py b/apax/md/ase_calc.py index a5915d6e..3e0d064c 100644 --- a/apax/md/ase_calc.py +++ b/apax/md/ase_calc.py @@ -82,8 +82,18 @@ def make_ensemble(model): def ensemble(positions, Z, idx, box, offsets): results = model(positions, Z, idx, box, offsets) uncertainty = {k + "_uncertainty": jnp.std(v, axis=0) for k, v in results.items()} + ensemble = {k + "_ensemble": v for k, v in results.items()} results = {k: jnp.mean(v, axis=0) for k, v in results.items()} + if "forces_ensemble" in ensemble.keys(): + ensemble["forces_ensemble"] = jnp.transpose( + ensemble["forces_ensemble"], (1, 2, 0) + ) + if "forces_ensemble" in ensemble.keys(): + ensemble["stress_ensemble"] = jnp.transpose( + ensemble["forces_ensemble"], (1, 2, 0) + ) results.update(uncertainty) + results.update(ensemble) return results diff --git a/apax/md/io.py b/apax/md/io.py index 26452031..39b0ca6d 100644 --- a/apax/md/io.py +++ b/apax/md/io.py @@ -8,18 +8,29 @@ from ase.calculators.singlepoint import SinglePointCalculator from apax.md.sim_utils import System +from apax.utils.helpers import APAX_PROPERTIES from apax.utils.jax_md_reduced import space log = logging.getLogger(__name__) class TrajHandler: - def __init__(self) -> None: - self.system: System - self.sampling_rate: int - self.buffer_size: int - self.traj_path: Path - self.time_step: float + def __init__( + self, + system: System, + sampling_rate: int, + buffer_size: int, + traj_path: Path, + time_step: float = 0.5, + properties: list[str] = APAX_PROPERTIES, + ) -> None: + self.atomic_numbers = system.atomic_numbers + self.box = system.box + self.fractional = np.any(self.box > 1e-6) + self.sampling_rate = sampling_rate + self.traj_path = traj_path + self.time_step = time_step + self.properties = properties def step(self, state_and_energy, transform=None): pass @@ -53,6 +64,7 @@ def atoms_from_state(self, state, predictions, nbr_kwargs): atoms.pbc = np.diag(atoms.cell.array) > 1e-6 predictions = {k: np.array(v) for k, v in predictions.items()} predictions["energy"] = predictions["energy"].item() + predictions = {k: v for k, v in predictions.items() if k in self.properties} atoms.calc = SinglePointCalculator(atoms, **predictions) return atoms @@ -65,6 +77,7 @@ def __init__( buffer_size: int, traj_path: Path, time_step: float = 0.5, + properties: list[str] = [], ) -> None: self.atomic_numbers = system.atomic_numbers self.box = system.box @@ -72,6 +85,7 @@ def __init__( self.sampling_rate = sampling_rate self.traj_path = traj_path self.time_step = time_step + self.properties = properties self.db = znh5md.IO( self.traj_path, timestep=self.time_step, store="time", save_units=False ) diff --git a/apax/md/simulate.py b/apax/md/simulate.py index eb609a99..974001d8 100644 --- a/apax/md/simulate.py +++ b/apax/md/simulate.py @@ -166,10 +166,10 @@ def run_sim( n_inner: int, extra_capacity: int, rng_key: int, + traj_handler: TrajHandler, load_momenta: bool = False, restart: bool = True, checkpoint_interval: int = 50_000, - traj_handler: TrajHandler = TrajHandler(), dynamics_checks: list[DynamicsCheckBase] = [], disable_pbar: bool = False, ): @@ -520,6 +520,7 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"): md_config.buffer_size, traj_path, md_config.ensemble.dt, + properties=md_config.properties, ) # TODO implement correct chunking @@ -531,10 +532,10 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"): n_inner=md_config.n_inner, extra_capacity=md_config.extra_capacity, load_momenta=md_config.load_momenta, + traj_handler=traj_handler, rng_key=jax.random.PRNGKey(md_config.seed), restart=md_config.restart, checkpoint_interval=md_config.checkpoint_interval, sim_dir=sim_dir, - traj_handler=traj_handler, dynamics_checks=dynamics_checks, ) diff --git a/apax/nn/models.py b/apax/nn/models.py index b18dc22e..8dfa17cc 100644 --- a/apax/nn/models.py +++ b/apax/nn/models.py @@ -79,8 +79,10 @@ def __call__( perturbation, ) - gm = self.descriptor(dr_vec, Z, idx) - features = jax.vmap(self.readout)(gm) + features = self.descriptor(dr_vec, Z, idx) + + if self.readout: + features = jax.vmap(self.readout)(features) if self.mask_atoms: features = mask_by_atom(features, Z) @@ -268,7 +270,9 @@ def __call__( prediction["forces"] = forces_mean prediction["forces_uncertainty"] = jnp.sqrt(forces_variance) - prediction["forces_ensemble"] = forces_ens + + forces_ens = jnp.transpose(forces_ens, (1, 2, 0)) + prediction["forces_ensemble"] = forces_ens # n_atoms x 3 x n_members else: forces_mean = -jax.grad(mean_energy_fn)(R, Z, neighbor, box, offsets) diff --git a/apax/utils/helpers.py b/apax/utils/helpers.py index acb80830..78efaa12 100644 --- a/apax/utils/helpers.py +++ b/apax/utils/helpers.py @@ -2,6 +2,18 @@ import yaml +APAX_PROPERTIES = [ + "energy", + "forces", + "stress", + "forces_uncertainty", + "energy_uncertainty", + "stress_uncertainty", + "energy_ensemble", + "forces_ensemble", + "stress_ensemble", +] + def setup_ase(): """Add uncertainty keys to ASE all properties. @@ -9,15 +21,7 @@ def setup_ase(): """ from ase.calculators.calculator import all_properties - additional_keys = [ - "forces_uncertainty", - "energy_uncertainty", - "stress_uncertainty", - "energy_ensemble", - "forces_ensemble", - ] - - for val in additional_keys: + for val in APAX_PROPERTIES: if val not in all_properties: all_properties.append(val) diff --git a/tests/integration_tests/md/md_config_threshold.yaml b/tests/integration_tests/md/md_config_threshold.yaml index 13a662f6..728b8f3e 100644 --- a/tests/integration_tests/md/md_config_threshold.yaml +++ b/tests/integration_tests/md/md_config_threshold.yaml @@ -1,17 +1,22 @@ ensemble: name: nvt - dt: 0.1 # fs time step + dt: 0.2 # fs time step temperature_schedule: name: piecewise - T0: 5 # K + T0: 50 # K values: [100, 200, 1000] steps: [10, 20, 30] -duration: 100 # fs +duration: 500 # fs n_inner: 1 sampling_rate: 1 checkpoint_interval: 2 restart: True dynamics_checks: - name: forces_uncertainty - threshold: 1.0 + threshold: 0.01 +properties: + - energy + - forces + - energy_uncertainty + - forces_ensemble diff --git a/tests/integration_tests/md/test_md.py b/tests/integration_tests/md/test_md.py index d7e85004..b4df1850 100644 --- a/tests/integration_tests/md/test_md.py +++ b/tests/integration_tests/md/test_md.py @@ -200,7 +200,7 @@ def test_jaxmd_schedule_and_thresold(get_tmp_path, example_dataset): } model_config_dict = load_config_and_run_training(model_confg_path, data_config_mods) - md_confg_path = TEST_PATH / "md_config.yaml" + md_confg_path = TEST_PATH / "md_config_threshold.yaml" with open(md_confg_path.as_posix(), "r") as stream: md_config_dict = yaml.safe_load(stream) @@ -214,3 +214,13 @@ def test_jaxmd_schedule_and_thresold(get_tmp_path, example_dataset): traj = znh5md.IO(md_config.sim_dir + "/" + md_config.traj_name)[:] assert len(traj) < 1000 # num steps + + results_keys = list(traj[0].calc.results.keys()) + + assert "energy" in results_keys + assert "forces" in results_keys + assert "energy_uncertainty" in results_keys + assert "forces_ensemble" in results_keys + + assert "energy_ensemble" not in results_keys + assert "forces_uncertainty" not in results_keys