Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch evaluation #233

Merged
merged 30 commits into from
Feb 28, 2024
Merged

Batch evaluation #233

merged 30 commits into from
Feb 28, 2024

Conversation

M-R-Schaefer
Copy link
Contributor

Added a batch_eval method to the ASE calculator that processes data in the same way as we do during training. That means padding of inputs and unpading of outputs.
This should drastically accelerate the evaluation of whole datasets, especially if they consist of differently sized structures (which would trigger recompilations until now).

@M-R-Schaefer M-R-Schaefer added the enhancement New feature or request label Feb 26, 2024
@@ -216,6 +230,40 @@ def calculate(self, atoms, properties=["energy"], system_changes=all_changes):
self.results = {k: np.array(v, dtype=np.float64) for k, v in results.items()}
self.results["energy"] = self.results["energy"].item()

def batch_eval(self, data, batch_size=64, silent=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps have data actually as atoms_list since that's what appears to be its role here, unless there are other conventions elsewhere in apax that use data? (in which case changes there may also be nice).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably a minimal docstring since this is a (likely common-use-case) user-facing method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair points, I"ll adjust the name and add some type hints + docs.

@Chronum94
Copy link
Collaborator

Chronum94 commented Feb 26, 2024

Running into an issue where, if the batch size isn't a perfect divisor of the number of evaluation samples, I get a list index out of range error.

So for example, if I have 10 atoms objects in the list, then batch sizes of 1, 2, 5, 10, work as intended, all other numbers <10 do not.

~/apax/apax/md/ase_calc.py:261, in ASECalculator.batch_eval(self, data, batch_size, silent)
    259 for j in range(batch_size):
    260     atoms = data[i].copy()
--> 261     atoms.calc = SinglePointCalculator(atoms=atoms, **unpadded_results[j])
    262     evaluated_data.append(atoms)
    263 pbar.update(batch_size)

This isn't really a problem for really large batched evals unless of course one has a dataset sized at a very large number that is also very coincidentally a prime number but either a fix, or documentation of this limitation (?feature?) would be good.

@Chronum94
Copy link
Collaborator

Potentially minor thing: Right now, the call signature looks something like so:

calc = ASECalculator(...)
...
calc.batch_eval(atoms_list)

But also requires that all of the atoms objects have a calculator attached. If there is nothing attached, we get the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[43], line 1
----> 1 calc.batch_eval(atoms_list, 10)

File ~/apax/apax/md/ase_calc.py:236, in ASECalculator.batch_eval(self, data, batch_size, silent)
    234 if self.model is None:
    235     self.initialize(data[0])
--> 236 dataset = initialize_dataset(
    237     self.model_config, RawDataset(atoms_list=data), calc_stats=False
    238 )
    239 dataset.set_batch_size(batch_size)
    241 evaluated_data = []

File ~/apax/apax/data/initialization.py:56, in initialize_dataset(config, raw_ds, calc_stats)
     55 def initialize_dataset(config, raw_ds, calc_stats: bool = True):
---> 56     inputs, labels = create_dict_dataset(
     57         raw_ds.atoms_list,
     58         r_max=config.model.r_max,
     59         external_labels=raw_ds.additional_labels,
     60         disable_pbar=config.progress_bar.disable_nl_pbar,
     61         pos_unit=config.data.pos_unit,
     62         energy_unit=config.data.energy_unit,
     63     )
     65     if calc_stats:
     66         ds_stats = compute_scale_shift_parameters(
     67             inputs,
     68             labels,
   (...)
     72             config.data.scale_options,
     73         )

File ~/apax/apax/data/input_pipeline.py:102, in create_dict_dataset(atoms_list, r_max, external_labels, disable_pbar, pos_unit, energy_unit)
     94 def create_dict_dataset(
     95     atoms_list: list,
     96     r_max: float,
   (...)
    100     energy_unit: str = "eV",
    101 ) -> tuple[dict]:
--> 102     inputs, labels = atoms_to_arrays(atoms_list, pos_unit, energy_unit)
    104     if external_labels:
    105         for shape, label in external_labels.items():

File ~/apax/apax/utils/convert.py:112, in atoms_to_arrays(atoms_list, pos_unit, energy_unit)
    110 inputs["ragged"]["numbers"].append(atoms.numbers)
    111 inputs["fixed"]["n_atoms"].append(len(atoms))
--> 112 for key, val in atoms.calc.results.items():
    113     if key == "forces":
    114         labels["ragged"][key].append(
    115             val * unit_dict[energy_unit] / unit_dict[pos_unit]
    116         )

AttributeError: 'NoneType' object has no attribute 'results'

@M-R-Schaefer
Copy link
Contributor Author

Potentially minor thing: Right now, the call signature looks something like so:

calc = ASECalculator(...)
...
calc.batch_eval(atoms_list)

But also requires that all of the atoms objects have a calculator attached. If there is nothing attached, we get the following error:

This required a non trivial overhaul of the input pipeline. I have implemented a draft, but I'll consult a colleague to see what he thinks. I don't particularly like it and the tests would need to be adjusted.
I'll see if there is a more sensible option.

@M-R-Schaefer
Copy link
Contributor Author

pre-commit.ci autofix

@M-R-Schaefer M-R-Schaefer merged commit 047ae66 into dev Feb 28, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants