Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

moved BAL selection node from ips to apax #315

Merged
merged 12 commits into from
Aug 8, 2024
Merged
2 changes: 2 additions & 0 deletions apax/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

try:
from .analysis import ApaxBatchPrediction # noqa: F401
from .selection import BatchKernelSelection # noqa: F401

__all__.append("ApaxBatchPrediction")
__all__.append("BatchKernelSelection")
except ImportError:
pass
85 changes: 85 additions & 0 deletions apax/nodes/selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import logging
import typing

import ase.io
import numpy as np
import zntrack.utils
from ipsuite.analysis.ensemble import plot_with_uncertainty
from ipsuite.configuration_selection.base import BatchConfigurationSelection
from matplotlib import pyplot as plt

from apax.bal import kernel_selection
from apax.nodes.model import ApaxBase

log = logging.getLogger(__name__)


class BatchKernelSelection(BatchConfigurationSelection):
"""Interface to the batch active learning methods implemented in apax.
Check the apax documentation for a list and explanation of implemented properties.

Attributes
----------
models: Union[Apax, List[Apax]]
One or more Apax models to construct a feature map from.
base_feature_map: dict
Name and parameters for the feature map transformation.
selection_method: str
Name of the selection method to be used. Choose from:
["max_dist", ]
n_configurations: int
Number of samples to be selected.
processing_batch_size: int
Number of samples to be processed in parallel.
Does not affect the result, just the speed of computing features.
"""

_module_ = "apax.nodes"

models: typing.List[ApaxBase] = zntrack.deps()
base_feature_map: dict = zntrack.params({"name": "ll_grad", "layer_name": "dense_2"})
selection_method: str = zntrack.params("max_dist")
n_configurations: str = zntrack.params()
processing_batch_size: str = zntrack.meta.Text(64)
img_selection = zntrack.outs_path(zntrack.nwd / "selection.png")

def select_atoms(self, atoms_lst: typing.List[ase.Atoms]) -> typing.List[int]:
if isinstance(self.models, list):
param_files = [m._parameter["data"]["directory"] for m in self.models]
else:
param_files = self.models._parameter["data"]["directory"]

selected = kernel_selection(
param_files,
self.train_data,
atoms_lst,
self.base_feature_map,
self.selection_method,
selection_batch_size=self.n_configurations,
processing_batch_size=self.processing_batch_size,
)
self._get_plot(atoms_lst, selected)

return list(selected)

def _get_plot(self, atoms_lst: typing.List[ase.Atoms], indices: typing.List[int]):
energies = np.array([atoms.calc.results["energy"] for atoms in atoms_lst])

if "energy_uncertainty" in atoms_lst[0].calc.results.keys():
uncertainty = np.array(
[atoms.calc.results["energy_uncertainty"] for atoms in atoms_lst]
)
fig, ax, _ = plot_with_uncertainty(
{"mean": energies, "std": uncertainty},
ylabel="energy",
xlabel="configuration",
)
else:
fig, ax = plt.subplots()
ax.plot(energies, label="energy")
ax.set_ylabel("energy")
ax.set_xlabel("configuration")

ax.plot(indices, energies[indices], "x", color="red")

fig.savefig(self.img_selection, bbox_inches="tight")
49 changes: 48 additions & 1 deletion tests/nodes/test_n_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
import os
import pathlib
import sys

try:
import ipsuite as ips
except ImportError:
pass

import numpy as np
import pytest
import yaml
import zntrack

import apax.nodes
from apax.nodes.model import Apax, ApaxEnsemble
from apax.nodes.utils import AddData

CONFIG_PATH = pathlib.Path(__file__).parent / "example.yaml"
TEST_PATH = pathlib.Path(__file__).parent.resolve()


def save_config_with_seed(path: str, seed: int = 1) -> None:
Expand Down Expand Up @@ -39,18 +49,47 @@ def test_n_train_model(tmp_path, get_md22_stachyose):
assert atoms.get_potential_energy() < 0


@pytest.mark.skipif("ipsuite" not in sys.modules, reason="requires new ipsuite release")
def test_n_train_2_model(tmp_path, get_md22_stachyose):
os.chdir(tmp_path)

save_config_with_seed(tmp_path / "example.yaml")
save_config_with_seed(tmp_path / "example2.yaml", seed=2)

proj = zntrack.Project(automatic_node_names=True)
thermostat = ips.calculators.LangevinThermostat(
time_step=1.0, temperature=100.0, friction=0.01
)
with proj:
data = AddData(file=get_md22_stachyose)
model1 = Apax(data=data.atoms, validation_data=data.atoms, config="example.yaml")
model2 = Apax(data=data.atoms, validation_data=data.atoms, config="example2.yaml")
ensemble = ApaxEnsemble(models=[model1, model2])
md = ips.calculators.ASEMD(
data=data.atoms,
model=ensemble,
thermostat=thermostat,
steps=20,
sampling_rate=1,
)

uncertainty_selection = ips.configuration_selection.ThresholdSelection(
data=md, n_configurations=1, threshold=0.0001
)

selection_batch_size = 3
kernel_selection = apax.nodes.BatchKernelSelection(
data=md.atoms,
train_data=data.atoms,
models=[model1, model2],
n_configurations=selection_batch_size,
processing_batch_size=4,
)

prediction = ips.analysis.Prediction(data=kernel_selection.atoms, model=ensemble)
analysis = ips.analysis.PredictionMetrics(
x=kernel_selection.atoms, y=prediction.atoms
)

proj.run()

Expand All @@ -61,4 +100,12 @@ def test_n_train_2_model(tmp_path, get_md22_stachyose):
atoms.calc = model.get_calculator()

assert atoms.get_potential_energy() < 0
assert atoms.calc.results["energy_uncertainty"] > 0

uncertainty_selection.load()
kernel_selection.load()
md.load()

uncertainties = [x.calc.results["energy_uncertainty"] for x in md.atoms]
assert [md.atoms[np.argmax(uncertainties)]] == uncertainty_selection.atoms

assert len(kernel_selection.atoms) == selection_batch_size
2 changes: 0 additions & 2 deletions tests/regression_tests/apax_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ data:

shift_method: "per_element_regression_shift"
shift_options: {"energy_regularisation": 1.0}
shuffle_buffer_size: 1000

pos_unit: Ang
energy_unit: eV
Expand Down Expand Up @@ -73,7 +72,6 @@ optimizer:
scale_lr: 0.001
shift_lr: 0.05
zbl_lr: 0.001
transition_begin: 0

callbacks:
- name: csv
Expand Down
Loading