Skip to content

Commit

Permalink
Merge pull request #315 from apax-hub/ips-move
Browse files Browse the repository at this point in the history
moved BAL selection node from ips to apax
  • Loading branch information
M-R-Schaefer authored Aug 8, 2024
2 parents a74c38f + 5855c97 commit 70b388c
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 3 deletions.
2 changes: 2 additions & 0 deletions apax/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

try:
from .analysis import ApaxBatchPrediction # noqa: F401
from .selection import BatchKernelSelection # noqa: F401

__all__.append("ApaxBatchPrediction")
__all__.append("BatchKernelSelection")
except ImportError:
pass
85 changes: 85 additions & 0 deletions apax/nodes/selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import logging
import typing

import ase.io
import numpy as np
import zntrack.utils
from ipsuite.analysis.ensemble import plot_with_uncertainty
from ipsuite.configuration_selection.base import BatchConfigurationSelection
from matplotlib import pyplot as plt

from apax.bal import kernel_selection
from apax.nodes.model import ApaxBase

log = logging.getLogger(__name__)


class BatchKernelSelection(BatchConfigurationSelection):
"""Interface to the batch active learning methods implemented in apax.
Check the apax documentation for a list and explanation of implemented properties.
Attributes
----------
models: Union[Apax, List[Apax]]
One or more Apax models to construct a feature map from.
base_feature_map: dict
Name and parameters for the feature map transformation.
selection_method: str
Name of the selection method to be used. Choose from:
["max_dist", ]
n_configurations: int
Number of samples to be selected.
processing_batch_size: int
Number of samples to be processed in parallel.
Does not affect the result, just the speed of computing features.
"""

_module_ = "apax.nodes"

models: typing.List[ApaxBase] = zntrack.deps()
base_feature_map: dict = zntrack.params({"name": "ll_grad", "layer_name": "dense_2"})
selection_method: str = zntrack.params("max_dist")
n_configurations: str = zntrack.params()
processing_batch_size: str = zntrack.meta.Text(64)
img_selection = zntrack.outs_path(zntrack.nwd / "selection.png")

def select_atoms(self, atoms_lst: typing.List[ase.Atoms]) -> typing.List[int]:
if isinstance(self.models, list):
param_files = [m._parameter["data"]["directory"] for m in self.models]
else:
param_files = self.models._parameter["data"]["directory"]

selected = kernel_selection(
param_files,
self.train_data,
atoms_lst,
self.base_feature_map,
self.selection_method,
selection_batch_size=self.n_configurations,
processing_batch_size=self.processing_batch_size,
)
self._get_plot(atoms_lst, selected)

return list(selected)

def _get_plot(self, atoms_lst: typing.List[ase.Atoms], indices: typing.List[int]):
energies = np.array([atoms.calc.results["energy"] for atoms in atoms_lst])

if "energy_uncertainty" in atoms_lst[0].calc.results.keys():
uncertainty = np.array(
[atoms.calc.results["energy_uncertainty"] for atoms in atoms_lst]
)
fig, ax, _ = plot_with_uncertainty(
{"mean": energies, "std": uncertainty},
ylabel="energy",
xlabel="configuration",
)
else:
fig, ax = plt.subplots()
ax.plot(energies, label="energy")
ax.set_ylabel("energy")
ax.set_xlabel("configuration")

ax.plot(indices, energies[indices], "x", color="red")

fig.savefig(self.img_selection, bbox_inches="tight")
49 changes: 48 additions & 1 deletion tests/nodes/test_n_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
import os
import pathlib
import sys

try:
import ipsuite as ips
except ImportError:
pass

import numpy as np
import pytest
import yaml
import zntrack

import apax.nodes
from apax.nodes.model import Apax, ApaxEnsemble
from apax.nodes.utils import AddData

CONFIG_PATH = pathlib.Path(__file__).parent / "example.yaml"
TEST_PATH = pathlib.Path(__file__).parent.resolve()


def save_config_with_seed(path: str, seed: int = 1) -> None:
Expand Down Expand Up @@ -39,18 +49,47 @@ def test_n_train_model(tmp_path, get_md22_stachyose):
assert atoms.get_potential_energy() < 0


@pytest.mark.skipif("ipsuite" not in sys.modules, reason="requires new ipsuite release")
def test_n_train_2_model(tmp_path, get_md22_stachyose):
os.chdir(tmp_path)

save_config_with_seed(tmp_path / "example.yaml")
save_config_with_seed(tmp_path / "example2.yaml", seed=2)

proj = zntrack.Project(automatic_node_names=True)
thermostat = ips.calculators.LangevinThermostat(
time_step=1.0, temperature=100.0, friction=0.01
)
with proj:
data = AddData(file=get_md22_stachyose)
model1 = Apax(data=data.atoms, validation_data=data.atoms, config="example.yaml")
model2 = Apax(data=data.atoms, validation_data=data.atoms, config="example2.yaml")
ensemble = ApaxEnsemble(models=[model1, model2])
md = ips.calculators.ASEMD(
data=data.atoms,
model=ensemble,
thermostat=thermostat,
steps=20,
sampling_rate=1,
)

uncertainty_selection = ips.configuration_selection.ThresholdSelection(
data=md, n_configurations=1, threshold=0.0001
)

selection_batch_size = 3
kernel_selection = apax.nodes.BatchKernelSelection(
data=md.atoms,
train_data=data.atoms,
models=[model1, model2],
n_configurations=selection_batch_size,
processing_batch_size=4,
)

prediction = ips.analysis.Prediction(data=kernel_selection.atoms, model=ensemble)
analysis = ips.analysis.PredictionMetrics(
x=kernel_selection.atoms, y=prediction.atoms
)

proj.run()

Expand All @@ -61,4 +100,12 @@ def test_n_train_2_model(tmp_path, get_md22_stachyose):
atoms.calc = model.get_calculator()

assert atoms.get_potential_energy() < 0
assert atoms.calc.results["energy_uncertainty"] > 0

uncertainty_selection.load()
kernel_selection.load()
md.load()

uncertainties = [x.calc.results["energy_uncertainty"] for x in md.atoms]
assert [md.atoms[np.argmax(uncertainties)]] == uncertainty_selection.atoms

assert len(kernel_selection.atoms) == selection_batch_size
2 changes: 0 additions & 2 deletions tests/regression_tests/apax_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ data:

shift_method: "per_element_regression_shift"
shift_options: {"energy_regularisation": 1.0}
shuffle_buffer_size: 1000

pos_unit: Ang
energy_unit: eV
Expand Down Expand Up @@ -73,7 +72,6 @@ optimizer:
scale_lr: 0.001
shift_lr: 0.05
zbl_lr: 0.001
transition_begin: 0

callbacks:
- name: csv
Expand Down

0 comments on commit 70b388c

Please sign in to comment.