Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shallow ensembles #275

Merged
merged 38 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
9cd4f6c
correct ordering of labels and predictions
M-R-Schaefer May 4, 2024
b0d6c9a
added crps and nll loss functions
M-R-Schaefer May 4, 2024
084d2af
added shallow ens to config
M-R-Schaefer May 4, 2024
8e1b4d0
added possibility for shallow ens in readout
M-R-Schaefer May 4, 2024
359f037
added preliminary shallow ens class
M-R-Schaefer May 4, 2024
705f109
added shallow ens to builder
M-R-Schaefer May 4, 2024
1520ff5
fixed tests
M-R-Schaefer May 4, 2024
48dc3ed
removed debug print statements
M-R-Schaefer May 7, 2024
df7bee5
correctly pass n_shallow ens in builder
M-R-Schaefer May 7, 2024
c786668
remove debug print statements
M-R-Schaefer May 7, 2024
c1f734a
docstring
M-R-Schaefer May 7, 2024
6a83516
return ensemble predictions from model
M-R-Schaefer May 8, 2024
b9a9ca1
Merge branch 'dev' into shallow_ens
M-R-Schaefer May 8, 2024
6e25069
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2024
d22dc6c
implemented switch between mean force and ensemble force computation
M-R-Schaefer May 15, 2024
4fcdd8d
Merge branch 'shallow_ens' of https://github.com/apax-hub/apax into s…
M-R-Schaefer May 15, 2024
af81217
implemented stress prediction for shallow ens
M-R-Schaefer May 17, 2024
df19d04
compatibility with BAL
M-R-Schaefer May 17, 2024
7f51cdd
loss fn divisor
M-R-Schaefer May 17, 2024
d17a8bf
added ApaxImport node
M-R-Schaefer May 22, 2024
5673431
made import node compatible with jaxmd node
M-R-Schaefer Jun 1, 2024
ea64e13
updated jaxmd to use new callback api
M-R-Schaefer Jun 1, 2024
8fd0ca1
deleted comment
M-R-Schaefer Jun 5, 2024
9f4c901
Merge branch 'dev' into shallow_ens
M-R-Schaefer Jun 5, 2024
2150bcc
remove old import
M-R-Schaefer Jun 5, 2024
9b0963e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 5, 2024
e8e0182
made shallow ensembles compatible with jaxmd
M-R-Schaefer Jun 5, 2024
ded97f6
Merge branch 'shallow_ens' of https://github.com/apax-hub/apax into s…
M-R-Schaefer Jun 5, 2024
cb65848
fixed bug for using jaxmd without ensembles
M-R-Schaefer Jun 5, 2024
c7e3c58
account for non shallow ensemble
M-R-Schaefer Jun 5, 2024
76a867f
Merge branch 'dev' into shallow_ens
M-R-Schaefer Jun 8, 2024
8427281
Merge branch 'dev' into shallow_ens
M-R-Schaefer Jun 12, 2024
6a6ad25
consistent use of std as uncertainty, not variance
M-R-Schaefer Jun 12, 2024
8a86bee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2024
8de78f9
linting
M-R-Schaefer Jun 12, 2024
b7e2b42
Merge branch 'shallow_ens' of https://github.com/apax-hub/apax into s…
M-R-Schaefer Jun 12, 2024
fdf859a
Merge branch 'dev' into shallow_ens
M-R-Schaefer Jun 20, 2024
b284f63
Merge branch 'dev' into shallow_ens
M-R-Schaefer Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion apax/bal/feature_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ def inner(ll_params):
inputs["box"],
inputs["offsets"],
)
return model.apply(full_params, R, Z, idx, box, offsets)
out = model.apply(full_params, R, Z, idx, box, offsets)
# take mean in case of shallow ensemble
# no effect for single model
out = jnp.mean(out)
return out

g_ll = jax.grad(inner)(ll_params)
g_ll = unflatten_dict(g_ll)
Expand Down
2 changes: 1 addition & 1 deletion apax/config/lr_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class CyclicCosineLR(LRSchedule, frozen=True, extra="forbid"):
Parameters
----------
period: int = 20
Length of a cycle.
Length of a cycle in epochs.
decay_factor: NonNegativeFloat = 1.0
Factor by which to decrease the LR after each cycle.
1.0 means no decrease.
Expand Down
2 changes: 2 additions & 0 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ class ModelConfig(BaseModel, extra="forbid"):

calc_stress: bool = False

n_shallow_ensemble: int = 0

descriptor_dtype: Literal["fp32", "fp64"] = "fp64"
readout_dtype: Literal["fp32", "fp64"] = "fp32"
scale_shift_dtype: Literal["fp32", "fp64"] = "fp32"
Expand Down
23 changes: 13 additions & 10 deletions apax/layers/readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,25 @@ class AtomisticReadout(nn.Module):
w_init: str = "normal"
b_init: str = "zeros"
use_ntk: bool = True
n_shallow_ensemble: int = 0
dtype: Any = jnp.float32

