Skip to content

Commit

Permalink
Merge pull request #328 from apax-hub/dynamics_checks
Browse files Browse the repository at this point in the history
added uncertainty checks to jax md
  • Loading branch information
M-R-Schaefer authored Aug 26, 2024
2 parents 50af544 + 9088876 commit 67a0f23
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 19 deletions.
19 changes: 19 additions & 0 deletions apax/config/md_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import yaml
from pydantic import BaseModel, Field, NonNegativeInt, PositiveFloat, PositiveInt
from typing_extensions import Annotated


class ConstantTempSchedule(BaseModel, extra="forbid"):
Expand Down Expand Up @@ -190,6 +191,22 @@ class NPTOptions(NVTOptions, extra="forbid"):
barostat_chain: NHCOptions = NHCOptions(tau=1000)


class EnergyUncertaintyCheck(BaseModel, extra="forbid"):
name: Literal["energy_uncertainty"] = "energy_uncertainty"
threshold: PositiveFloat
per_atom: bool = True


class ForcesUncertaintyCheck(BaseModel, extra="forbid"):
name: Literal["forces_uncertainty"] = "forces_uncertainty"
threshold: PositiveFloat


DynamicsCheck = Annotated[
Union[EnergyUncertaintyCheck, ForcesUncertaintyCheck], Field(discriminator="name")
]


class MDConfig(BaseModel, frozen=True, extra="forbid"):
"""
Configuration for a NHC molecular dynamics simulation.
Expand Down Expand Up @@ -247,6 +264,8 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
dr_threshold: PositiveFloat = 0.5
extra_capacity: NonNegativeInt = 0

dynamics_checks: list[DynamicsCheck] = []

initial_structure: str
load_momenta: bool = False
sim_dir: str = "."
Expand Down
48 changes: 48 additions & 0 deletions apax/md/dynamics_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Literal, Union

import jax.numpy as jnp
from pydantic import BaseModel, TypeAdapter


class DynamicsCheckBase(BaseModel):
def check(self, predictions):
pass


class EnergyUncertaintyCheck(DynamicsCheckBase, extra="forbid"):
name: Literal["energy_uncertainty"] = "energy_uncertainty"
threshold: float
per_atom: bool = True

def check(self, predictions):
if "energy_uncertainty" not in predictions.keys():
m = "No energy uncertainty found. Are you using a model ensemble?"
raise ValueError(m)

energy_uncertainty = predictions["energy_uncertainty"]
if self.per_atom:
n_atoms = predictions["forces"].shape[0]
energy_uncertainty = energy_uncertainty / n_atoms

check_passed = jnp.all(energy_uncertainty < self.threshold)
return check_passed


class ForceUncertaintyCheck(DynamicsCheckBase, extra="forbid"):
name: Literal["forces_uncertainty"] = "forces_uncertainty"
threshold: float

def check(self, predictions):
if "forces_uncertainty" not in predictions.keys():
m = "No force uncertainties found. Are you using a model ensemble?"
raise ValueError(m)

forces_uncertainty = predictions["forces_uncertainty"]

check_passed = jnp.all(forces_uncertainty < self.threshold)
return check_passed


DynamicsChecks = TypeAdapter(
Union[EnergyUncertaintyCheck, ForceUncertaintyCheck]
).validate_python
57 changes: 41 additions & 16 deletions apax/md/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from apax.config import Config, MDConfig, parse_config
from apax.config.md_config import Integrator
from apax.md.ase_calc import make_ensemble, maybe_vmap
from apax.md.dynamics_checks import DynamicsCheckBase, DynamicsChecks
from apax.md.io import H5TrajHandler, TrajHandler, truncate_trajectory_to_checkpoint
from apax.md.md_checkpoint import load_md_state
from apax.md.sim_utils import SimulationFunctions, System
Expand Down Expand Up @@ -133,19 +134,25 @@ def handle_checkpoints(state, step, system, load_momenta, ckpt_dir, should_load_
return state, step


def create_evaluation_functions(aux_fn, positions, Z, neighbor, box):
def create_evaluation_functions(aux_fn, positions, Z, neighbor, box, dynamics_checks):
offsets = jnp.zeros((neighbor.idx.shape[1], 3))

def on_eval(positions, neighbor, box):
predictions = aux_fn(positions, Z, neighbor, box, offsets)
return predictions
all_checks_passed = True

for check in dynamics_checks:
check_passed = check.check(predictions)
all_checks_passed = all_checks_passed & check_passed
return predictions, all_checks_passed

predictions = aux_fn(positions, Z, neighbor, box, offsets)
dummpy_preds = jax.tree_map(lambda x: jnp.zeros_like(x), predictions)

def no_eval(positions, neighbor, box):
predictions = dummpy_preds
return predictions
all_checks_passed = True
return predictions, all_checks_passed

return on_eval, no_eval

Expand All @@ -163,6 +170,7 @@ def run_sim(
restart: bool = True,
checkpoint_interval: int = 50_000,
traj_handler: TrajHandler = TrajHandler(),
dynamics_checks: list[DynamicsCheckBase] = [],
disable_pbar: bool = False,
):
"""
Expand All @@ -188,7 +196,6 @@ def run_sim(
sim_dir : Path
Directory where the trajectory and simulation checkpoints will be saved.
"""
energy_fn = sim_fns.energy_fn
neighbor_fn = sim_fns.neighbor_fn
ckpt_dir = sim_dir / "ckpts"
ckpt_dir.mkdir(exist_ok=True)
Expand Down Expand Up @@ -231,12 +238,13 @@ def run_sim(
system.atomic_numbers,
neighbor,
system.box,
dynamics_checks,
)

@jax.jit
def sim(state, outer_step, neighbor): # TODO make more modular
def body_fn(i, state):
state, outer_step, neighbor = state
state, outer_step, neighbor, all_checks_passed = state
step = i + outer_step * n_inner

apply_fn_kwargs = {}
Expand All @@ -253,22 +261,25 @@ def body_fn(i, state):
neighbor = neighbor.update(state.position, **nbr_kwargs)

condition = step % sampling_rate == 0
predictions = jax.lax.cond(
predictions, check_passed = jax.lax.cond(
condition, on_eval, no_eval, state.position, neighbor, box
)

all_checks_passed = all_checks_passed & check_passed

# maybe move this to on_eval
io_callback(traj_handler.step, None, (state, predictions, nbr_kwargs))
return state, outer_step, neighbor
return state, outer_step, neighbor, all_checks_passed

state, outer_step, neighbor = jax.lax.fori_loop(
0, n_inner, body_fn, (state, outer_step, neighbor)
all_checks_passed = True
state, outer_step, neighbor, all_checks_passed = jax.lax.fori_loop(
0, n_inner, body_fn, (state, outer_step, neighbor, all_checks_passed)
)
current_temperature = (
quantity.temperature(velocity=state.velocity, mass=state.mass) / units.kB
)

return state, neighbor, current_temperature
return state, neighbor, current_temperature, all_checks_passed

start = time.time()
total_sim_time = n_steps * ensemble.dt / 1000
Expand All @@ -285,7 +296,17 @@ def body_fn(i, state):
leave=True,
)
while step < n_outer:
new_state, neighbor, current_temperature = sim(state, step, neighbor)
new_state, neighbor, current_temperature, all_checks_passed = sim(
state, step, neighbor
)

if np.any(np.isnan(state.position)) or np.any(np.isnan(state.velocity)):
raise ValueError(f"NaN encountered, simulation aborted after {step+1} steps.")

if not all_checks_passed:
with logging_redirect_tqdm():
log.critical(f"One or more dynamics checks failed at step: {step+1}")
break

if neighbor.did_buffer_overflow:
with logging_redirect_tqdm():
Expand All @@ -298,11 +319,6 @@ def body_fn(i, state):
state = new_state
step += 1

if np.any(np.isnan(state.position)) or np.any(np.isnan(state.velocity)):
raise ValueError(
f"NaN encountered, simulation aborted after {step} steps."
)

if step % checkpoint_interval == 0:
with logging_redirect_tqdm():
current_sim_time = step * n_inner * ensemble.dt / 1000
Expand Down Expand Up @@ -488,6 +504,14 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"):
traj_path = sim_dir / md_config.traj_name

system, sim_fns = md_setup(model_config, md_config)

dynamics_checks = []
if md_config.dynamics_checks:
check_list = [
DynamicsChecks(check.model_dump()) for check in md_config.dynamics_checks
]
dynamics_checks.extend(check_list)

n_steps = int(np.ceil(md_config.duration / md_config.ensemble.dt))

traj_handler = H5TrajHandler(
Expand All @@ -512,4 +536,5 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"):
checkpoint_interval=md_config.checkpoint_interval,
sim_dir=sim_dir,
traj_handler=traj_handler,
dynamics_checks=dynamics_checks,
)
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,4 @@ def load_and_dump_config(config_path, dump_path):
def load_config_and_run_training(config_path, updated_config):
config_dict = mod_config(config_path, updated_config)
run(config_dict)
return config_dict
14 changes: 11 additions & 3 deletions tests/integration_tests/md/config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
n_epochs: 2
n_epochs: 5

data:
experiment: apax_dummy
data_path: dummy_ds # ds.extxyz #ethanol.traj # ds.extxyz #

n_train: 4
n_train: 10
n_valid: 2
batch_size: 2
batch_size: 1
valid_batch_size: 2

model:
Expand All @@ -16,6 +16,12 @@ model:
name: gaussian
n_basis: 5
n_radial: 3

ensemble:
kind: shallow
n_members: 4
force_variance: true

descriptor_dtype: fp64
readout_dtype: fp32
scale_shift_dtype: fp64
Expand All @@ -27,4 +33,6 @@ metrics:

loss:
- name: energy
loss_type: crps
- name: forces
loss_type: crps
17 changes: 17 additions & 0 deletions tests/integration_tests/md/md_config_threshold.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
ensemble:
name: nvt
dt: 0.1 # fs time step
temperature_schedule:
name: piecewise
T0: 5 # K
values: [100, 200, 1000]
steps: [10, 20, 30]

duration: 100 # fs
n_inner: 1
sampling_rate: 1
checkpoint_interval: 2
restart: True
dynamics_checks:
- name: forces_uncertainty
threshold: 1.0
36 changes: 36 additions & 0 deletions tests/integration_tests/md/test_md.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import pathlib
import uuid

import jax
import jax.numpy as jnp
import numpy as np
import pytest
import yaml
import znh5md
from ase import Atoms
Expand All @@ -14,6 +16,7 @@
from apax.md import run_md
from apax.md.ase_calc import ASECalculator
from apax.utils import jax_md_reduced
from tests.conftest import load_config_and_run_training

TEST_PATH = pathlib.Path(__file__).parent.resolve()

Expand Down Expand Up @@ -178,3 +181,36 @@ def test_ase_calc(get_tmp_path):
assert "energy_uncertainty" in atoms.calc.results.keys()
assert "forces_uncertainty" in atoms.calc.results.keys()
assert "stress_uncertainty" in atoms.calc.results.keys()


@pytest.mark.parametrize("num_data", (30,))
def test_jaxmd_schedule_and_thresold(get_tmp_path, example_dataset):
model_confg_path = TEST_PATH / "config.yaml"
working_dir = get_tmp_path / str(uuid.uuid4())
data_path = get_tmp_path / "ds.extxyz"

write(data_path, example_dataset)

data_config_mods = {
"data": {
"directory": working_dir.as_posix(),
"experiment": "model",
"data_path": data_path.as_posix(),
},
}
model_config_dict = load_config_and_run_training(model_confg_path, data_config_mods)

md_confg_path = TEST_PATH / "md_config.yaml"

with open(md_confg_path.as_posix(), "r") as stream:
md_config_dict = yaml.safe_load(stream)
md_config_dict["sim_dir"] = get_tmp_path.as_posix()
md_config_dict["initial_structure"] = get_tmp_path.as_posix() + "/ds.extxyz"
md_config = MDConfig.model_validate(md_config_dict)

model_config = Config.model_validate(model_config_dict)

run_md(model_config, md_config)

traj = znh5md.IO(md_config.sim_dir + "/" + md_config.traj_name)[:]
assert len(traj) < 1000 # num steps

0 comments on commit 67a0f23

Please sign in to comment.