diff --git a/apax/nodes/__init__.py b/apax/nodes/__init__.py index 7c326b11..ebc5d418 100644 --- a/apax/nodes/__init__.py +++ b/apax/nodes/__init__.py @@ -5,7 +5,9 @@ try: from .analysis import ApaxBatchPrediction # noqa: F401 + from .selection import BatchKernelSelection # noqa: F401 __all__.append("ApaxBatchPrediction") + __all__.append("BatchKernelSelection") except ImportError: pass diff --git a/apax/nodes/selection.py b/apax/nodes/selection.py new file mode 100644 index 00000000..df2f12b0 --- /dev/null +++ b/apax/nodes/selection.py @@ -0,0 +1,85 @@ +import logging +import typing + +import ase.io +import numpy as np +import zntrack.utils +from ipsuite.analysis.ensemble import plot_with_uncertainty +from ipsuite.configuration_selection.base import BatchConfigurationSelection +from matplotlib import pyplot as plt + +from apax.bal import kernel_selection +from apax.nodes.model import ApaxBase + +log = logging.getLogger(__name__) + + +class BatchKernelSelection(BatchConfigurationSelection): + """Interface to the batch active learning methods implemented in apax. + Check the apax documentation for a list and explanation of implemented properties. + + Attributes + ---------- + models: Union[Apax, List[Apax]] + One or more Apax models to construct a feature map from. + base_feature_map: dict + Name and parameters for the feature map transformation. + selection_method: str + Name of the selection method to be used. Choose from: + ["max_dist", ] + n_configurations: int + Number of samples to be selected. + processing_batch_size: int + Number of samples to be processed in parallel. + Does not affect the result, just the speed of computing features. + """ + + _module_ = "apax.nodes" + + models: typing.List[ApaxBase] = zntrack.deps() + base_feature_map: dict = zntrack.params({"name": "ll_grad", "layer_name": "dense_2"}) + selection_method: str = zntrack.params("max_dist") + n_configurations: str = zntrack.params() + processing_batch_size: str = zntrack.meta.Text(64) + img_selection = zntrack.outs_path(zntrack.nwd / "selection.png") + + def select_atoms(self, atoms_lst: typing.List[ase.Atoms]) -> typing.List[int]: + if isinstance(self.models, list): + param_files = [m._parameter["data"]["directory"] for m in self.models] + else: + param_files = self.models._parameter["data"]["directory"] + + selected = kernel_selection( + param_files, + self.train_data, + atoms_lst, + self.base_feature_map, + self.selection_method, + selection_batch_size=self.n_configurations, + processing_batch_size=self.processing_batch_size, + ) + self._get_plot(atoms_lst, selected) + + return list(selected) + + def _get_plot(self, atoms_lst: typing.List[ase.Atoms], indices: typing.List[int]): + energies = np.array([atoms.calc.results["energy"] for atoms in atoms_lst]) + + if "energy_uncertainty" in atoms_lst[0].calc.results.keys(): + uncertainty = np.array( + [atoms.calc.results["energy_uncertainty"] for atoms in atoms_lst] + ) + fig, ax, _ = plot_with_uncertainty( + {"mean": energies, "std": uncertainty}, + ylabel="energy", + xlabel="configuration", + ) + else: + fig, ax = plt.subplots() + ax.plot(energies, label="energy") + ax.set_ylabel("energy") + ax.set_xlabel("configuration") + + ax.plot(indices, energies[indices], "x", color="red") + + fig.savefig(self.img_selection, bbox_inches="tight") diff --git a/tests/nodes/test_n_model.py b/tests/nodes/test_n_model.py index 06b13899..02fdcc22 100644 --- a/tests/nodes/test_n_model.py +++ b/tests/nodes/test_n_model.py @@ -1,13 +1,23 @@ import os import pathlib +import sys +try: + import ipsuite as ips +except ImportError: + pass + +import numpy as np +import pytest import yaml import zntrack +import apax.nodes from apax.nodes.model import Apax, ApaxEnsemble from apax.nodes.utils import AddData CONFIG_PATH = pathlib.Path(__file__).parent / "example.yaml" +TEST_PATH = pathlib.Path(__file__).parent.resolve() def save_config_with_seed(path: str, seed: int = 1) -> None: @@ -39,6 +49,7 @@ def test_n_train_model(tmp_path, get_md22_stachyose): assert atoms.get_potential_energy() < 0 +@pytest.mark.skipif("ipsuite" not in sys.modules, reason="requires new ipsuite release") def test_n_train_2_model(tmp_path, get_md22_stachyose): os.chdir(tmp_path) @@ -46,11 +57,39 @@ def test_n_train_2_model(tmp_path, get_md22_stachyose): save_config_with_seed(tmp_path / "example2.yaml", seed=2) proj = zntrack.Project(automatic_node_names=True) + thermostat = ips.calculators.LangevinThermostat( + time_step=1.0, temperature=100.0, friction=0.01 + ) with proj: data = AddData(file=get_md22_stachyose) model1 = Apax(data=data.atoms, validation_data=data.atoms, config="example.yaml") model2 = Apax(data=data.atoms, validation_data=data.atoms, config="example2.yaml") ensemble = ApaxEnsemble(models=[model1, model2]) + md = ips.calculators.ASEMD( + data=data.atoms, + model=ensemble, + thermostat=thermostat, + steps=20, + sampling_rate=1, + ) + + uncertainty_selection = ips.configuration_selection.ThresholdSelection( + data=md, n_configurations=1, threshold=0.0001 + ) + + selection_batch_size = 3 + kernel_selection = apax.nodes.BatchKernelSelection( + data=md.atoms, + train_data=data.atoms, + models=[model1, model2], + n_configurations=selection_batch_size, + processing_batch_size=4, + ) + + prediction = ips.analysis.Prediction(data=kernel_selection.atoms, model=ensemble) + analysis = ips.analysis.PredictionMetrics( + x=kernel_selection.atoms, y=prediction.atoms + ) proj.run() @@ -61,4 +100,12 @@ def test_n_train_2_model(tmp_path, get_md22_stachyose): atoms.calc = model.get_calculator() assert atoms.get_potential_energy() < 0 - assert atoms.calc.results["energy_uncertainty"] > 0 + + uncertainty_selection.load() + kernel_selection.load() + md.load() + + uncertainties = [x.calc.results["energy_uncertainty"] for x in md.atoms] + assert [md.atoms[np.argmax(uncertainties)]] == uncertainty_selection.atoms + + assert len(kernel_selection.atoms) == selection_batch_size diff --git a/tests/regression_tests/apax_config.yaml b/tests/regression_tests/apax_config.yaml index 41760757..1c84b4ce 100644 --- a/tests/regression_tests/apax_config.yaml +++ b/tests/regression_tests/apax_config.yaml @@ -15,7 +15,6 @@ data: shift_method: "per_element_regression_shift" shift_options: {"energy_regularisation": 1.0} - shuffle_buffer_size: 1000 pos_unit: Ang energy_unit: eV @@ -73,7 +72,6 @@ optimizer: scale_lr: 0.001 shift_lr: 0.05 zbl_lr: 0.001 - transition_begin: 0 callbacks: - name: csv