def setup(self):
units = [u for u in self.units] + [1]
readout_unit = [1]
if self.n_shallow_ensemble > 0:
readout_unit = [self.n_shallow_ensemble]
units = [u for u in self.units] + readout_unit
dense = []
for ii, n_hidden in enumerate(units):
dense.append(
NTKLinear(
n_hidden,
w_init=self.w_init,
b_init=self.b_init,
use_ntk=self.use_ntk,
dtype=self.dtype,
name=f"dense_{ii}",
)
layer = NTKLinear(
n_hidden,
w_init=self.w_init,
b_init=self.b_init,
use_ntk=self.use_ntk,
dtype=self.dtype,
name=f"dense_{ii}",
)
dense.append(layer)
if ii < len(units) - 1:
dense.append(swish)
self.sequential = nn.Sequential(dense, name="readout")
Expand Down
4 changes: 2 additions & 2 deletions apax/md/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self) -> None:
self.traj_path: Path
self.time_step: float

def step(self, state_and_energy, transform):
def step(self, state_and_energy, transform=None):
pass

def write(self, x=None, transform=None):
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(
def reset_buffer(self):
self.buffer = []

def step(self, state, transform):
def step(self, state, transform=None):
state, energy, nbr_kwargs = state

if self.step_counter % self.sampling_rate == 0:
Expand Down
29 changes: 23 additions & 6 deletions apax/md/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from ase import units
from ase.io import read
from flax.training import checkpoints
from jax.experimental.host_callback import barrier_wait, id_tap
from jax.experimental import io_callback
from jax.experimental.host_callback import barrier_wait
from tqdm import trange
from tqdm.contrib.logging import logging_redirect_tqdm

Expand All @@ -28,16 +29,23 @@
log = logging.getLogger(__name__)


def create_energy_fn(model, params, numbers, n_models):
def ensemble(params, R, Z, neighbor, box, offsets, perturbation=None):
def create_energy_fn(model, params, numbers, n_models, shallow=False):
def full_ensemble(params, R, Z, neighbor, box, offsets, perturbation=None):
vmodel = jax.vmap(model, (0, None, None, None, None, None, None), 0)
energies = vmodel(params, R, Z, neighbor, box, offsets, perturbation)
energy = jnp.mean(energies)
return energy

def shallow_ensemble(params, R, Z, neighbor, box, offsets, perturbation=None):
energies = model(params, R, Z, neighbor, box, offsets, perturbation)
energy = jnp.mean(energies)
return energy

if n_models > 1:
energy_fn = ensemble
if shallow:
energy_fn = shallow_ensemble
else:
energy_fn = full_ensemble
else:
energy_fn = model

Expand Down Expand Up @@ -218,7 +226,7 @@ def body_fn(i, state):
nbr_kwargs = nbr_options(state)
neighbor = neighbor.update(state.position, **nbr_kwargs)

id_tap(traj_handler.step, (state, current_energy, nbr_kwargs))
io_callback(traj_handler.step, None, (state, current_energy, nbr_kwargs))
return state, neighbor

state, neighbor = jax.lax.fori_loop(0, n_inner, body_fn, (state, neighbor))
Expand Down Expand Up @@ -375,8 +383,17 @@ def md_setup(model_config: Config, md_config: MDConfig):

_, params = restore_parameters(model_config.data.model_version_path)
params = canonicalize_energy_model_parameters(params)

n_models = 1
shallow = False
if model_config.n_models > 1:
n_models = model_config.n_models
elif model_config.model.n_shallow_ensemble > 1:
n_models = model_config.model.n_shallow_ensemble
shallow = True

energy_fn = create_energy_fn(
model.apply, params, system.atomic_numbers, model_config.n_models
model.apply, params, system.atomic_numbers, n_models, shallow
)
sim_fns = SimulationFunctions(energy_fn, shift_fn, neighbor_fn)
return system, sim_fns
Expand Down
23 changes: 17 additions & 6 deletions apax/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from apax.layers.empirical import ZBLRepulsion
from apax.layers.readout import AtomisticReadout
from apax.layers.scaling import PerElementScaleShift
from apax.model.gmnn import AtomisticModel, EnergyDerivativeModel, EnergyModel
from apax.model.gmnn import (
AtomisticModel,
EnergyDerivativeModel,
EnergyModel,
ShallowEnsembleModel,
)


class ModelBuilder:
Expand Down Expand Up @@ -79,6 +84,7 @@ def build_readout(self):
b_init=self.config["b_init"],
w_init=self.config["w_init"],
use_ntk=self.config["use_ntk"],
n_shallow_ensemble=self.config["n_shallow_ensemble"],
dtype=self.config["readout_dtype"],
)
return readout
Expand Down Expand Up @@ -149,9 +155,14 @@ def build_energy_derivative_model(
init_box=init_box,
inference_disp_fn=inference_disp_fn,
)

model = EnergyDerivativeModel(
energy_model,
calc_stress=self.config["calc_stress"],
)
if self.config["n_shallow_ensemble"] > 0:
model = ShallowEnsembleModel(
energy_model,
calc_stress=self.config["calc_stress"],
)
else:
model = EnergyDerivativeModel(
energy_model,
calc_stress=self.config["calc_stress"],
)
return model
86 changes: 83 additions & 3 deletions apax/model/gmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,27 @@ def __call__(
dr_vec += offsets

# Model Core
# shape Natoms
# shape shallow ens: Natoms x Nensemble
atomic_energies = self.atomistic_model(dr_vec, Z, idx)
total_energy = fp64_sum(atomic_energies)

# check for shallow ensemble
is_shallow_ensemble = atomic_energies.shape[1] > 1
if is_shallow_ensemble:
total_energies_ensemble = fp64_sum(atomic_energies, axis=0)
# shape Nensemble
result = total_energies_ensemble
else:
# shape ()
result = fp64_sum(atomic_energies)

# Corrections
for correction in self.corrections:
energy_correction = correction(dr_vec, Z, idx)
total_energy = total_energy + energy_correction
result = result + energy_correction

# TODO think of nice abstraction for predicting additional properties
return total_energy
return result


class EnergyDerivativeModel(nn.Module):
Expand Down Expand Up @@ -165,3 +176,72 @@ def __call__(
prediction["stress"] = stress

return prediction


def make_mean_energy_fn(energy_fn):
def mean_energy_fn(
R: Array,
Z: Array,
neighbor: Union[partition.NeighborList, Array],
box,
offsets,
perturbation=None,
):
e_ens = energy_fn(R, Z, neighbor, box, offsets, perturbation)
E_mean = jnp.mean(e_ens)
return E_mean

return mean_energy_fn


class ShallowEnsembleModel(nn.Module):
"""Transforms an EnergyModel into one that also predicts derivatives the total energy.
Can calculate forces and stress tensors.
"""

energy_model: EnergyModel = EnergyModel()
calc_stress: bool = False
force_variance: bool = True

def __call__(
self,
R: Array,
Z: Array,
neighbor: Union[partition.NeighborList, Array],
box,
offsets,
):
energy_ens = self.energy_model(R, Z, neighbor, box, offsets)
mean_energy_fn = make_mean_energy_fn(self.energy_model)

n_ens = energy_ens.shape[0]
divisor = 1 / (n_ens - 1)

energy_mean = jnp.mean(energy_ens)
energy_variance = divisor * fp64_sum((energy_ens - energy_mean) ** 2)

prediction = {
"energy": energy_mean,
"energy_ensemble": energy_ens,
"energy_uncertainty": jnp.sqrt(energy_variance),
}

if self.force_variance:
forces_ens = -jax.jacrev(self.energy_model)(R, Z, neighbor, box, offsets)
forces_mean = jnp.mean(forces_ens, axis=0)
forces_variance = divisor * fp64_sum((forces_ens - forces_mean) ** 2, axis=0)

prediction["forces"] = forces_mean
prediction["forces_uncertainty"] = jnp.sqrt(forces_variance)
prediction["forces_ensemble"] = forces_ens
else:
forces_mean = -jax.grad(mean_energy_fn)(R, Z, neighbor, box, offsets)
prediction["forces"] = forces_mean

if self.calc_stress:
stress = stress_times_vol(
mean_energy_fn, R, box, Z=Z, neighbor=neighbor, offsets=offsets
)
prediction["stress"] = stress

return prediction
4 changes: 2 additions & 2 deletions apax/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .md import ApaxJaxMD
from .model import Apax, ApaxEnsemble
from .model import Apax, ApaxEnsemble, ApaxImport

__all__ = ["Apax", "ApaxEnsemble", "ApaxJaxMD"]
__all__ = ["Apax", "ApaxEnsemble", "ApaxJaxMD", "ApaxImport"]

try:
from .analysis import ApaxBatchPrediction # noqa: F401
Expand Down
54 changes: 54 additions & 0 deletions apax/nodes/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,57 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
transformations=transformations,
)
return calc


class ApaxImport(zntrack.Node):
"""Parallel apax model ensemble in ASE.
Parameters
----------
models: list
List of `ApaxModel` nodes to ensemble.
nl_skin: float
Neighborlist skin.
transformations: dict
Key-parameter dict with function transformations applied
to the model function within the ASE calculator.
See the apax documentation for available methods.
"""

config: str = zntrack.params_path()
nl_skin: float = zntrack.params(0.5)
transformations: dict[str, dict] = zntrack.params(None)

_parameter: dict = None

def _post_load_(self) -> None:
self._handle_parameter_file()

def _handle_parameter_file(self):
with self.state.use_tmp_path():
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())

def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
"""Property to return a model specific ase calculator object.
Returns
-------
calc:
ase calculator object
"""

directory = self._parameter["data"]["directory"]
exp = self._parameter["data"]["experiment"]
model_dir = directory + "/" + exp

transformations = []
if self.transformations:
for transform, params in self.transformations.items():
transformations.append(available_transformations[transform](**params))

calc = ASECalculator(
model_dir,
dr=self.nl_skin,
transformations=transformations,
)
return calc
Loading
Loading