Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added uncertainty checks to jax md #328

Merged
merged 10 commits into from
Aug 26, 2024
19 changes: 19 additions & 0 deletions apax/config/md_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import yaml
from pydantic import BaseModel, Field, NonNegativeInt, PositiveFloat, PositiveInt
from typing_extensions import Annotated


class ConstantTempSchedule(BaseModel, extra="forbid"):
Expand Down Expand Up @@ -190,6 +191,22 @@ class NPTOptions(NVTOptions, extra="forbid"):
barostat_chain: NHCOptions = NHCOptions(tau=1000)


class EnergyUncertaintyCheck(BaseModel, extra="forbid"):
name: Literal["energy_uncertainty"] = "energy_uncertainty"
threshold: PositiveFloat
per_atom: bool = True


class ForcesUncertaintyCheck(BaseModel, extra="forbid"):
name: Literal["forces_uncertainty"] = "forces_uncertainty"
threshold: PositiveFloat


DynamicsCheck = Annotated[
Union[EnergyUncertaintyCheck, ForcesUncertaintyCheck], Field(discriminator="name")
]


class MDConfig(BaseModel, frozen=True, extra="forbid"):
"""
Configuration for a NHC molecular dynamics simulation.
Expand Down Expand Up @@ -247,6 +264,8 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
dr_threshold: PositiveFloat = 0.5
extra_capacity: NonNegativeInt = 0

dynamics_checks: list[DynamicsCheck] = []

initial_structure: str
load_momenta: bool = False
sim_dir: str = "."
Expand Down
48 changes: 48 additions & 0 deletions apax/md/dynamics_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Literal, Union

import jax.numpy as jnp
from pydantic import BaseModel, TypeAdapter


class DynamicsCheckBase(BaseModel):
def check(self, predictions):
pass


class EnergyUncertaintyCheck(DynamicsCheckBase, extra="forbid"):
name: Literal["energy_uncertainty"] = "energy_uncertainty"
threshold: float
per_atom: bool = True

def check(self, predictions):
if "energy_uncertainty" not in predictions.keys():
m = "No energy uncertainty found. Are you using a model ensemble?"
raise ValueError(m)

energy_uncertainty = predictions["energy_uncertainty"]
if self.per_atom:
n_atoms = predictions["forces"].shape[0]
energy_uncertainty = energy_uncertainty / n_atoms

check_passed = jnp.all(energy_uncertainty < self.threshold)
return check_passed


class ForceUncertaintyCheck(DynamicsCheckBase, extra="forbid"):
name: Literal["forces_uncertainty"] = "forces_uncertainty"
threshold: float

def check(self, predictions):
if "forces_uncertainty" not in predictions.keys():
m = "No force uncertainties found. Are you using a model ensemble?"
raise ValueError(m)

forces_uncertainty = predictions["forces_uncertainty"]

check_passed = jnp.all(forces_uncertainty < self.threshold)
return check_passed


DynamicsChecks = TypeAdapter(
Union[EnergyUncertaintyCheck, ForceUncertaintyCheck]
).validate_python
57 changes: 41 additions & 16 deletions apax/md/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from apax.config import Config, MDConfig, parse_config
from apax.config.md_config import Integrator
from apax.md.ase_calc import make_ensemble, maybe_vmap
from apax.md.dynamics_checks import DynamicsCheckBase, DynamicsChecks
from apax.md.io import H5TrajHandler, TrajHandler, truncate_trajectory_to_checkpoint
from apax.md.md_checkpoint import load_md_state
from apax.md.sim_utils import SimulationFunctions, System
Expand Down Expand Up @@ -133,19 +134,25 @@ def handle_checkpoints(state, step, system, load_momenta, ckpt_dir, should_load_
return state, step


def create_evaluation_functions(aux_fn, positions, Z, neighbor, box):
def create_evaluation_functions(aux_fn, positions, Z, neighbor, box, dynamics_checks):
offsets = jnp.zeros((neighbor.idx.shape[1], 3))

def on_eval(positions, neighbor, box):
predictions = aux_fn(positions, Z, neighbor, box, offsets)
return predictions
all_checks_passed = True

for check in dynamics_checks:
check_passed = check.check(predictions)
all_checks_passed = all_checks_passed & check_passed
return predictions, all_checks_passed

predictions = aux_fn(positions, Z, neighbor, box, offsets)
dummpy_preds = jax.tree_map(lambda x: jnp.zeros_like(x), predictions)

def no_eval(positions, neighbor, box):
predictions = dummpy_preds
return predictions
all_checks_passed = True
return predictions, all_checks_passed

return on_eval, no_eval

Expand All @@ -163,6 +170,7 @@ def run_sim(
restart: bool = True,
checkpoint_interval: int = 50_000,
traj_handler: TrajHandler = TrajHandler(),
dynamics_checks: list[DynamicsCheckBase] = [],
disable_pbar: bool = False,
):
"""
Expand All @@ -188,7 +196,6 @@ def run_sim(
sim_dir : Path
Directory where the trajectory and simulation checkpoints will be saved.
"""
energy_fn = sim_fns.energy_fn
neighbor_fn = sim_fns.neighbor_fn
ckpt_dir = sim_dir / "ckpts"
ckpt_dir.mkdir(exist_ok=True)
Expand Down Expand Up @@ -231,12 +238,13 @@ def run_sim(
system.atomic_numbers,
neighbor,
system.box,
dynamics_checks,
)

@jax.jit
def sim(state, outer_step, neighbor): # TODO make more modular
def body_fn(i, state):
state, outer_step, neighbor = state
state, outer_step, neighbor, all_checks_passed = state
step = i + outer_step * n_inner

apply_fn_kwargs = {}
Expand All @@ -253,22 +261,25 @@ def body_fn(i, state):
neighbor = neighbor.update(state.position, **nbr_kwargs)

condition = step % sampling_rate == 0
predictions = jax.lax.cond(
predictions, check_passed = jax.lax.cond(
condition, on_eval, no_eval, state.position, neighbor, box
)

all_checks_passed = all_checks_passed & check_passed

# maybe move this to on_eval
io_callback(traj_handler.step, None, (state, predictions, nbr_kwargs))
return state, outer_step, neighbor
return state, outer_step, neighbor, all_checks_passed

state, outer_step, neighbor = jax.lax.fori_loop(
0, n_inner, body_fn, (state, outer_step, neighbor)
all_checks_passed = True
state, outer_step, neighbor, all_checks_passed = jax.lax.fori_loop(
0, n_inner, body_fn, (state, outer_step, neighbor, all_checks_passed)
)
current_temperature = (
quantity.temperature(velocity=state.velocity, mass=state.mass) / units.kB
)

return state, neighbor, current_temperature
return state, neighbor, current_temperature, all_checks_passed

start = time.time()
total_sim_time = n_steps * ensemble.dt / 1000
Expand All @@ -285,7 +296,17 @@ def body_fn(i, state):
leave=True,
)
while step < n_outer:
new_state, neighbor, current_temperature = sim(state, step, neighbor)
new_state, neighbor, current_temperature, all_checks_passed = sim(
state, step, neighbor
)

if np.any(np.isnan(state.position)) or np.any(np.isnan(state.velocity)):
raise ValueError(f"NaN encountered, simulation aborted after {step+1} steps.")

if not all_checks_passed:
with logging_redirect_tqdm():
log.critical(f"One or more dynamics checks failed at step: {step+1}")
break

if neighbor.did_buffer_overflow:
with logging_redirect_tqdm():
Expand All @@ -298,11 +319,6 @@ def body_fn(i, state):
state = new_state
step += 1

if np.any(np.isnan(state.position)) or np.any(np.isnan(state.velocity)):
raise ValueError(
f"NaN encountered, simulation aborted after {step} steps."
)

if step % checkpoint_interval == 0:
with logging_redirect_tqdm():
current_sim_time = step * n_inner * ensemble.dt / 1000
Expand Down Expand Up @@ -488,6 +504,14 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"):
traj_path = sim_dir / md_config.traj_name

system, sim_fns = md_setup(model_config, md_config)

dynamics_checks = []
if md_config.dynamics_checks:
check_list = [
DynamicsChecks(check.model_dump()) for check in md_config.dynamics_checks
]
dynamics_checks.extend(check_list)

n_steps = int(np.ceil(md_config.duration / md_config.ensemble.dt))

traj_handler = H5TrajHandler(
Expand All @@ -512,4 +536,5 @@ def run_md(model_config: Config, md_config: MDConfig, log_level="error"):
checkpoint_interval=md_config.checkpoint_interval,
sim_dir=sim_dir,
traj_handler=traj_handler,
dynamics_checks=dynamics_checks,
)
Loading