-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batch evaluation #233
Batch evaluation #233
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
apax/md/ase_calc.py
Outdated
@@ -216,6 +230,40 @@ def calculate(self, atoms, properties=["energy"], system_changes=all_changes): | |||
self.results = {k: np.array(v, dtype=np.float64) for k, v in results.items()} | |||
self.results["energy"] = self.results["energy"].item() | |||
|
|||
def batch_eval(self, data, batch_size=64, silent=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps have data
actually as atoms_list
since that's what appears to be its role here, unless there are other conventions elsewhere in apax that use data
? (in which case changes there may also be nice).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably a minimal docstring since this is a (likely common-use-case) user-facing method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fair points, I"ll adjust the name and add some type hints + docs.
Running into an issue where, if the batch size isn't a perfect divisor of the number of evaluation samples, I get a list index out of range error. So for example, if I have 10 atoms objects in the list, then batch sizes of 1, 2, 5, 10, work as intended, all other numbers <10 do not. ~/apax/apax/md/ase_calc.py:261, in ASECalculator.batch_eval(self, data, batch_size, silent)
259 for j in range(batch_size):
260 atoms = data[i].copy()
--> 261 atoms.calc = SinglePointCalculator(atoms=atoms, **unpadded_results[j])
262 evaluated_data.append(atoms)
263 pbar.update(batch_size) This isn't really a problem for really large batched evals unless of course one has a dataset sized at a very large number that is also very coincidentally a prime number but either a fix, or documentation of this limitation (?feature?) would be good. |
Potentially minor thing: Right now, the call signature looks something like so: calc = ASECalculator(...)
...
calc.batch_eval(atoms_list) But also requires that all of the atoms objects have a calculator attached. If there is nothing attached, we get the following error: ---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[43], line 1
----> 1 calc.batch_eval(atoms_list, 10)
File ~/apax/apax/md/ase_calc.py:236, in ASECalculator.batch_eval(self, data, batch_size, silent)
234 if self.model is None:
235 self.initialize(data[0])
--> 236 dataset = initialize_dataset(
237 self.model_config, RawDataset(atoms_list=data), calc_stats=False
238 )
239 dataset.set_batch_size(batch_size)
241 evaluated_data = []
File ~/apax/apax/data/initialization.py:56, in initialize_dataset(config, raw_ds, calc_stats)
55 def initialize_dataset(config, raw_ds, calc_stats: bool = True):
---> 56 inputs, labels = create_dict_dataset(
57 raw_ds.atoms_list,
58 r_max=config.model.r_max,
59 external_labels=raw_ds.additional_labels,
60 disable_pbar=config.progress_bar.disable_nl_pbar,
61 pos_unit=config.data.pos_unit,
62 energy_unit=config.data.energy_unit,
63 )
65 if calc_stats:
66 ds_stats = compute_scale_shift_parameters(
67 inputs,
68 labels,
(...)
72 config.data.scale_options,
73 )
File ~/apax/apax/data/input_pipeline.py:102, in create_dict_dataset(atoms_list, r_max, external_labels, disable_pbar, pos_unit, energy_unit)
94 def create_dict_dataset(
95 atoms_list: list,
96 r_max: float,
(...)
100 energy_unit: str = "eV",
101 ) -> tuple[dict]:
--> 102 inputs, labels = atoms_to_arrays(atoms_list, pos_unit, energy_unit)
104 if external_labels:
105 for shape, label in external_labels.items():
File ~/apax/apax/utils/convert.py:112, in atoms_to_arrays(atoms_list, pos_unit, energy_unit)
110 inputs["ragged"]["numbers"].append(atoms.numbers)
111 inputs["fixed"]["n_atoms"].append(len(atoms))
--> 112 for key, val in atoms.calc.results.items():
113 if key == "forces":
114 labels["ragged"][key].append(
115 val * unit_dict[energy_unit] / unit_dict[pos_unit]
116 )
AttributeError: 'NoneType' object has no attribute 'results' |
for more information, see https://pre-commit.ci
This required a non trivial overhaul of the input pipeline. I have implemented a draft, but I'll consult a colleague to see what he thinks. I don't particularly like it and the tests would need to be adjusted. |
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
pre-commit.ci autofix |
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Added a
batch_eval
method to the ASE calculator that processes data in the same way as we do during training. That means padding of inputs and unpading of outputs.This should drastically accelerate the evaluation of whole datasets, especially if they consist of differently sized structures (which would trigger recompilations until now).