Skip to content

Commit

Permalink
Merge pull request #275 from apax-hub/shallow_ens
Browse files Browse the repository at this point in the history
Shallow ensembles
  • Loading branch information
M-R-Schaefer authored Jul 9, 2024
2 parents 73c8754 + b284f63 commit 6c8a91a
Show file tree
Hide file tree
Showing 13 changed files with 336 additions and 72 deletions.
6 changes: 5 additions & 1 deletion apax/bal/feature_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ def inner(ll_params):
inputs["box"],
inputs["offsets"],
)
return model.apply(full_params, R, Z, idx, box, offsets)
out = model.apply(full_params, R, Z, idx, box, offsets)
# take mean in case of shallow ensemble
# no effect for single model
out = jnp.mean(out)
return out

g_ll = jax.grad(inner)(ll_params)
g_ll = unflatten_dict(g_ll)
Expand Down
2 changes: 1 addition & 1 deletion apax/config/lr_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class CyclicCosineLR(LRSchedule, frozen=True, extra="forbid"):
Parameters
----------
period: int = 20
Length of a cycle.
Length of a cycle in epochs.
decay_factor: NonNegativeFloat = 1.0
Factor by which to decrease the LR after each cycle.
1.0 means no decrease.
Expand Down
2 changes: 2 additions & 0 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ class ModelConfig(BaseModel, extra="forbid"):

calc_stress: bool = False

n_shallow_ensemble: int = 0

descriptor_dtype: Literal["fp32", "fp64"] = "fp64"
readout_dtype: Literal["fp32", "fp64"] = "fp32"
scale_shift_dtype: Literal["fp32", "fp64"] = "fp32"
Expand Down
23 changes: 13 additions & 10 deletions apax/layers/readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,25 @@ class AtomisticReadout(nn.Module):
w_init: str = "normal"
b_init: str = "zeros"
use_ntk: bool = True
n_shallow_ensemble: int = 0
dtype: Any = jnp.float32

def setup(self):
units = [u for u in self.units] + [1]
readout_unit = [1]
if self.n_shallow_ensemble > 0:
readout_unit = [self.n_shallow_ensemble]
units = [u for u in self.units] + readout_unit
dense = []
for ii, n_hidden in enumerate(units):
dense.append(
NTKLinear(
n_hidden,
w_init=self.w_init,
b_init=self.b_init,
use_ntk=self.use_ntk,
dtype=self.dtype,
name=f"dense_{ii}",
)
layer = NTKLinear(
n_hidden,
w_init=self.w_init,
b_init=self.b_init,
use_ntk=self.use_ntk,
dtype=self.dtype,
name=f"dense_{ii}",
)
dense.append(layer)
if ii < len(units) - 1:
dense.append(swish)
self.sequential = nn.Sequential(dense, name="readout")
Expand Down
4 changes: 2 additions & 2 deletions apax/md/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self) -> None:
self.traj_path: Path
self.time_step: float

def step(self, state_and_energy, transform):
def step(self, state_and_energy, transform=None):
pass

