Skip to content

Commit

Permalink
Merge pull request #233 from apax-hub/batch_eval
Browse files Browse the repository at this point in the history
Batch evaluation
  • Loading branch information
M-R-Schaefer authored Feb 28, 2024
2 parents 864558c + 9db40cb commit 047ae66
Show file tree
Hide file tree
Showing 13 changed files with 283 additions and 181 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ main.py
tmp/
.npz
.traj
.h5
events.out.*

# Translations
Expand Down
3 changes: 3 additions & 0 deletions apax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
jax_config.update("jax_enable_x64", True)
from apax.utils.helpers import setup_ase

setup_ase()
5 changes: 2 additions & 3 deletions apax/bal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from tqdm import trange

from apax.bal import feature_maps, kernel, selection, transforms
from apax.data.initialization import RawDataset
from apax.data.input_pipeline import AtomisticDataset
from apax.model.builder import ModelBuilder
from apax.model.gmnn import EnergyModel
Expand Down Expand Up @@ -55,7 +54,7 @@ def compute_features(feature_fn, dataset: AtomisticDataset):
ds = dataset.batch()

pbar = trange(n_data, desc="Computing features", ncols=100, leave=True)
for i, (inputs, _) in enumerate(ds):
for inputs in ds:
g = feature_fn(inputs)
features.append(np.asarray(g))
pbar.update(g.shape[0])
Expand Down Expand Up @@ -88,7 +87,7 @@ def kernel_selection(

n_train = len(train_atoms)
dataset = initialize_dataset(
config, RawDataset(atoms_list=train_atoms + pool_atoms), calc_stats=False
config, train_atoms + pool_atoms, read_labels=False, calc_stats=False
)
dataset.set_batch_size(processing_batch_size)

Expand Down
3 changes: 3 additions & 0 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class DataConfig(BaseModel, extra="forbid"):
batch_size: Number of training examples to be evaluated at once.
valid_batch_size: Number of validation examples to be evaluated at once.
shuffle_buffer_size: Size of the `tf.data` shuffle buffer.
additional_properties_info:
dict of property name, shape (ragged or fixed) pairs
energy_regularisation: Magnitude of the regularization in the per-element
energy regression.
"""
Expand All @@ -58,6 +60,7 @@ class DataConfig(BaseModel, extra="forbid"):
batch_size: PositiveInt = 32
valid_batch_size: PositiveInt = 100
shuffle_buffer_size: PositiveInt = 1000
additional_properties_info: dict[str, str] = {}

shift_method: str = "per_element_regression_shift"
shift_options: dict = {"energy_regularisation": 1.0}
Expand Down
52 changes: 26 additions & 26 deletions apax/data/initialization.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,25 @@
import dataclasses
import logging
from typing import Optional

import numpy as np
from ase import Atoms

from apax.data.input_pipeline import AtomisticDataset, create_dict_dataset
from apax.data.input_pipeline import AtomisticDataset, process_inputs
from apax.data.statistics import compute_scale_shift_parameters
from apax.utils.data import load_data, split_atoms, split_idxs, split_label
from apax.utils.convert import atoms_to_labels
from apax.utils.data import load_data, split_atoms, split_idxs

log = logging.getLogger(__name__)


@dataclasses.dataclass
class RawDataset:
atoms_list: list[Atoms]
additional_labels: Optional[dict] = None


def load_data_files(data_config):
log.info("Running Input Pipeline")
if data_config.data_path is not None:
log.info(f"Read data file {data_config.data_path}")
atoms_list, label_dict = load_data(data_config.data_path)
atoms_list = load_data(data_config.data_path)

train_idxs, val_idxs = split_idxs(
atoms_list, data_config.n_train, data_config.n_valid
)
train_atoms_list, val_atoms_list = split_atoms(atoms_list, train_idxs, val_idxs)
train_label_dict, val_label_dict = split_label(label_dict, train_idxs, val_idxs)

np.savez(
data_config.model_version_path / "train_val_idxs",
Expand All @@ -39,26 +30,35 @@ def load_data_files(data_config):
elif data_config.train_data_path and data_config.val_data_path is not None:
log.info(f"Read training data file {data_config.train_data_path}")
log.info(f"Read validation data file {data_config.val_data_path}")
train_atoms_list, train_label_dict = load_data(data_config.train_data_path)
val_atoms_list, val_label_dict = load_data(data_config.val_data_path)
train_atoms_list = load_data(data_config.train_data_path)
val_atoms_list = load_data(data_config.val_data_path)
else:
raise ValueError("input data path/paths not defined")

train_raw_ds = RawDataset(
atoms_list=train_atoms_list, additional_labels=train_label_dict
)
val_raw_ds = RawDataset(atoms_list=val_atoms_list, additional_labels=val_label_dict)

return train_raw_ds, val_raw_ds
return train_atoms_list, val_atoms_list


def initialize_dataset(config, raw_ds, calc_stats: bool = True):
inputs, labels = create_dict_dataset(
raw_ds.atoms_list,
def initialize_dataset(
config,
atoms_list,
read_labels: bool = True,
calc_stats: bool = True,
):
if calc_stats and not read_labels:
raise ValueError(
"Cannot calculate scale/shift parameters without reading labels."
)
inputs = process_inputs(
atoms_list,
r_max=config.model.r_max,
external_labels=raw_ds.additional_labels,
disable_pbar=config.progress_bar.disable_nl_pbar,
pos_unit=config.data.pos_unit,
)
labels = atoms_to_labels(
atoms_list,
additional_properties_info=config.data.additional_properties_info,
read_labels=read_labels,
pos_unit=config.data.pos_unit,
energy_unit=config.data.energy_unit,
)

Expand All @@ -74,8 +74,8 @@ def initialize_dataset(config, raw_ds, calc_stats: bool = True):

dataset = AtomisticDataset(
inputs,
labels,
config.n_epochs,
labels=labels,
buffer_size=config.data.shuffle_buffer_size,
)

Expand Down
92 changes: 48 additions & 44 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging
from typing import Dict, Iterator
from typing import Dict, Iterator, Optional

import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf

from apax.data.preprocessing import dataset_neighborlist, prefetch_to_single_device
from apax.utils.convert import atoms_to_arrays
from apax.utils.convert import atoms_to_inputs

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -35,17 +35,15 @@ def __init__(self, max_atoms: int, max_nbrs: int) -> None:
self.max_atoms = max_atoms
self.max_nbrs = max_nbrs

def __call__(
self, r_inputs: dict, f_inputs: dict, r_labels: dict, f_labels: dict
) -> tuple[dict, dict]:
def __call__(self, inputs: dict, labels: dict = None) -> tuple[dict, dict]:
"""
Arguments
---------
r_inputs :
Inputs of ragged shape. Untrainable system-determining properties.
Inputs of ragged shape.
f_inputs :
Inputs of fixed shape. Untrainable system-determining properties.
Inputs of fixed shape.
r_labels :
Labels of ragged shape. Trainable system properties.
f_labels :
Expand All @@ -58,6 +56,8 @@ def __call__(
labels:
Contains all labels and all entries are uniformly shaped.
"""
r_inputs = inputs["ragged"]
f_inputs = inputs["fixed"]
for key, val in r_inputs.items():
if self.max_atoms is None:
r_inputs[key] = val.to_tensor()
Expand All @@ -75,36 +75,34 @@ def __call__(
padded_shape = [shape[0], self.max_atoms, shape[2]] # batch, atoms, 3
r_inputs[key] = val.to_tensor(shape=padded_shape)

for key, val in r_labels.items():
if self.max_atoms is None:
r_labels[key] = val.to_tensor()
else:
padded_shape = [shape[0], self.max_atoms, shape[2]]
r_labels[key] = val.to_tensor(default_value=0.0, shape=padded_shape)
new_inputs = r_inputs.copy()
new_inputs.update(f_inputs)

inputs = r_inputs.copy()
inputs.update(f_inputs)
if labels:
r_labels = labels["ragged"]
f_labels = labels["fixed"]
for key, val in r_labels.items():
if self.max_atoms is None:
r_labels[key] = val.to_tensor()
else:
padded_shape = [shape[0], self.max_atoms, shape[2]]
r_labels[key] = val.to_tensor(default_value=0.0, shape=padded_shape)

labels = r_labels.copy()
labels.update(f_labels)
new_labels = r_labels.copy()
new_labels.update(f_labels)

return inputs, labels
return new_inputs, new_labels
else:
return new_inputs


def create_dict_dataset(
def process_inputs(
atoms_list: list,
r_max: float,
external_labels: dict = {},
disable_pbar=False,
pos_unit: str = "Ang",
energy_unit: str = "eV",
) -> tuple[dict]:
inputs, labels = atoms_to_arrays(atoms_list, pos_unit, energy_unit)

if external_labels:
for shape, label in external_labels.items():
labels[shape].update(label)

) -> dict:
inputs = atoms_to_inputs(atoms_list, pos_unit)
idx, offsets = dataset_neighborlist(
inputs["ragged"]["positions"],
box=inputs["fixed"]["box"],
Expand All @@ -115,11 +113,11 @@ def create_dict_dataset(

inputs["ragged"]["idx"] = idx
inputs["ragged"]["offsets"] = offsets
return inputs, labels
return inputs


def dataset_from_dicts(
inputs: Dict[str, np.ndarray], labels: Dict[str, np.ndarray]
inputs: Dict[str, np.ndarray], labels: Optional[Dict[str, np.ndarray]] = None
) -> tf.data.Dataset:
# tf.RaggedTensors should be created from `tf.ragged.stack`
# instead of `tf.ragged.constant` for performance reasons.
Expand All @@ -129,17 +127,18 @@ def dataset_from_dicts(
for key, val in inputs["fixed"].items():
inputs["fixed"][key] = tf.constant(val)

for key, val in labels["ragged"].items():
labels["ragged"][key] = tf.ragged.stack(val)
for key, val in labels["fixed"].items():
labels["fixed"][key] = tf.constant(val)

ds = tf.data.Dataset.from_tensor_slices((
inputs["ragged"],
inputs["fixed"],
labels["ragged"],
labels["fixed"],
))
if labels:
for key, val in labels["ragged"].items():
labels["ragged"][key] = tf.ragged.stack(val)
for key, val in labels["fixed"].items():
labels["fixed"][key] = tf.constant(val)

tensors = (inputs, labels)
else:
tensors = inputs

ds = tf.data.Dataset.from_tensor_slices(tensors)

return ds


Expand All @@ -149,8 +148,8 @@ class AtomisticDataset:
def __init__(
self,
inputs,
labels,
n_epoch: int,
labels=None,
buffer_size: int = 1000,
) -> None:
"""Processes inputs/labels and makes them accessible for training.
Expand Down Expand Up @@ -181,7 +180,10 @@ def __init__(

self.n_data = len(inputs["fixed"]["n_atoms"])

self.ds = dataset_from_dicts(inputs, labels)
if labels:
self.ds = dataset_from_dicts(inputs, labels)
else:
self.ds = dataset_from_dicts(inputs)

def set_batch_size(self, batch_size: int):
self.batch_size = self.validate_batch_size(batch_size)
Expand Down Expand Up @@ -214,12 +216,14 @@ def steps_per_epoch(self) -> int:

def init_input(self) -> Dict[str, np.ndarray]:
"""Returns first batch of inputs and labels to init the model."""
inputs, _ = next(
inputs = next(
self.ds.batch(1)
.map(PadToSpecificSize(self.max_atoms, self.max_nbrs))
.take(1)
.as_numpy_iterator()
)
if isinstance(inputs, tuple):
inputs = inputs[0] # remove labels

inputs = jax.tree_map(lambda x: jnp.array(x[0]), inputs)
init_box = np.array(inputs["box"])
Expand Down
Loading

0 comments on commit 047ae66

Please sign in to comment.