From 9cd4f6c744c88561be77760baebed0492e0e55f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sat, 4 May 2024 10:07:45 +0200 Subject: [PATCH 01/29] correct ordering of labels and predictions --- apax/train/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 3da237d4..e776f61c 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -236,7 +236,7 @@ def calc_loss(params, inputs, labels, loss_fn, model): inputs["offsets"], ) predictions = model(params, R, Z, idx, box, offsets) - loss = loss_fn(inputs, labels, predictions) + loss = loss_fn(inputs, predictions, labels) return loss, predictions From b0d6c9a9b09bdb7da78b50daa41a1bdabe765ce2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sat, 4 May 2024 10:08:54 +0200 Subject: [PATCH 02/29] added crps and nll loss functions --- apax/train/loss.py | 93 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 86 insertions(+), 7 deletions(-) diff --git a/apax/train/loss.py b/apax/train/loss.py index c2e575d3..36533c38 100644 --- a/apax/train/loss.py +++ b/apax/train/loss.py @@ -2,28 +2,41 @@ from typing import List import einops +import jax import jax.numpy as jnp +import jax.scipy as jsc +import numpy as np from apax.utils.math import normed_dotp def weighted_squared_error( - label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {} + label: jnp.array, + prediction: jnp.array, + name, + divisor: float = 1.0, + parameters: dict = {}, ) -> jnp.array: """ Squared error function that allows weighting of individual contributions by the number of atoms in the system. """ + label, prediction = label[name], prediction[name] return (label - prediction) ** 2 / divisor def weighted_huber_loss( - label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {} + label: jnp.array, + prediction: jnp.array, + name, + divisor: float = 1.0, + parameters: dict = {}, ) -> jnp.array: """ Huber loss function that allows weighting of individual contributions by the number of atoms in the system. """ + label, prediction = label[name], prediction[name] if "delta" not in parameters.keys(): raise KeyError("Huber loss function requires 'delta' parameter") delta = parameters["delta"] @@ -32,23 +45,81 @@ def weighted_huber_loss( return loss / divisor +def crps_loss( + label: jax.Array, + prediction: jax.Array, + name, + divisor: float = 1.0, + parameters: dict = {}, +) -> jax.Array: + """Computes the CRPS of a gaussian distribution given + means, targets and vars (uncertainty estimate) + """ + label, means, vars = label[name], prediction[name], prediction[name + "_uncertainty"] + + sigma = jnp.sqrt(vars) + sigma = jnp.clip(sigma, a_min=1e-6) + + norm_x = (label - means) / sigma + cdf = 0.5 * (1 + jsc.special.erf(norm_x / jnp.sqrt(2))) + + normalization = 1 / (jnp.sqrt(2.0 * np.pi)) + + pdf = normalization * jnp.exp(-(norm_x**2) / 2.0) + + crps = sigma * (norm_x * (2 * cdf - 1) + 2 * pdf - 1 / jnp.sqrt(np.pi)) + + return crps # / divisor # TODO how to account for differently sized systems? + + +def nll_loss( + label: jax.Array, + prediction: jax.Array, + name, + divisor: float = 1.0, + parameters: dict = {}, +) -> jax.Array: + """Computes the gaussian NLL loss means, targets and vars (uncertainty estimate) + """ + label, means, vars = label[name], prediction[name], prediction[name + "_uncertainty"] + eps = 1e-4 + sigma = jnp.sqrt(vars) + sigma = jnp.clip(sigma, a_min=1e-4) + + x1 = jnp.log(jnp.maximum(vars, eps)) + x2 = (means - label) ** 2 / (jnp.maximum(vars, eps)) + nll = 0.5 * (x1 + x2) + + return nll + + def force_angle_loss( - label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {} + label: jnp.array, + prediction: jnp.array, + name, + divisor: float = 1.0, + parameters: dict = {}, ) -> jnp.array: """ Consine similarity loss function. Contributions are summed in `Loss`. """ + label, prediction = label[name], prediction[name] dotp = normed_dotp(label, prediction) return (1.0 - dotp) / divisor def force_angle_div_force_label( - label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {} + label: jnp.array, + prediction: jnp.array, + name, + divisor: float = 1.0, + parameters: dict = {}, ): """ Consine similarity loss function weighted by the norm of the force labels. Contributions are summed in `Loss`. """ + label, prediction = label[name], prediction[name] dotp = normed_dotp(label, prediction) F_0_norm = jnp.linalg.norm(label, ord=2, axis=2, keepdims=False) loss = jnp.where(F_0_norm > 1e-6, (1.0 - dotp) / F_0_norm, jnp.zeros_like(dotp)) @@ -56,18 +127,24 @@ def force_angle_div_force_label( def force_angle_exponential_weight( - label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {} + label: jnp.array, + prediction: jnp.array, + name, + divisor: float = 1.0, + parameters: dict = {}, ) -> jnp.array: """ Consine similarity loss function exponentially scaled by the norm of the force labels. Contributions are summed in `Loss`. """ + label, prediction = label[name], prediction[name] dotp = normed_dotp(label, prediction) F_0_norm = jnp.linalg.norm(label, ord=2, axis=2, keepdims=False) return (1.0 - dotp) * jnp.exp(-F_0_norm) / divisor -def stress_tril(label, prediction, divisor=1.0, parameters: dict = {}): +def stress_tril(label, prediction, name, divisor=1.0, parameters: dict = {}): + label, prediction = label[name], prediction[name] idxs = jnp.tril_indices(3) label_tril = label[:, idxs[0], idxs[1]] prediction_tril = prediction[:, idxs[0], idxs[1]] @@ -81,6 +158,8 @@ def stress_tril(label, prediction, divisor=1.0, parameters: dict = {}): "cosine_sim_div_magnitude": force_angle_div_force_label, "cosine_sim_exp_magnitude": force_angle_exponential_weight, "tril": stress_tril, + "crps": crps_loss, + "nll": nll_loss, } @@ -111,7 +190,7 @@ def __call__(self, inputs: dict, prediction: dict, label: dict) -> float: # TODO we may want to insert an additional `mask` argument for this method divisor = self.determine_divisor(inputs["n_atoms"]) batch_losses = self.loss_fn( - label[self.name], prediction[self.name], divisor, self.parameters + label, prediction, self.name, divisor, self.parameters ) loss = self.weight * jnp.sum(jnp.mean(batch_losses, axis=0)) return loss From 084d2aff43e4d27657796e36390ac93d51973237 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sat, 4 May 2024 10:09:26 +0200 Subject: [PATCH 03/29] added shallow ens to config --- apax/config/train_config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/apax/config/train_config.py b/apax/config/train_config.py index d7eb5da3..e5ee4a5f 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -182,6 +182,8 @@ class ModelConfig(BaseModel, extra="forbid"): calc_stress: bool = False + n_shallow_ensemble: int = 0 + descriptor_dtype: Literal["fp32", "fp64"] = "fp64" readout_dtype: Literal["fp32", "fp64"] = "fp32" scale_shift_dtype: Literal["fp32", "fp64"] = "fp32" From 8e1b4d0307ef691bbad5cdbc0311e2c29465de41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sat, 4 May 2024 10:10:01 +0200 Subject: [PATCH 04/29] added possibility for shallow ens in readout --- apax/layers/readout.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/apax/layers/readout.py b/apax/layers/readout.py index 41955cd0..7cc7f814 100644 --- a/apax/layers/readout.py +++ b/apax/layers/readout.py @@ -12,17 +12,22 @@ class AtomisticReadout(nn.Module): units: List[int] = field(default_factory=lambda: [512, 512]) activation_fn: Callable = swish b_init: str = "normal" + n_shallow_ensemble: int = 0 dtype: Any = jnp.float32 def setup(self): - units = [u for u in self.units] + [1] + readout_unit = [1] + if self.n_shallow_ensemble > 0: + readout_unit = [self.n_shallow_ensemble] + self._n_shallow_ensemble = self.n_shallow_ensemble + + units = [u for u in self.units] + readout_unit dense = [] for ii, n_hidden in enumerate(units): - dense.append( - NTKLinear( - n_hidden, b_init=self.b_init, dtype=self.dtype, name=f"dense_{ii}" - ) + layer = NTKLinear( + n_hidden, b_init=self.b_init, dtype=self.dtype, name=f"dense_{ii}" ) + dense.append(layer) if ii < len(units) - 1: dense.append(swish) self.sequential = nn.Sequential(dense, name="readout") From 359f0376425eabc9e100826ccd4f82acb39c1b30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sat, 4 May 2024 10:28:54 +0200 Subject: [PATCH 05/29] added preliminary shallow ens class --- apax/model/gmnn.py | 62 ++++++++++++++++++++++++++++++++++++++-- apax/train/parameters.py | 6 ++-- 2 files changed, 62 insertions(+), 6 deletions(-) diff --git a/apax/model/gmnn.py b/apax/model/gmnn.py index 73aa8da0..bf85d189 100644 --- a/apax/model/gmnn.py +++ b/apax/model/gmnn.py @@ -122,17 +122,30 @@ def __call__( dr_vec = self.displacement(Rj, Ri, perturbation, box) dr_vec += offsets + aux_prediction = {} + # Model Core + # shape Natoms + # shape shallow ens: Natoms x Nensemble atomic_energies = self.atomistic_model(dr_vec, Z, idx) - total_energy = fp64_sum(atomic_energies) + + # check for shallow ensemble + is_shallow_ensemble = len(atomic_energies.shape) > 1 + if is_shallow_ensemble: + total_energies_ensemble = fp64_sum(atomic_energies, axis=0) + # shape Nensemble + result = total_energies_ensemble + else: + # shape () + result = fp64_sum(atomic_energies) # Corrections for correction in self.corrections: energy_correction = correction(dr_vec, Z, idx) - total_energy = total_energy + energy_correction + result = result + energy_correction # TODO think of nice abstraction for predicting additional properties - return total_energy + return result class EnergyDerivativeModel(nn.Module): @@ -165,3 +178,46 @@ def __call__( prediction["stress"] = stress return prediction + + +class ShallowEnsembleModel(nn.Module): + """Transforms an EnergyModel into one that also predicts derivatives the total energy. + Can calculate forces and stress tensors. + """ + + energy_model: EnergyModel = EnergyModel() + calc_stress: bool = False + + def __call__( + self, + R: Array, + Z: Array, + neighbor: Union[partition.NeighborList, Array], + box, + offsets, + ): + energy_ens = self.energy_model(R, Z, neighbor, box, offsets) + # forces_ens = - jax.jacrev(self.energy_model)( + # R, Z, neighbor, box, offsets + # ) + forces_mean = -jax.grad(lambda *args: jnp.mean(self.energy_model(*args)))( + R, Z, neighbor, box, offsets + ) + + n_ens = energy_ens.shape[0] + divisor = 1 / (n_ens - 1) + + energy_mean = jnp.mean(energy_ens) + energy_variance = divisor * fp64_sum((energy_ens - energy_mean) ** 2) + + # forces_mean = jnp.mean(forces_ens, axis=0) + # forces_variance = divisor * fp64_sum((forces_ens - forces_mean)**2, axis=0) + + prediction = { + "energy": energy_mean, + "forces": forces_mean, + "energy_uncertainty": energy_variance, + # "forces_uncertainty": forces_variance, + } + + return prediction diff --git a/apax/train/parameters.py b/apax/train/parameters.py index 5adc1830..bf7b6e8a 100644 --- a/apax/train/parameters.py +++ b/apax/train/parameters.py @@ -3,8 +3,7 @@ @jax.jit def tree_ema(tree1, tree2, alpha): - """Exponential moving average of two pytrees. - """ + """Exponential moving average of two pytrees.""" ema = jax.tree_map(lambda a, b: alpha * a + (1 - alpha) * b, tree1, tree2) return ema @@ -20,7 +19,8 @@ class EMAParameters: alpha : float, default = 0.9 How much of the new model to use. 1.0 would mean no averaging, 0.0 no updates. """ - def __init__(self, ema_start: int , alpha: float = 0.9) -> None: + + def __init__(self, ema_start: int, alpha: float = 0.9) -> None: self.alpha = alpha self.ema_start = ema_start self.ema_params = None From 705f109751e767ba742f4addf84c0e3a517b0d33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sat, 4 May 2024 10:32:01 +0200 Subject: [PATCH 06/29] added shallow ens to builder --- apax/model/builder.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/apax/model/builder.py b/apax/model/builder.py index 09e1972c..ba6ea122 100644 --- a/apax/model/builder.py +++ b/apax/model/builder.py @@ -6,7 +6,12 @@ from apax.layers.empirical import ZBLRepulsion from apax.layers.readout import AtomisticReadout from apax.layers.scaling import PerElementScaleShift -from apax.model.gmnn import AtomisticModel, EnergyDerivativeModel, EnergyModel +from apax.model.gmnn import ( + AtomisticModel, + EnergyDerivativeModel, + EnergyModel, + ShallowEnsembleModel, +) class ModelBuilder: @@ -121,9 +126,14 @@ def build_energy_derivative_model( init_box=init_box, inference_disp_fn=inference_disp_fn, ) - - model = EnergyDerivativeModel( - energy_model, - calc_stress=self.config["calc_stress"], - ) + if self.config.n_shallow_ensemble > 0: + model = ShallowEnsembleModel( + energy_model, + calc_stress=False, # self.config["calc_stress"], + ) + else: + model = EnergyDerivativeModel( + energy_model, + calc_stress=self.config["calc_stress"], + ) return model From 1520ff5c87c3200aa383737b84670b9964925153 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sat, 4 May 2024 13:14:26 +0200 Subject: [PATCH 07/29] fixed tests --- apax/layers/readout.py | 5 +- apax/model/builder.py | 2 +- apax/model/gmnn.py | 7 +-- apax/train/loss.py | 3 +- tests/unit_tests/train/test_loss.py | 74 ++++++++++++++++------------- 5 files changed, 50 insertions(+), 41 deletions(-) diff --git a/apax/layers/readout.py b/apax/layers/readout.py index 7cc7f814..3d7611dd 100644 --- a/apax/layers/readout.py +++ b/apax/layers/readout.py @@ -17,10 +17,11 @@ class AtomisticReadout(nn.Module): def setup(self): readout_unit = [1] + print(self.n_shallow_ensemble) if self.n_shallow_ensemble > 0: readout_unit = [self.n_shallow_ensemble] - self._n_shallow_ensemble = self.n_shallow_ensemble - + # self._n_shallow_ensemble = self.n_shallow_ensemble + print(readout_unit) units = [u for u in self.units] + readout_unit dense = [] for ii, n_hidden in enumerate(units): diff --git a/apax/model/builder.py b/apax/model/builder.py index ba6ea122..ab177cc9 100644 --- a/apax/model/builder.py +++ b/apax/model/builder.py @@ -126,7 +126,7 @@ def build_energy_derivative_model( init_box=init_box, inference_disp_fn=inference_disp_fn, ) - if self.config.n_shallow_ensemble > 0: + if self.config["n_shallow_ensemble"] > 0: model = ShallowEnsembleModel( energy_model, calc_stress=False, # self.config["calc_stress"], diff --git a/apax/model/gmnn.py b/apax/model/gmnn.py index bf85d189..21dde1b1 100644 --- a/apax/model/gmnn.py +++ b/apax/model/gmnn.py @@ -122,15 +122,15 @@ def __call__( dr_vec = self.displacement(Rj, Ri, perturbation, box) dr_vec += offsets - aux_prediction = {} - # Model Core # shape Natoms # shape shallow ens: Natoms x Nensemble atomic_energies = self.atomistic_model(dr_vec, Z, idx) + print(atomic_energies.shape) # check for shallow ensemble - is_shallow_ensemble = len(atomic_energies.shape) > 1 + is_shallow_ensemble = atomic_energies.shape[1] > 1 + print(is_shallow_ensemble) if is_shallow_ensemble: total_energies_ensemble = fp64_sum(atomic_energies, axis=0) # shape Nensemble @@ -138,6 +138,7 @@ def __call__( else: # shape () result = fp64_sum(atomic_energies) + print(result.shape) # Corrections for correction in self.corrections: diff --git a/apax/train/loss.py b/apax/train/loss.py index 36533c38..987c140b 100644 --- a/apax/train/loss.py +++ b/apax/train/loss.py @@ -79,8 +79,7 @@ def nll_loss( divisor: float = 1.0, parameters: dict = {}, ) -> jax.Array: - """Computes the gaussian NLL loss means, targets and vars (uncertainty estimate) - """ + """Computes the gaussian NLL loss means, targets and vars (uncertainty estimate)""" label, means, vars = label[name], prediction[name], prediction[name + "_uncertainty"] eps = 1e-4 sigma = jnp.sqrt(vars) diff --git a/tests/unit_tests/train/test_loss.py b/tests/unit_tests/train/test_loss.py index 6e3263fd..975751d9 100644 --- a/tests/unit_tests/train/test_loss.py +++ b/tests/unit_tests/train/test_loss.py @@ -11,68 +11,76 @@ def test_weighted_squared_error(): - energy_label = jnp.array([[0.1, 0.4, 0.2, -0.5], [0.1, -0.1, 0.8, 0.6]]) + name = "energy" + label = {"energy": jnp.array([[0.1, 0.4, 0.2, -0.5], [0.1, -0.1, 0.8, 0.6]])} - loss = weighted_squared_error(energy_label, energy_label, divisor=1.0) + loss = weighted_squared_error(label, label, name, divisor=1.0) loss = jnp.sum(loss) ref = 0.0 assert loss.shape == () assert abs(loss - ref) < 1e-6 - pred = jnp.array( - [ - [0.6, 0.4, 0.2, -0.5], - [0.1, -0.1, 0.8, 0.6], - ] - ) - loss = weighted_squared_error(energy_label, pred, divisor=1.0) + pred = { + "energy": jnp.array( + [ + [0.6, 0.4, 0.2, -0.5], + [0.1, -0.1, 0.8, 0.6], + ] + ) + } + loss = weighted_squared_error(label, pred, name, divisor=1.0) loss = jnp.sum(loss) ref = 0.25 assert abs(loss - ref) < 1e-6 - loss = weighted_squared_error(energy_label, pred, divisor=2.0) + loss = weighted_squared_error(label, pred, name, divisor=2.0) loss = jnp.sum(loss) ref = 0.125 assert abs(loss - ref) < 1e-6 def test_force_angle_loss(): - F_pred = jnp.array( - [ + name = "forces" + F_pred = { + "forces": jnp.array( [ - [0.5, 0.0, 0.0], - [0.5, 0.0, 0.0], - [0.5, 0.5, 0.0], - [0.0, 0.5, 0.0], - [0.0, 0.5, 0.0], - [0.0, 0.0, 0.0], # padding + [ + [0.5, 0.0, 0.0], + [0.5, 0.0, 0.0], + [0.5, 0.5, 0.0], + [0.0, 0.5, 0.0], + [0.0, 0.5, 0.0], + [0.0, 0.0, 0.0], # padding + ] ] - ] - ) + ) + } - F_0 = jnp.array( - [ + F_0 = { + "forces": jnp.array( [ - [0.5, 0.0, 0.0], - [0.9, 0.0, 0.0], - [0.5, 0.0, 0.0], - [0.5, 0.0, 0.0], - [0.9, 0.0, 0.0], - [0.0, 0.0, 0.0], # padding + [ + [0.5, 0.0, 0.0], + [0.9, 0.0, 0.0], + [0.5, 0.0, 0.0], + [0.5, 0.0, 0.0], + [0.9, 0.0, 0.0], + [0.0, 0.0, 0.0], # padding + ] ] - ] - ) + ) + } - F_angle_loss = force_angle_loss(F_pred, F_0) + F_angle_loss = force_angle_loss(F_pred, F_0, name) F_angle_loss = jnp.arccos(-F_angle_loss + 1) * 360 / (2 * np.pi) assert F_angle_loss.shape == (1, 6) ref = jnp.array([0.0, 0.0, 45.0, 90.0, 90.0, 90.0]) assert jnp.allclose(F_angle_loss, ref) - F_angle_loss = force_angle_div_force_label(F_pred, F_0) + F_angle_loss = force_angle_div_force_label(F_pred, F_0, name) assert F_angle_loss.shape == (1, 6) - F_angle_loss = force_angle_exponential_weight(F_pred, F_0) + F_angle_loss = force_angle_exponential_weight(F_pred, F_0, name) assert F_angle_loss.shape == (1, 6) From 48dc3ed9fbd3028aa21648d0c7e2a7de7389270c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 7 May 2024 10:18:12 +0200 Subject: [PATCH 08/29] removed debug print statements --- apax/model/gmnn.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/apax/model/gmnn.py b/apax/model/gmnn.py index 21dde1b1..e5acedd9 100644 --- a/apax/model/gmnn.py +++ b/apax/model/gmnn.py @@ -126,11 +126,9 @@ def __call__( # shape Natoms # shape shallow ens: Natoms x Nensemble atomic_energies = self.atomistic_model(dr_vec, Z, idx) - print(atomic_energies.shape) # check for shallow ensemble is_shallow_ensemble = atomic_energies.shape[1] > 1 - print(is_shallow_ensemble) if is_shallow_ensemble: total_energies_ensemble = fp64_sum(atomic_energies, axis=0) # shape Nensemble @@ -138,7 +136,6 @@ def __call__( else: # shape () result = fp64_sum(atomic_energies) - print(result.shape) # Corrections for correction in self.corrections: From df7bee5ad7f1eb0d4b0dcaab8880634217b0e24b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 7 May 2024 10:18:28 +0200 Subject: [PATCH 09/29] correctly pass n_shallow ens in builder --- apax/model/builder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/apax/model/builder.py b/apax/model/builder.py index ab177cc9..7ea6a9b8 100644 --- a/apax/model/builder.py +++ b/apax/model/builder.py @@ -56,6 +56,7 @@ def build_readout(self): readout = AtomisticReadout( units=self.config["nn"], b_init=self.config["b_init"], + n_shallow_ensemble=self.config["n_shallow_ensemble"], dtype=self.config["readout_dtype"], ) return readout From c7866681741d0513a4203dc0d0eee7d657bbca6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 7 May 2024 10:18:40 +0200 Subject: [PATCH 10/29] remove debug print statements --- apax/layers/readout.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/apax/layers/readout.py b/apax/layers/readout.py index 3d7611dd..2d077135 100644 --- a/apax/layers/readout.py +++ b/apax/layers/readout.py @@ -17,11 +17,8 @@ class AtomisticReadout(nn.Module): def setup(self): readout_unit = [1] - print(self.n_shallow_ensemble) if self.n_shallow_ensemble > 0: readout_unit = [self.n_shallow_ensemble] - # self._n_shallow_ensemble = self.n_shallow_ensemble - print(readout_unit) units = [u for u in self.units] + readout_unit dense = [] for ii, n_hidden in enumerate(units): From c1f734ae0031a669f14daba5935b2740d9d07a19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 7 May 2024 10:18:52 +0200 Subject: [PATCH 11/29] docstring --- apax/config/lr_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/config/lr_config.py b/apax/config/lr_config.py index b64f4f8c..5d1c67a8 100644 --- a/apax/config/lr_config.py +++ b/apax/config/lr_config.py @@ -34,7 +34,7 @@ class CyclicCosineLR(LRSchedule, frozen=True, extra="forbid"): Parameters ---------- period: int = 20 - Length of a cycle. + Length of a cycle in epochs. decay_factor: NonNegativeFloat = 1.0 Factor by which to decrease the LR after each cycle. 1.0 means no decrease. From 6a83516e2c4704507a56a8e3e7ed8962c7520c4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 8 May 2024 11:38:40 +0200 Subject: [PATCH 12/29] return ensemble predictions from model --- apax/model/gmnn.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/apax/model/gmnn.py b/apax/model/gmnn.py index e5acedd9..715cf17e 100644 --- a/apax/model/gmnn.py +++ b/apax/model/gmnn.py @@ -195,12 +195,12 @@ def __call__( offsets, ): energy_ens = self.energy_model(R, Z, neighbor, box, offsets) - # forces_ens = - jax.jacrev(self.energy_model)( - # R, Z, neighbor, box, offsets - # ) - forces_mean = -jax.grad(lambda *args: jnp.mean(self.energy_model(*args)))( + forces_ens = - jax.jacfwd(self.energy_model)( R, Z, neighbor, box, offsets ) + # forces_mean = -jax.grad(lambda *args: jnp.mean(self.energy_model(*args)))( + # R, Z, neighbor, box, offsets + # ) n_ens = energy_ens.shape[0] divisor = 1 / (n_ens - 1) @@ -208,14 +208,16 @@ def __call__( energy_mean = jnp.mean(energy_ens) energy_variance = divisor * fp64_sum((energy_ens - energy_mean) ** 2) - # forces_mean = jnp.mean(forces_ens, axis=0) - # forces_variance = divisor * fp64_sum((forces_ens - forces_mean)**2, axis=0) + forces_mean = jnp.mean(forces_ens, axis=0) + forces_variance = divisor * fp64_sum((forces_ens - forces_mean)**2, axis=0) prediction = { "energy": energy_mean, "forces": forces_mean, "energy_uncertainty": energy_variance, - # "forces_uncertainty": forces_variance, + "forces_uncertainty": forces_variance, + "energy_ensemble": energy_ens, + "forces_ensemble": forces_ens, } return prediction From 6e250699bd4407e901bdfc69de76072cbbfe7089 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 May 2024 17:00:43 +0000 Subject: [PATCH 13/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/model/gmnn.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/apax/model/gmnn.py b/apax/model/gmnn.py index 715cf17e..f3e795d7 100644 --- a/apax/model/gmnn.py +++ b/apax/model/gmnn.py @@ -195,9 +195,7 @@ def __call__( offsets, ): energy_ens = self.energy_model(R, Z, neighbor, box, offsets) - forces_ens = - jax.jacfwd(self.energy_model)( - R, Z, neighbor, box, offsets - ) + forces_ens = -jax.jacfwd(self.energy_model)(R, Z, neighbor, box, offsets) # forces_mean = -jax.grad(lambda *args: jnp.mean(self.energy_model(*args)))( # R, Z, neighbor, box, offsets # ) @@ -209,7 +207,7 @@ def __call__( energy_variance = divisor * fp64_sum((energy_ens - energy_mean) ** 2) forces_mean = jnp.mean(forces_ens, axis=0) - forces_variance = divisor * fp64_sum((forces_ens - forces_mean)**2, axis=0) + forces_variance = divisor * fp64_sum((forces_ens - forces_mean) ** 2, axis=0) prediction = { "energy": energy_mean, From d22dc6c13b06a839cdccd18b49349e68d0af3a63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 15 May 2024 12:21:33 +0200 Subject: [PATCH 14/29] implemented switch between mean force and ensemble force computation --- apax/model/gmnn.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/apax/model/gmnn.py b/apax/model/gmnn.py index 715cf17e..6b805e54 100644 --- a/apax/model/gmnn.py +++ b/apax/model/gmnn.py @@ -185,6 +185,7 @@ class ShallowEnsembleModel(nn.Module): energy_model: EnergyModel = EnergyModel() calc_stress: bool = False + force_variance: bool = True def __call__( self, @@ -195,12 +196,6 @@ def __call__( offsets, ): energy_ens = self.energy_model(R, Z, neighbor, box, offsets) - forces_ens = - jax.jacfwd(self.energy_model)( - R, Z, neighbor, box, offsets - ) - # forces_mean = -jax.grad(lambda *args: jnp.mean(self.energy_model(*args)))( - # R, Z, neighbor, box, offsets - # ) n_ens = energy_ens.shape[0] divisor = 1 / (n_ens - 1) @@ -208,16 +203,24 @@ def __call__( energy_mean = jnp.mean(energy_ens) energy_variance = divisor * fp64_sum((energy_ens - energy_mean) ** 2) - forces_mean = jnp.mean(forces_ens, axis=0) - forces_variance = divisor * fp64_sum((forces_ens - forces_mean)**2, axis=0) - prediction = { "energy": energy_mean, - "forces": forces_mean, - "energy_uncertainty": energy_variance, - "forces_uncertainty": forces_variance, "energy_ensemble": energy_ens, - "forces_ensemble": forces_ens, + "energy_uncertainty": energy_variance, } + if self.force_variance: + forces_ens = -jax.jacrev(self.energy_model)(R, Z, neighbor, box, offsets) + forces_mean = jnp.mean(forces_ens, axis=0) + forces_variance = divisor * fp64_sum((forces_ens - forces_mean) ** 2, axis=0) + + prediction["forces"] = (forces_mean,) + prediction["forces_uncertainty"] = (forces_variance,) + prediction["forces_ensemble"] = (forces_ens,) + else: + forces_mean = -jax.grad(lambda *args: jnp.mean(self.energy_model(*args)))( + R, Z, neighbor, box, offsets + ) + prediction["forces"] = (forces_mean,) + return prediction From af812179bc5fa484142c8f3cd9817aaf975b7677 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 17 May 2024 12:00:17 +0200 Subject: [PATCH 15/29] implemented stress prediction for shallow ens --- apax/model/builder.py | 2 +- apax/model/gmnn.py | 33 +++++++++++++++++++++++++++------ 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/apax/model/builder.py b/apax/model/builder.py index 7ea6a9b8..0e0002be 100644 --- a/apax/model/builder.py +++ b/apax/model/builder.py @@ -130,7 +130,7 @@ def build_energy_derivative_model( if self.config["n_shallow_ensemble"] > 0: model = ShallowEnsembleModel( energy_model, - calc_stress=False, # self.config["calc_stress"], + calc_stress=self.config["calc_stress"], ) else: model = EnergyDerivativeModel( diff --git a/apax/model/gmnn.py b/apax/model/gmnn.py index 6b805e54..28c21b2b 100644 --- a/apax/model/gmnn.py +++ b/apax/model/gmnn.py @@ -178,6 +178,22 @@ def __call__( return prediction +def make_mean_energy_fn(energy_fn): + def mean_energy_fn( + R: Array, + Z: Array, + neighbor: Union[partition.NeighborList, Array], + box, + offsets, + perturbation=None, + ): + e_ens = energy_fn(R, Z, neighbor, box, offsets, perturbation) + E_mean = jnp.mean(e_ens) + return E_mean + + return mean_energy_fn + + class ShallowEnsembleModel(nn.Module): """Transforms an EnergyModel into one that also predicts derivatives the total energy. Can calculate forces and stress tensors. @@ -196,6 +212,7 @@ def __call__( offsets, ): energy_ens = self.energy_model(R, Z, neighbor, box, offsets) + mean_energy_fn = make_mean_energy_fn(self.energy_model) n_ens = energy_ens.shape[0] divisor = 1 / (n_ens - 1) @@ -214,13 +231,17 @@ def __call__( forces_mean = jnp.mean(forces_ens, axis=0) forces_variance = divisor * fp64_sum((forces_ens - forces_mean) ** 2, axis=0) - prediction["forces"] = (forces_mean,) - prediction["forces_uncertainty"] = (forces_variance,) - prediction["forces_ensemble"] = (forces_ens,) + prediction["forces"] = forces_mean + prediction["forces_uncertainty"] = forces_variance + prediction["forces_ensemble"] = forces_ens else: - forces_mean = -jax.grad(lambda *args: jnp.mean(self.energy_model(*args)))( - R, Z, neighbor, box, offsets + forces_mean = -jax.grad(mean_energy_fn)(R, Z, neighbor, box, offsets) + prediction["forces"] = forces_mean + + if self.calc_stress: + stress = stress_times_vol( + mean_energy_fn, R, box, Z=Z, neighbor=neighbor, offsets=offsets ) - prediction["forces"] = (forces_mean,) + prediction["stress"] = stress return prediction From df19d043c604af993c7ca7902b754c383db2fc0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 17 May 2024 12:00:33 +0200 Subject: [PATCH 16/29] compatibility with BAL --- apax/bal/feature_maps.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/apax/bal/feature_maps.py b/apax/bal/feature_maps.py index 565439ed..8b41356b 100644 --- a/apax/bal/feature_maps.py +++ b/apax/bal/feature_maps.py @@ -58,7 +58,11 @@ def inner(ll_params): inputs["box"], inputs["offsets"], ) - return model.apply(full_params, R, Z, idx, box, offsets) + out = model.apply(full_params, R, Z, idx, box, offsets) + # take mean in case of shallow ensemble + # no effect for single model + out = jnp.mean(out) + return out g_ll = jax.grad(inner)(ll_params) g_ll = unflatten_dict(g_ll) From 7f51cddbd37802438665df1ddcfac895d38b1b28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 17 May 2024 12:00:58 +0200 Subject: [PATCH 17/29] loss fn divisor --- apax/train/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/train/loss.py b/apax/train/loss.py index 987c140b..e5f10e71 100644 --- a/apax/train/loss.py +++ b/apax/train/loss.py @@ -69,7 +69,7 @@ def crps_loss( crps = sigma * (norm_x * (2 * cdf - 1) + 2 * pdf - 1 / jnp.sqrt(np.pi)) - return crps # / divisor # TODO how to account for differently sized systems? + return crps / divisor def nll_loss( From d17a8bf0475bc64455830bc7fb9668d4a490ee67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 22 May 2024 11:30:26 +0200 Subject: [PATCH 18/29] added ApaxImport node --- apax/nodes/__init__.py | 4 ++-- apax/nodes/model.py | 54 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/apax/nodes/__init__.py b/apax/nodes/__init__.py index ae635831..95ef5325 100644 --- a/apax/nodes/__init__.py +++ b/apax/nodes/__init__.py @@ -1,4 +1,4 @@ from .md import ApaxJaxMD -from .model import Apax, ApaxEnsemble +from .model import Apax, ApaxEnsemble, ApaxImport -__all__ = ["Apax", "ApaxEnsemble", "ApaxJaxMD"] +__all__ = ["Apax", "ApaxEnsemble", "ApaxJaxMD", "ApaxImport"] diff --git a/apax/nodes/model.py b/apax/nodes/model.py index 9c74de69..18aa9030 100644 --- a/apax/nodes/model.py +++ b/apax/nodes/model.py @@ -145,3 +145,57 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator: transformations=transformations, ) return calc + + +class ApaxImport(zntrack.Node): + """Parallel apax model ensemble in ASE. + + Parameters + ---------- + models: list + List of `ApaxModel` nodes to ensemble. + nl_skin: float + Neighborlist skin. + transformations: dict + Key-parameter dict with function transformations applied + to the model function within the ASE calculator. + See the apax documentation for available methods. + """ + + config: str = zntrack.params_path() + nl_skin: float = zntrack.params(0.5) + transformations: dict[str, dict] = zntrack.params(None) + + _parameter: dict = None + + def _post_load_(self) -> None: + self.adapt_parameter_file() + + def adapt_parameter_file(self): + with self.state.use_tmp_path(): + self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text()) + + def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator: + """Property to return a model specific ase calculator object. + + Returns + ------- + calc: + ase calculator object + """ + + directory = self._parameter["data"]["directory"] + exp = self._parameter["data"]["experiment"] + model_dir = directory + "/" + exp + + transformations = [] + if self.transformations: + for transform, params in self.transformations.items(): + transformations.append(available_transformations[transform](**params)) + + calc = ASECalculator( + model_dir, + dr=self.nl_skin, + transformations=transformations, + ) + return calc From 5673431d384c792c68b418d32ae8dc969fdfbf82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sat, 1 Jun 2024 12:02:20 +0200 Subject: [PATCH 19/29] made import node compatible with jaxmd node --- apax/nodes/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apax/nodes/model.py b/apax/nodes/model.py index 18aa9030..750256f2 100644 --- a/apax/nodes/model.py +++ b/apax/nodes/model.py @@ -169,9 +169,9 @@ class ApaxImport(zntrack.Node): _parameter: dict = None def _post_load_(self) -> None: - self.adapt_parameter_file() + self._handle_parameter_file() - def adapt_parameter_file(self): + def _handle_parameter_file(self): with self.state.use_tmp_path(): self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text()) From ea64e132c6812c9a79e50cda6c6e9a3da5f4c01a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sat, 1 Jun 2024 12:02:41 +0200 Subject: [PATCH 20/29] updated jaxmd to use new callback api --- apax/md/io.py | 4 ++-- apax/md/nvt.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/apax/md/io.py b/apax/md/io.py index 8f783a0a..4a412678 100644 --- a/apax/md/io.py +++ b/apax/md/io.py @@ -21,7 +21,7 @@ def __init__(self) -> None: self.traj_path: Path self.time_step: float - def step(self, state_and_energy, transform): + def step(self, state_and_energy, transform=None): pass def write(self, x=None, transform=None): @@ -82,7 +82,7 @@ def __init__( def reset_buffer(self): self.buffer = [] - def step(self, state, transform): + def step(self, state, transform=None): state, energy, nbr_kwargs = state if self.step_counter % self.sampling_rate == 0: diff --git a/apax/md/nvt.py b/apax/md/nvt.py index 62922523..b560bd68 100644 --- a/apax/md/nvt.py +++ b/apax/md/nvt.py @@ -10,6 +10,8 @@ from ase.io import read from flax.training import checkpoints from jax.experimental.host_callback import barrier_wait, id_tap +from jax.experimental import io_callback +# from jax import pure_callback from tqdm import trange from tqdm.contrib.logging import logging_redirect_tqdm @@ -217,7 +219,7 @@ def body_fn(i, state): nbr_kwargs = nbr_options(state) neighbor = neighbor.update(state.position, **nbr_kwargs) - id_tap(traj_handler.step, (state, current_energy, nbr_kwargs)) + io_callback(traj_handler.step, None, (state, current_energy, nbr_kwargs)) return state, neighbor state, neighbor = jax.lax.fori_loop(0, n_inner, body_fn, (state, neighbor)) From 8fd0ca19c66f87efe2bbbf8214dd625d2f835130 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 5 Jun 2024 21:22:52 +0200 Subject: [PATCH 21/29] deleted comment --- apax/md/nvt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apax/md/nvt.py b/apax/md/nvt.py index b560bd68..ae33095c 100644 --- a/apax/md/nvt.py +++ b/apax/md/nvt.py @@ -11,7 +11,6 @@ from flax.training import checkpoints from jax.experimental.host_callback import barrier_wait, id_tap from jax.experimental import io_callback -# from jax import pure_callback from tqdm import trange from tqdm.contrib.logging import logging_redirect_tqdm From 2150bcc5df3fe6a0294e5c5509bff2fd7ab20da5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 5 Jun 2024 21:25:28 +0200 Subject: [PATCH 22/29] remove old import --- apax/md/simulate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/md/simulate.py b/apax/md/simulate.py index 36bef12a..0bdb23c9 100644 --- a/apax/md/simulate.py +++ b/apax/md/simulate.py @@ -9,7 +9,7 @@ from ase import units from ase.io import read from flax.training import checkpoints -from jax.experimental.host_callback import barrier_wait, id_tap +from jax.experimental.host_callback import barrier_wait from jax.experimental import io_callback from tqdm import trange from tqdm.contrib.logging import logging_redirect_tqdm From 9b0963e7a7e4318777b9b474f1ace21561bd8c95 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Jun 2024 19:26:09 +0000 Subject: [PATCH 23/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/md/simulate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/md/simulate.py b/apax/md/simulate.py index 0bdb23c9..196b337f 100644 --- a/apax/md/simulate.py +++ b/apax/md/simulate.py @@ -9,8 +9,8 @@ from ase import units from ase.io import read from flax.training import checkpoints -from jax.experimental.host_callback import barrier_wait from jax.experimental import io_callback +from jax.experimental.host_callback import barrier_wait from tqdm import trange from tqdm.contrib.logging import logging_redirect_tqdm From e8e01827126334eb6ca7fc2d5b07f9ca143cc41c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 5 Jun 2024 21:37:51 +0200 Subject: [PATCH 24/29] made shallow ensembles compatible with jaxmd --- apax/md/simulate.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/apax/md/simulate.py b/apax/md/simulate.py index 0bdb23c9..01eb5e50 100644 --- a/apax/md/simulate.py +++ b/apax/md/simulate.py @@ -9,8 +9,8 @@ from ase import units from ase.io import read from flax.training import checkpoints -from jax.experimental.host_callback import barrier_wait from jax.experimental import io_callback +from jax.experimental.host_callback import barrier_wait from tqdm import trange from tqdm.contrib.logging import logging_redirect_tqdm @@ -29,16 +29,23 @@ log = logging.getLogger(__name__) -def create_energy_fn(model, params, numbers, n_models): - def ensemble(params, R, Z, neighbor, box, offsets, perturbation=None): +def create_energy_fn(model, params, numbers, n_models, shallow=False): + def full_ensemble(params, R, Z, neighbor, box, offsets, perturbation=None): vmodel = jax.vmap(model, (0, None, None, None, None, None, None), 0) energies = vmodel(params, R, Z, neighbor, box, offsets, perturbation) energy = jnp.mean(energies) + return energy + def shallow_ensemble(params, R, Z, neighbor, box, offsets, perturbation=None): + energies = model(params, R, Z, neighbor, box, offsets, perturbation) + energy = jnp.mean(energies) return energy if n_models > 1: - energy_fn = ensemble + if shallow: + energy_fn = shallow_ensemble + else: + energy_fn = full_ensemble else: energy_fn = model @@ -376,8 +383,16 @@ def md_setup(model_config: Config, md_config: MDConfig): _, params = restore_parameters(model_config.data.model_version_path) params = canonicalize_energy_model_parameters(params) + + if model_config.n_models > 1: + n_models = model_config.n_models + shallow = False + elif model_config.model.n_shallow_ensemble > 1: + n_models = model_config.model.n_shallow_ensemble + shallow = True + energy_fn = create_energy_fn( - model.apply, params, system.atomic_numbers, model_config.n_models + model.apply, params, system.atomic_numbers, n_models, shallow ) sim_fns = SimulationFunctions(energy_fn, shift_fn, neighbor_fn) return system, sim_fns From cb65848defe1d31596703dda14fbfa3128463fd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 5 Jun 2024 22:07:09 +0200 Subject: [PATCH 25/29] fixed bug for using jaxmd without ensembles --- apax/md/simulate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/apax/md/simulate.py b/apax/md/simulate.py index 01eb5e50..18af991d 100644 --- a/apax/md/simulate.py +++ b/apax/md/simulate.py @@ -384,6 +384,7 @@ def md_setup(model_config: Config, md_config: MDConfig): _, params = restore_parameters(model_config.data.model_version_path) params = canonicalize_energy_model_parameters(params) + n_models = 1 if model_config.n_models > 1: n_models = model_config.n_models shallow = False From c7e3c58bd6c989e41841929d6b4edf308dd6525c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 5 Jun 2024 22:41:15 +0200 Subject: [PATCH 26/29] account for non shallow ensemble --- apax/md/simulate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/md/simulate.py b/apax/md/simulate.py index 18af991d..881d514b 100644 --- a/apax/md/simulate.py +++ b/apax/md/simulate.py @@ -385,9 +385,9 @@ def md_setup(model_config: Config, md_config: MDConfig): params = canonicalize_energy_model_parameters(params) n_models = 1 + shallow = False if model_config.n_models > 1: n_models = model_config.n_models - shallow = False elif model_config.model.n_shallow_ensemble > 1: n_models = model_config.model.n_shallow_ensemble shallow = True From 6a6ad255cbaeb3af49208d78a6112ac1d0abfa22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 12 Jun 2024 16:42:13 +0200 Subject: [PATCH 27/29] consistent use of std as uncertainty, not variance --- apax/model/gmnn.py | 4 ++-- apax/train/loss.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/apax/model/gmnn.py b/apax/model/gmnn.py index 28c21b2b..5324d3bc 100644 --- a/apax/model/gmnn.py +++ b/apax/model/gmnn.py @@ -223,7 +223,7 @@ def __call__( prediction = { "energy": energy_mean, "energy_ensemble": energy_ens, - "energy_uncertainty": energy_variance, + "energy_uncertainty": jnp.sqrt(energy_variance), } if self.force_variance: @@ -232,7 +232,7 @@ def __call__( forces_variance = divisor * fp64_sum((forces_ens - forces_mean) ** 2, axis=0) prediction["forces"] = forces_mean - prediction["forces_uncertainty"] = forces_variance + prediction["forces_uncertainty"] = jnp.sqrt(forces_variance) prediction["forces_ensemble"] = forces_ens else: forces_mean = -jax.grad(mean_energy_fn)(R, Z, neighbor, box, offsets) diff --git a/apax/train/loss.py b/apax/train/loss.py index e5f10e71..c4929ff8 100644 --- a/apax/train/loss.py +++ b/apax/train/loss.py @@ -55,19 +55,18 @@ def crps_loss( """Computes the CRPS of a gaussian distribution given means, targets and vars (uncertainty estimate) """ - label, means, vars = label[name], prediction[name], prediction[name + "_uncertainty"] + label, means, sigmas = label[name], prediction[name], prediction[name + "_uncertainty"] - sigma = jnp.sqrt(vars) - sigma = jnp.clip(sigma, a_min=1e-6) + sigmas = jnp.clip(sigmas, a_min=1e-6) - norm_x = (label - means) / sigma + norm_x = (label - means) / sigmas cdf = 0.5 * (1 + jsc.special.erf(norm_x / jnp.sqrt(2))) normalization = 1 / (jnp.sqrt(2.0 * np.pi)) pdf = normalization * jnp.exp(-(norm_x**2) / 2.0) - crps = sigma * (norm_x * (2 * cdf - 1) + 2 * pdf - 1 / jnp.sqrt(np.pi)) + crps = sigmas * (norm_x * (2 * cdf - 1) + 2 * pdf - 1 / jnp.sqrt(np.pi)) return crps / divisor @@ -79,14 +78,15 @@ def nll_loss( divisor: float = 1.0, parameters: dict = {}, ) -> jax.Array: - """Computes the gaussian NLL loss means, targets and vars (uncertainty estimate)""" - label, means, vars = label[name], prediction[name], prediction[name + "_uncertainty"] + """Computes the gaussian NLL loss means, targets and variances (uncertainty estimate)""" + label, means, sigmas = label[name], prediction[name], prediction[name + "_uncertainty"] + variances = jnp.pow(sigmas, 2) eps = 1e-4 - sigma = jnp.sqrt(vars) + sigma = jnp.sqrt(variances) sigma = jnp.clip(sigma, a_min=1e-4) x1 = jnp.log(jnp.maximum(vars, eps)) - x2 = (means - label) ** 2 / (jnp.maximum(vars, eps)) + x2 = (means - label) ** 2 / (jnp.maximum(variances, eps)) nll = 0.5 * (x1 + x2) return nll From 8a86bee27f4b3cef80158b47423c1c004543e4ea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jun 2024 14:42:29 +0000 Subject: [PATCH 28/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/layers/readout.py | 10 +++++----- apax/train/loss.py | 12 ++++++++++-- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/apax/layers/readout.py b/apax/layers/readout.py index 42f922fd..fbf8b0a3 100644 --- a/apax/layers/readout.py +++ b/apax/layers/readout.py @@ -26,11 +26,11 @@ def setup(self): for ii, n_hidden in enumerate(units): layer = NTKLinear( n_hidden, - w_init=self.w_init, - b_init=self.b_init, - use_ntk=self.use_ntk, - dtype=self.dtype, - name=f"dense_{ii}", + w_init=self.w_init, + b_init=self.b_init, + use_ntk=self.use_ntk, + dtype=self.dtype, + name=f"dense_{ii}", ) dense.append(layer) if ii < len(units) - 1: diff --git a/apax/train/loss.py b/apax/train/loss.py index c4929ff8..3803b97d 100644 --- a/apax/train/loss.py +++ b/apax/train/loss.py @@ -55,7 +55,11 @@ def crps_loss( """Computes the CRPS of a gaussian distribution given means, targets and vars (uncertainty estimate) """ - label, means, sigmas = label[name], prediction[name], prediction[name + "_uncertainty"] + label, means, sigmas = ( + label[name], + prediction[name], + prediction[name + "_uncertainty"], + ) sigmas = jnp.clip(sigmas, a_min=1e-6) @@ -79,7 +83,11 @@ def nll_loss( parameters: dict = {}, ) -> jax.Array: """Computes the gaussian NLL loss means, targets and variances (uncertainty estimate)""" - label, means, sigmas = label[name], prediction[name], prediction[name + "_uncertainty"] + label, means, sigmas = ( + label[name], + prediction[name], + prediction[name + "_uncertainty"], + ) variances = jnp.pow(sigmas, 2) eps = 1e-4 sigma = jnp.sqrt(variances) From 8de78f9faba4e10e3e5a73712b2ebb30e9dc5b96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 12 Jun 2024 16:45:54 +0200 Subject: [PATCH 29/29] linting --- apax/train/loss.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/apax/train/loss.py b/apax/train/loss.py index c4929ff8..b2f95673 100644 --- a/apax/train/loss.py +++ b/apax/train/loss.py @@ -53,9 +53,11 @@ def crps_loss( parameters: dict = {}, ) -> jax.Array: """Computes the CRPS of a gaussian distribution given - means, targets and vars (uncertainty estimate) + means, targets and standard deviations (uncertainty estimate) """ - label, means, sigmas = label[name], prediction[name], prediction[name + "_uncertainty"] + label = label[name] + means = prediction[name] + sigmas = prediction[name + "_uncertainty"] sigmas = jnp.clip(sigmas, a_min=1e-6) @@ -78,12 +80,17 @@ def nll_loss( divisor: float = 1.0, parameters: dict = {}, ) -> jax.Array: - """Computes the gaussian NLL loss means, targets and variances (uncertainty estimate)""" - label, means, sigmas = label[name], prediction[name], prediction[name + "_uncertainty"] + """Computes the gaussian NLL loss given + means, targets and standard deviations (uncertainty estimate) + """ + label = label[name] + means = prediction[name] + sigmas = prediction[name + "_uncertainty"] + variances = jnp.pow(sigmas, 2) eps = 1e-4 sigma = jnp.sqrt(variances) - sigma = jnp.clip(sigma, a_min=1e-4) + sigma = jnp.clip(sigma, a_min=1e-6) x1 = jnp.log(jnp.maximum(vars, eps)) x2 = (means - label) ** 2 / (jnp.maximum(variances, eps))