def write(self, x=None, transform=None):
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(
def reset_buffer(self):
self.buffer = []

def step(self, state, transform):
def step(self, state, transform=None):
state, energy, nbr_kwargs = state

if self.step_counter % self.sampling_rate == 0:
Expand Down
29 changes: 23 additions & 6 deletions apax/md/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from ase import units
from ase.io import read
from flax.training import checkpoints
from jax.experimental.host_callback import barrier_wait, id_tap
from jax.experimental import io_callback
from jax.experimental.host_callback import barrier_wait
from tqdm import trange
from tqdm.contrib.logging import logging_redirect_tqdm

Expand All @@ -28,16 +29,23 @@
log = logging.getLogger(__name__)


def create_energy_fn(model, params, numbers, n_models):
def ensemble(params, R, Z, neighbor, box, offsets, perturbation=None):
def create_energy_fn(model, params, numbers, n_models, shallow=False):
def full_ensemble(params, R, Z, neighbor, box, offsets, perturbation=None):
vmodel = jax.vmap(model, (0, None, None, None, None, None, None), 0)
energies = vmodel(params, R, Z, neighbor, box, offsets, perturbation)
energy = jnp.mean(energies)
return energy

def shallow_ensemble(params, R, Z, neighbor, box, offsets, perturbation=None):
energies = model(params, R, Z, neighbor, box, offsets, perturbation)
energy = jnp.mean(energies)
return energy

if n_models > 1:
energy_fn = ensemble
if shallow:
energy_fn = shallow_ensemble
else:
energy_fn = full_ensemble
else:
energy_fn = model

Expand Down Expand Up @@ -218,7 +226,7 @@ def body_fn(i, state):
nbr_kwargs = nbr_options(state)
neighbor = neighbor.update(state.position, **nbr_kwargs)

id_tap(traj_handler.step, (state, current_energy, nbr_kwargs))
io_callback(traj_handler.step, None, (state, current_energy, nbr_kwargs))
return state, neighbor

state, neighbor = jax.lax.fori_loop(0, n_inner, body_fn, (state, neighbor))
Expand Down Expand Up @@ -375,8 +383,17 @@ def md_setup(model_config: Config, md_config: MDConfig):

_, params = restore_parameters(model_config.data.model_version_path)
params = canonicalize_energy_model_parameters(params)

n_models = 1
shallow = False
if model_config.n_models > 1:
n_models = model_config.n_models
elif model_config.model.n_shallow_ensemble > 1:
n_models = model_config.model.n_shallow_ensemble
shallow = True

energy_fn = create_energy_fn(
model.apply, params, system.atomic_numbers, model_config.n_models
model.apply, params, system.atomic_numbers, n_models, shallow
)
sim_fns = SimulationFunctions(energy_fn, shift_fn, neighbor_fn)
return system, sim_fns
Expand Down
23 changes: 17 additions & 6 deletions apax/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from apax.layers.empirical import ZBLRepulsion
from apax.layers.readout import AtomisticReadout
from apax.layers.scaling import PerElementScaleShift
from apax.model.gmnn import AtomisticModel, EnergyDerivativeModel, EnergyModel
from apax.model.gmnn import (
AtomisticModel,
EnergyDerivativeModel,
EnergyModel,
ShallowEnsembleModel,
)


class ModelBuilder:
Expand Down Expand Up @@ -79,6 +84,7 @@ def build_readout(self):
b_init=self.config["b_init"],
w_init=self.config["w_init"],
use_ntk=self.config["use_ntk"],
n_shallow_ensemble=self.config["n_shallow_ensemble"],
dtype=self.config["readout_dtype"],
)
return readout
Expand Down Expand Up @@ -149,9 +155,14 @@ def build_energy_derivative_model(
init_box=init_box,
inference_disp_fn=inference_disp_fn,
)

