From 8e0c95d77b9b7de4919a257b1caf68ac2b0ad784 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 6 Aug 2024 12:35:03 +0200 Subject: [PATCH 1/8] moved BAL selection node from ips to apax --- apax/nodes/__init__.py | 2 + apax/nodes/selection.py | 89 +++++++++++++++++++++++++++++++++++++ tests/nodes/test_n_model.py | 40 ++++++++++++++++- 3 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 apax/nodes/selection.py diff --git a/apax/nodes/__init__.py b/apax/nodes/__init__.py index 7c326b11..ebc5d418 100644 --- a/apax/nodes/__init__.py +++ b/apax/nodes/__init__.py @@ -5,7 +5,9 @@ try: from .analysis import ApaxBatchPrediction # noqa: F401 + from .selection import BatchKernelSelection # noqa: F401 __all__.append("ApaxBatchPrediction") + __all__.append("BatchKernelSelection") except ImportError: pass diff --git a/apax/nodes/selection.py b/apax/nodes/selection.py new file mode 100644 index 00000000..078d0030 --- /dev/null +++ b/apax/nodes/selection.py @@ -0,0 +1,89 @@ +import logging +import typing + +import ase.io +import numpy as np +import zntrack.utils +from ipsuite.analysis.ensemble import plot_with_uncertainty +from ipsuite.configuration_selection.base import BatchConfigurationSelection +from ipsuite.utils.combine import get_flat_data_from_dict +from matplotlib import pyplot as plt + +from apax.bal import kernel_selection +from apax.nodes.model import ApaxBase + +log = logging.getLogger(__name__) + + +class BatchKernelSelection(BatchConfigurationSelection): + """Interface to the batch active learning methods implemented in apax. + Check the apax documentation for a list and explanation of implemented properties. + + Attributes + ---------- + models: Union[Apax, List[Apax]] + One or more Apax models to construct a feature map from. + base_feature_map: dict + Name and parameters for the feature map transformation. + selection_method: str + Name of the selection method to be used. Choose from: + ["max_dist", ] + n_configurations: int + Number of samples to be selected. + processing_batch_size: int + Number of samples to be processed in parallel. + Does not affect the result, just the speed of computing features. + """ + + _module_ = "apax.nodes" + + models: typing.List[ApaxBase] = zntrack.deps() + base_feature_map: dict = zntrack.params({"name": "ll_grad", "layer_name": "dense_2"}) + selection_method: str = zntrack.params("max_dist") + n_configurations: str = zntrack.params() + processing_batch_size: str = zntrack.meta.Text(64) + img_selection = zntrack.outs_path(zntrack.nwd / "selection.png") + + def select_atoms(self, atoms_lst: typing.List[ase.Atoms]) -> typing.List[int]: + if isinstance(self.models, list): + param_files = [m._parameter["data"]["directory"] for m in self.models] + else: + param_files = self.models._parameter["data"]["directory"] + + if isinstance(self.train_data, dict): + self.train_data = get_flat_data_from_dict(self.train_data) + + selected = kernel_selection( + param_files, + self.train_data, + atoms_lst, + self.base_feature_map, + self.selection_method, + selection_batch_size=self.n_configurations, + processing_batch_size=self.processing_batch_size, + ) + self._get_plot(atoms_lst, selected) + + return list(selected) + + def _get_plot(self, atoms_lst: typing.List[ase.Atoms], indices: typing.List[int]): + energies = np.array([atoms.calc.results["energy"] for atoms in atoms_lst]) + + if "energy_uncertainty" in atoms_lst[0].calc.results.keys(): + uncertainty = np.array( + [atoms.calc.results["energy_uncertainty"] for atoms in atoms_lst] + ) + fig, ax, _ = plot_with_uncertainty( + {"mean": energies, "std": uncertainty}, + ylabel="energy", + xlabel="configuration", + ) + else: + fig, ax = plt.subplots() + ax.plot(energies, label="energy") + ax.set_ylabel("energy") + ax.set_xlabel("configuration") + + ax.plot(indices, energies[indices], "x", color="red") + + fig.savefig(self.img_selection, bbox_inches="tight") diff --git a/tests/nodes/test_n_model.py b/tests/nodes/test_n_model.py index 06b13899..9155da7e 100644 --- a/tests/nodes/test_n_model.py +++ b/tests/nodes/test_n_model.py @@ -1,13 +1,17 @@ import os import pathlib +import ipsuite as ips +import numpy as np import yaml import zntrack +import apax.nodes from apax.nodes.model import Apax, ApaxEnsemble from apax.nodes.utils import AddData CONFIG_PATH = pathlib.Path(__file__).parent / "example.yaml" +TEST_PATH = pathlib.Path(__file__).parent.resolve() def save_config_with_seed(path: str, seed: int = 1) -> None: @@ -46,11 +50,37 @@ def test_n_train_2_model(tmp_path, get_md22_stachyose): save_config_with_seed(tmp_path / "example2.yaml", seed=2) proj = zntrack.Project(automatic_node_names=True) + thermostat = ips.calculators.LangevinThermostat( + time_step=1.0, temperature=100.0, friction=0.01 + ) with proj: data = AddData(file=get_md22_stachyose) model1 = Apax(data=data.atoms, validation_data=data.atoms, config="example.yaml") model2 = Apax(data=data.atoms, validation_data=data.atoms, config="example2.yaml") ensemble = ApaxEnsemble(models=[model1, model2]) + md = ips.calculators.ASEMD( + data=data.atoms, + model=ensemble, + thermostat=thermostat, + steps=20, + sampling_rate=1, + ) + + uncertainty_selection = ips.configuration_selection.ThresholdSelection( + data=md, n_configurations=1, threshold=0.0001 + ) + + selection_batch_size = 3 + kernel_selection = apax.nodes.BatchKernelSelection( + data=md.atoms, + train_data=data.atoms, + models=[model1, model2], + n_configurations=selection_batch_size, + processing_batch_size=4, + ) + + prediction = ips.analysis.Prediction(data=kernel_selection.atoms, model=ensemble) + analysis = ips.analysis.PredictionMetrics(x=kernel_selection.atoms, y=prediction.atoms) proj.run() @@ -61,4 +91,12 @@ def test_n_train_2_model(tmp_path, get_md22_stachyose): atoms.calc = model.get_calculator() assert atoms.get_potential_energy() < 0 - assert atoms.calc.results["energy_uncertainty"] > 0 + + uncertainty_selection.load() + kernel_selection.load() + md.load() + + uncertainties = [x.calc.results["energy_uncertainty"] for x in md.atoms] + assert [md.atoms[np.argmax(uncertainties)]] == uncertainty_selection.atoms + + assert len(kernel_selection.atoms) == selection_batch_size From 2107ff5cd1a826ea710608290e1395fe417e42e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 10:43:32 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/nodes/selection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/nodes/selection.py b/apax/nodes/selection.py index 078d0030..4f664d58 100644 --- a/apax/nodes/selection.py +++ b/apax/nodes/selection.py @@ -34,7 +34,7 @@ class BatchKernelSelection(BatchConfigurationSelection): Number of samples to be processed in parallel. Does not affect the result, just the speed of computing features. """ - + _module_ = "apax.nodes" models: typing.List[ApaxBase] = zntrack.deps() From f2041cd87ab9500735628d8096c5edab1b5a895e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 6 Aug 2024 15:28:03 +0200 Subject: [PATCH 3/8] disable ips test on gh actions --- tests/nodes/test_n_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/nodes/test_n_model.py b/tests/nodes/test_n_model.py index 9155da7e..5c92ae82 100644 --- a/tests/nodes/test_n_model.py +++ b/tests/nodes/test_n_model.py @@ -1,8 +1,10 @@ import os import pathlib +import sys import ipsuite as ips import numpy as np +import pytest import yaml import zntrack @@ -43,6 +45,8 @@ def test_n_train_model(tmp_path, get_md22_stachyose): assert atoms.get_potential_energy() < 0 +@pytest.mark.skipif('ipsuite' not in sys.modules, + reason="requires new ipsuite release") def test_n_train_2_model(tmp_path, get_md22_stachyose): os.chdir(tmp_path) From 714922bcadaac722645a0079080be935b04b4c86 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 13:28:14 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/nodes/test_n_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/nodes/test_n_model.py b/tests/nodes/test_n_model.py index 5c92ae82..29800869 100644 --- a/tests/nodes/test_n_model.py +++ b/tests/nodes/test_n_model.py @@ -45,8 +45,7 @@ def test_n_train_model(tmp_path, get_md22_stachyose): assert atoms.get_potential_energy() < 0 -@pytest.mark.skipif('ipsuite' not in sys.modules, - reason="requires new ipsuite release") +@pytest.mark.skipif("ipsuite" not in sys.modules, reason="requires new ipsuite release") def test_n_train_2_model(tmp_path, get_md22_stachyose): os.chdir(tmp_path) @@ -84,7 +83,9 @@ def test_n_train_2_model(tmp_path, get_md22_stachyose): ) prediction = ips.analysis.Prediction(data=kernel_selection.atoms, model=ensemble) - analysis = ips.analysis.PredictionMetrics(x=kernel_selection.atoms, y=prediction.atoms) + analysis = ips.analysis.PredictionMetrics( + x=kernel_selection.atoms, y=prediction.atoms + ) proj.run() From da2b8362510b5210864c0bc2e0ea377d2f4926d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Thu, 8 Aug 2024 16:35:32 +0200 Subject: [PATCH 5/8] fixed tests --- apax/nodes/selection.py | 3 --- tests/regression_tests/apax_config.yaml | 2 -- 2 files changed, 5 deletions(-) diff --git a/apax/nodes/selection.py b/apax/nodes/selection.py index 4f664d58..afe92cd0 100644 --- a/apax/nodes/selection.py +++ b/apax/nodes/selection.py @@ -50,9 +50,6 @@ def select_atoms(self, atoms_lst: typing.List[ase.Atoms]) -> typing.List[int]: else: param_files = self.models._parameter["data"]["directory"] - if isinstance(self.train_data, dict): - self.train_data = get_flat_data_from_dict(self.train_data) - selected = kernel_selection( param_files, self.train_data, diff --git a/tests/regression_tests/apax_config.yaml b/tests/regression_tests/apax_config.yaml index 41760757..1c84b4ce 100644 --- a/tests/regression_tests/apax_config.yaml +++ b/tests/regression_tests/apax_config.yaml @@ -15,7 +15,6 @@ data: shift_method: "per_element_regression_shift" shift_options: {"energy_regularisation": 1.0} - shuffle_buffer_size: 1000 pos_unit: Ang energy_unit: eV @@ -73,7 +72,6 @@ optimizer: scale_lr: 0.001 shift_lr: 0.05 zbl_lr: 0.001 - transition_begin: 0 callbacks: - name: csv From ca5d6a28f89c4d1c6342022a7f1c2f84d5ed9392 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Aug 2024 14:35:48 +0000 Subject: [PATCH 6/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/nodes/selection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apax/nodes/selection.py b/apax/nodes/selection.py index afe92cd0..df2f12b0 100644 --- a/apax/nodes/selection.py +++ b/apax/nodes/selection.py @@ -6,7 +6,6 @@ import zntrack.utils from ipsuite.analysis.ensemble import plot_with_uncertainty from ipsuite.configuration_selection.base import BatchConfigurationSelection -from ipsuite.utils.combine import get_flat_data_from_dict from matplotlib import pyplot as plt from apax.bal import kernel_selection From e7c1b1bc1e06ee3b1e1da5381e24f668f8f64737 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Thu, 8 Aug 2024 17:01:07 +0200 Subject: [PATCH 7/8] remove import --- apax/nodes/selection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apax/nodes/selection.py b/apax/nodes/selection.py index afe92cd0..df2f12b0 100644 --- a/apax/nodes/selection.py +++ b/apax/nodes/selection.py @@ -6,7 +6,6 @@ import zntrack.utils from ipsuite.analysis.ensemble import plot_with_uncertainty from ipsuite.configuration_selection.base import BatchConfigurationSelection -from ipsuite.utils.combine import get_flat_data_from_dict from matplotlib import pyplot as plt from apax.bal import kernel_selection From dd04b0e2e2c2da59d0c574f47afa382af4afdfa8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Thu, 8 Aug 2024 17:02:40 +0200 Subject: [PATCH 8/8] try except ipsuite import --- tests/nodes/test_n_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/nodes/test_n_model.py b/tests/nodes/test_n_model.py index 29800869..02fdcc22 100644 --- a/tests/nodes/test_n_model.py +++ b/tests/nodes/test_n_model.py @@ -2,7 +2,11 @@ import pathlib import sys -import ipsuite as ips +try: + import ipsuite as ips +except ImportError: + pass + import numpy as np import pytest import yaml