model = EnergyDerivativeModel(
energy_model,
calc_stress=self.config["calc_stress"],
)
if self.config["n_shallow_ensemble"] > 0:
model = ShallowEnsembleModel(
energy_model,
calc_stress=self.config["calc_stress"],
)
else:
model = EnergyDerivativeModel(
energy_model,
calc_stress=self.config["calc_stress"],
)
return model
86 changes: 83 additions & 3 deletions apax/model/gmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,27 @@ def __call__(
dr_vec += offsets

# Model Core
# shape Natoms
# shape shallow ens: Natoms x Nensemble
atomic_energies = self.atomistic_model(dr_vec, Z, idx)
total_energy = fp64_sum(atomic_energies)

# check for shallow ensemble
is_shallow_ensemble = atomic_energies.shape[1] > 1
if is_shallow_ensemble:
total_energies_ensemble = fp64_sum(atomic_energies, axis=0)
# shape Nensemble
result = total_energies_ensemble
else:
# shape ()
result = fp64_sum(atomic_energies)

# Corrections
for correction in self.corrections:
energy_correction = correction(dr_vec, Z, idx)
total_energy = total_energy + energy_correction
result = result + energy_correction

# TODO think of nice abstraction for predicting additional properties
return total_energy
return result


class EnergyDerivativeModel(nn.Module):
Expand Down Expand Up @@ -165,3 +176,72 @@ def __call__(
prediction["stress"] = stress

return prediction


def make_mean_energy_fn(energy_fn):
def mean_energy_fn(
R: Array,
Z: Array,
neighbor: Union[partition.NeighborList, Array],
box,
offsets,
perturbation=None,
):
e_ens = energy_fn(R, Z, neighbor, box, offsets, perturbation)
E_mean = jnp.mean(e_ens)
return E_mean

return mean_energy_fn


class ShallowEnsembleModel(nn.Module):
"""Transforms an EnergyModel into one that also predicts derivatives the total energy.
Can calculate forces and stress tensors.
"""

energy_model: EnergyModel = EnergyModel()
calc_stress: bool = False
force_variance: bool = True

def __call__(
self,
R: Array,
Z: Array,
neighbor: Union[partition.NeighborList, Array],
box,
offsets,
):
energy_ens = self.energy_model(R, Z, neighbor, box, offsets)
mean_energy_fn = make_mean_energy_fn(self.energy_model)

n_ens = energy_ens.shape[0]
divisor = 1 / (n_ens - 1)

energy_mean = jnp.mean(energy_ens)
energy_variance = divisor * fp64_sum((energy_ens - energy_mean) ** 2)

prediction = {
"energy": energy_mean,
"energy_ensemble": energy_ens,
"energy_uncertainty": jnp.sqrt(energy_variance),
}

if self.force_variance:
forces_ens = -jax.jacrev(self.energy_model)(R, Z, neighbor, box, offsets)
forces_mean = jnp.mean(forces_ens, axis=0)
forces_variance = divisor * fp64_sum((forces_ens - forces_mean) ** 2, axis=0)

prediction["forces"] = forces_mean
prediction["forces_uncertainty"] = jnp.sqrt(forces_variance)
prediction["forces_ensemble"] = forces_ens
else:
forces_mean = -jax.grad(mean_energy_fn)(R, Z, neighbor, box, offsets)
prediction["forces"] = forces_mean

if self.calc_stress:
stress = stress_times_vol(
mean_energy_fn, R, box, Z=Z, neighbor=neighbor, offsets=offsets
)
prediction["stress"] = stress

return prediction
4 changes: 2 additions & 2 deletions apax/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .md import ApaxJaxMD
from .model import Apax, ApaxEnsemble
from .model import Apax, ApaxEnsemble, ApaxImport

__all__ = ["Apax", "ApaxEnsemble", "ApaxJaxMD"]
__all__ = ["Apax", "ApaxEnsemble", "ApaxJaxMD", "ApaxImport"]

try:
from .analysis import ApaxBatchPrediction # noqa: F401
Expand Down
54 changes: 54 additions & 0 deletions apax/nodes/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,57 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
transformations=transformations,
)
return calc


class ApaxImport(zntrack.Node):
"""Parallel apax model ensemble in ASE.
Parameters
----------
models: list
List of `ApaxModel` nodes to ensemble.
nl_skin: float
Neighborlist skin.
transformations: dict
Key-parameter dict with function transformations applied
to the model function within the ASE calculator.
See the apax documentation for available methods.
"""

config: str = zntrack.params_path()
nl_skin: float = zntrack.params(0.5)
transformations: dict[str, dict] = zntrack.params(None)

_parameter: dict = None

def _post_load_(self) -> None:
self._handle_parameter_file()

def _handle_parameter_file(self):
with self.state.use_tmp_path():
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())

def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
"""Property to return a model specific ase calculator object.
Returns
-------
calc:
ase calculator object
"""

directory = self._parameter["data"]["directory"]
exp = self._parameter["data"]["experiment"]
model_dir = directory + "/" + exp

transformations = []
if self.transformations:
for transform, params in self.transformations.items():
transformations.append(available_transformations[transform](**params))

calc = ASECalculator(
model_dir,
dr=self.nl_skin,
transformations=transformations,
)
return calc
Loading

0 comments on commit 6c8a91a

Please sign in to comment.