From 93a05f045051052f711d4a5d47709aa72f713db9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Thu, 4 Apr 2024 15:55:21 +0200 Subject: [PATCH 1/9] sketch of nl fix --- apax/data/input_pipeline.py | 22 +++++++++++++--------- apax/data/preprocessing.py | 12 ++++++------ 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 082415a0..5664e972 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -3,8 +3,9 @@ from collections import deque from pathlib import Path from random import shuffle -from typing import Dict, Iterator +from typing import Dict, Iterator, List +from ase import Atoms import jax import jax.numpy as jnp import numpy as np @@ -23,12 +24,12 @@ def pad_nl(idx, offsets, max_neighbors): return idx, offsets -def find_largest_system(inputs: dict[str, np.ndarray], r_max) -> tuple[int]: - max_atoms = np.max(inputs["n_atoms"]) +def find_largest_system(atoms_list: List[Atoms], r_max) -> tuple[int]: + max_atoms = np.max([atoms.numbers.shape[0] for atoms in atoms_list]) max_nbrs = 0 - for position, box in zip(inputs["positions"], inputs["box"]): - neighbor_idxs, _ = compute_nl(position, box, r_max) + for atoms in atoms_list: + neighbor_idxs, _ = compute_nl(atoms, r_max) n_neighbors = neighbor_idxs.shape[1] max_nbrs = max(max_nbrs, n_neighbors) @@ -52,13 +53,16 @@ def __init__( shuffle(atoms) self.sample_atoms = atoms[0] self.inputs = atoms_to_inputs(atoms) + self.atoms = atoms self.n_epochs = n_epochs self.buffer_size = buffer_size - max_atoms, max_nbrs = find_largest_system(self.inputs, cutoff) + max_atoms, max_nbrs = find_largest_system(atoms, cutoff) self.max_atoms = max_atoms self.max_nbrs = max_nbrs + # print(max_atoms, max_nbrs) + # quit() if atoms[0].calc and not ignore_labels: self.labels = atoms_to_labels(atoms) @@ -94,8 +98,8 @@ def validate_batch_size(self, batch_size: int) -> int: return batch_size def prepare_data(self, i): - inputs = {k: v[i] for k, v in self.inputs.items()} - idx, offsets = compute_nl(inputs["positions"], inputs["box"], self.cutoff) + inputs = {k: v[i] for k, v in self.inputs.items()} # inputs["positions"], inputs["box"] + idx, offsets = compute_nl(self.atoms[i], self.cutoff) inputs["idx"], inputs["offsets"] = pad_nl(idx, offsets, self.max_nbrs) zeros_to_add = self.max_atoms - inputs["numbers"].shape[0] @@ -166,7 +170,7 @@ def init_input(self) -> Dict[str, np.ndarray]: """Returns first batch of inputs and labels to init the model.""" positions = self.sample_atoms.positions box = self.sample_atoms.cell.array - idx, offsets = compute_nl(positions, box, self.cutoff) + idx, offsets = compute_nl(self.sample_atoms, self.cutoff) inputs = ( positions, self.sample_atoms.numbers, diff --git a/apax/data/preprocessing.py b/apax/data/preprocessing.py index b52efdbf..32827644 100644 --- a/apax/data/preprocessing.py +++ b/apax/data/preprocessing.py @@ -10,12 +10,12 @@ log = logging.getLogger(__name__) -def compute_nl(position, box, r_max): - if np.all(box < 1e-6): - cell, cell_origin = get_shrink_wrapped_cell(position) +def compute_nl(atoms, r_max): + if np.all(atoms.cell.array < 1e-6): + cell, cell_origin = get_shrink_wrapped_cell(atoms.positions) idxs_i, idxs_j = neighbour_list( "ij", - positions=position, + positions=atoms.positions, cutoff=r_max, cell=cell, cell_origin=cell_origin, @@ -30,11 +30,11 @@ def compute_nl(position, box, r_max): else: idxs_i, idxs_j, offsets = neighbour_list( "ijS", - positions=position, + atoms = atoms, cutoff=r_max, - cell=box, ) neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32) + box = atoms.cell.array offsets = np.matmul(offsets, box) return neighbor_idxs, offsets From 3fb591791258114dd390c02c7f3b6f35415c24f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Thu, 4 Apr 2024 16:33:25 +0200 Subject: [PATCH 2/9] fixed periodicity check bug --- apax/utils/convert.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/apax/utils/convert.py b/apax/utils/convert.py index b6bfd76e..07d252fb 100644 --- a/apax/utils/convert.py +++ b/apax/utils/convert.py @@ -38,6 +38,14 @@ def prune_dict(data_dict): pruned = {key: val for key, val in data_dict.items() if len(val) != 0} return pruned +def is_periodic(box): + pbc_dims = np.any(np.abs(box) > 1e-6) + if np.all(pbc_dims == True) or np.all(pbc_dims == False): + return pbc_dims + else: + msg = f"Only 3D periodic and gas phase system supported at the moment. Found {box}" + raise ValueError(msg) + def atoms_to_inputs( atoms_list: list[Atoms], @@ -67,19 +75,21 @@ def atoms_to_inputs( } box = atoms_list[0].cell.array - pbc = np.all(box > 1e-6) + pbc = is_periodic(box) for atoms in atoms_list: box = (atoms.cell.array * unit_dict[pos_unit]).astype(DTYPE) box = box.T # takes row and column convention of ase into account inputs["box"].append(box) - if pbc != np.all(box > 1e-6): + current_pbc = is_periodic(box) + + if pbc != current_pbc: raise ValueError( "Apax does not support dataset periodic and non periodic structures" ) - if np.all(box < 1e-6): + if not current_pbc: inputs["positions"].append( (atoms.positions * unit_dict[pos_unit]).astype(DTYPE) ) From 04cc7328fa49fcbeea4e046a9de1e7d382cec9e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Thu, 4 Apr 2024 16:36:39 +0200 Subject: [PATCH 3/9] removed self.inputs form dataset --- apax/data/input_pipeline.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 5664e972..65fc2ec7 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -39,7 +39,7 @@ def find_largest_system(atoms_list: List[Atoms], r_max) -> tuple[int]: class InMemoryDataset: def __init__( self, - atoms, + atoms_list, cutoff, bs, n_epochs, @@ -50,26 +50,29 @@ def __init__( cache_path=".", ) -> None: if pre_shuffle: - shuffle(atoms) - self.sample_atoms = atoms[0] - self.inputs = atoms_to_inputs(atoms) + shuffle(atoms_list) + self.sample_atoms = atoms_list[0] + # self.inputs = atoms_to_inputs(atoms) self.atoms = atoms self.n_epochs = n_epochs self.buffer_size = buffer_size - max_atoms, max_nbrs = find_largest_system(atoms, cutoff) + max_atoms, max_nbrs = find_largest_system(atoms_list, cutoff) self.max_atoms = max_atoms self.max_nbrs = max_nbrs # print(max_atoms, max_nbrs) # quit() - if atoms[0].calc and not ignore_labels: - self.labels = atoms_to_labels(atoms) - else: - self.labels = None + self.compute_labels = False + if atoms_list[0].calc and not ignore_labels: + self.compute_labels = True + # if atoms[0].calc and not ignore_labels: + # self.labels = atoms_to_labels(atoms) + # else: + # self.labels = None - self.n_data = len(atoms) + self.n_data = len(atoms_list) self.count = 0 self.cutoff = cutoff self.buffer = deque() @@ -98,7 +101,9 @@ def validate_batch_size(self, batch_size: int) -> int: return batch_size def prepare_data(self, i): - inputs = {k: v[i] for k, v in self.inputs.items()} # inputs["positions"], inputs["box"] + # inputs = {k: v[i] for k, v in self.inputs.items()} + atoms = self.atoms_list[i] + inputs = atoms_to_inputs(atoms, self.pos_unit) idx, offsets = compute_nl(self.atoms[i], self.cutoff) inputs["idx"], inputs["offsets"] = pad_nl(idx, offsets, self.max_nbrs) From ac6030046f0a98f310ea37dcb06c9a887cb03722 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Thu, 4 Apr 2024 21:21:20 +0200 Subject: [PATCH 4/9] sketch of fix --- apax/data/input_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 071b724b..a6775f6b 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -122,10 +122,10 @@ def prepare_data(self, i): inputs["n_atoms"], (0, zeros_to_add), "constant" ).astype(np.int16) - if not self.labels: + if not self.compute_labels: return inputs - labels = {k: v[i] for k, v in self.labels.items()} + # labels = {k: v[i] for k, v in self.labels.items()} if "forces" in labels: labels["forces"] = np.pad( labels["forces"], ((0, zeros_to_add), (0, 0)), "constant" From 2c081102519f0313f5b1efcc3e16184a081e1507 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 5 Apr 2024 10:41:25 +0200 Subject: [PATCH 5/9] fixed neighborlist computatoin for periodic systems --- apax/data/input_pipeline.py | 43 ++++++++++++++----------------------- apax/data/preprocessing.py | 26 +++++++++++++--------- 2 files changed, 32 insertions(+), 37 deletions(-) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index a6775f6b..156f5baf 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -24,12 +24,13 @@ def pad_nl(idx, offsets, max_neighbors): return idx, offsets -def find_largest_system(atoms_list: List[Atoms], r_max) -> tuple[int]: - max_atoms = np.max([atoms.numbers.shape[0] for atoms in atoms_list]) +def find_largest_system(inputs, r_max) -> tuple[int]: + positions, boxes = inputs["positions"], inputs["box"] + max_atoms = np.max(inputs["n_atoms"]) max_nbrs = 0 - for atoms in atoms_list: - neighbor_idxs, _ = compute_nl(atoms, r_max) + for pos, box in zip(positions, boxes): + neighbor_idxs, _ = compute_nl(pos, box, r_max) n_neighbors = neighbor_idxs.shape[1] max_nbrs = max(max_nbrs, n_neighbors) @@ -63,22 +64,15 @@ def __init__( if pre_shuffle: shuffle(atoms_list) self.sample_atoms = atoms_list[0] - # self.inputs = atoms_to_inputs(atoms, self.pos_unit) - self.atoms_list = atoms_list + self.inputs = atoms_to_inputs(atoms_list, self.pos_unit) - max_atoms, max_nbrs = find_largest_system(atoms_list, self.cutoff) + max_atoms, max_nbrs = find_largest_system(self.inputs, self.cutoff) self.max_atoms = max_atoms self.max_nbrs = max_nbrs - # print(max_atoms, max_nbrs) - # quit() - - self.compute_labels = False if atoms_list[0].calc and not ignore_labels: - self.compute_labels = True - # if atoms[0].calc and not ignore_labels: - # self.labels = atoms_to_labels(atoms) - # else: - # self.labels = None + self.labels = atoms_to_labels(atoms_list) + else: + self.labels = None self.count = 0 self.buffer = deque() @@ -105,10 +99,8 @@ def validate_batch_size(self, batch_size: int) -> int: return batch_size def prepare_data(self, i): - # inputs = {k: v[i] for k, v in self.inputs.items()} - atoms = self.atoms_list[i] - inputs = atoms_to_inputs(atoms, self.pos_unit) - idx, offsets = compute_nl(self.atoms[i], self.cutoff) + inputs = {k: v[i] for k, v in self.inputs.items()} + idx, offsets = compute_nl(inputs["positions"], inputs["box"], self.cutoff) inputs["idx"], inputs["offsets"] = pad_nl(idx, offsets, self.max_nbrs) zeros_to_add = self.max_atoms - inputs["numbers"].shape[0] @@ -118,19 +110,15 @@ def prepare_data(self, i): inputs["numbers"] = np.pad( inputs["numbers"], (0, zeros_to_add), "constant" ).astype(np.int16) - inputs["n_atoms"] = np.pad( - inputs["n_atoms"], (0, zeros_to_add), "constant" - ).astype(np.int16) - if not self.compute_labels: + if not self.labels: return inputs - # labels = {k: v[i] for k, v in self.labels.items()} + labels = {k: v[i] for k, v in self.labels.items()} if "forces" in labels: labels["forces"] = np.pad( labels["forces"], ((0, zeros_to_add), (0, 0)), "constant" ) - inputs = {k: tf.constant(v) for k, v in inputs.items()} labels = {k: tf.constant(v) for k, v in labels.items()} return (inputs, labels) @@ -179,7 +167,8 @@ def init_input(self) -> Dict[str, np.ndarray]: """Returns first batch of inputs and labels to init the model.""" positions = self.sample_atoms.positions * unit_dict[self.pos_unit] box = self.sample_atoms.cell.array * unit_dict[self.pos_unit] - idx, offsets = compute_nl(self.sample_atoms, self.cutoff) + # For an input sample, it does not matter whether pos is fractional or cartesian + idx, offsets = compute_nl(positions, box, self.cutoff) inputs = ( positions, self.sample_atoms.numbers, diff --git a/apax/data/preprocessing.py b/apax/data/preprocessing.py index 32827644..1c6cd0f6 100644 --- a/apax/data/preprocessing.py +++ b/apax/data/preprocessing.py @@ -10,31 +10,37 @@ log = logging.getLogger(__name__) -def compute_nl(atoms, r_max): - if np.all(atoms.cell.array < 1e-6): - cell, cell_origin = get_shrink_wrapped_cell(atoms.positions) +def compute_nl(positions, box, r_max): + """Computes the NL for a single structure. + For periodic systems, positions are assumed to be in + fractional coordinates. + """ + if np.all(box < 1e-6): + box, box_origin = get_shrink_wrapped_cell(positions) idxs_i, idxs_j = neighbour_list( "ij", - positions=atoms.positions, + positions=positions, cutoff=r_max, - cell=cell, - cell_origin=cell_origin, + cell=box, + cell_origin=box_origin, pbc=[False, False, False], ) - neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32) + neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int16) n_neighbors = neighbor_idxs.shape[1] offsets = np.full([n_neighbors, 3], 0) else: + positions = positions @ box idxs_i, idxs_j, offsets = neighbour_list( "ijS", - atoms = atoms, + positions=positions, cutoff=r_max, + cell=box, + pbc=[True, True, True] ) - neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32) - box = atoms.cell.array + neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int16) offsets = np.matmul(offsets, box) return neighbor_idxs, offsets From 4aa8eae7117db2f2e785b19da409e7460b9fdad7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 5 Apr 2024 10:42:10 +0200 Subject: [PATCH 6/9] linting --- apax/data/input_pipeline.py | 4 +--- apax/data/preprocessing.py | 6 +----- apax/train/trainer.py | 10 ++++++---- apax/utils/convert.py | 5 ++++- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 156f5baf..43ae1dda 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -3,9 +3,8 @@ from collections import deque from pathlib import Path from random import shuffle -from typing import Dict, Iterator, List +from typing import Dict, Iterator -from ase import Atoms import jax import jax.numpy as jnp import numpy as np @@ -52,7 +51,6 @@ def __init__( ignore_labels=False, cache_path=".", ) -> None: - self.n_epochs = n_epochs self.cutoff = cutoff self.n_jit_steps = n_jit_steps diff --git a/apax/data/preprocessing.py b/apax/data/preprocessing.py index 1c6cd0f6..c68b33ba 100644 --- a/apax/data/preprocessing.py +++ b/apax/data/preprocessing.py @@ -34,11 +34,7 @@ def compute_nl(positions, box, r_max): else: positions = positions @ box idxs_i, idxs_j, offsets = neighbour_list( - "ijS", - positions=positions, - cutoff=r_max, - cell=box, - pbc=[True, True, True] + "ijS", positions=positions, cutoff=r_max, cell=box, pbc=[True, True, True] ) neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int16) offsets = np.matmul(offsets, box) diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 8c040a3f..6d6bc0f0 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -107,10 +107,12 @@ def fit( epoch_loss["val_loss"] /= val_steps_per_epoch epoch_loss["val_loss"] = float(epoch_loss["val_loss"]) - epoch_metrics.update({ - f"val_{key}": float(val) - for key, val in val_batch_metrics.compute().items() - }) + epoch_metrics.update( + { + f"val_{key}": float(val) + for key, val in val_batch_metrics.compute().items() + } + ) epoch_metrics.update({**epoch_loss}) diff --git a/apax/utils/convert.py b/apax/utils/convert.py index 07d252fb..c537bcf2 100644 --- a/apax/utils/convert.py +++ b/apax/utils/convert.py @@ -38,12 +38,15 @@ def prune_dict(data_dict): pruned = {key: val for key, val in data_dict.items() if len(val) != 0} return pruned + def is_periodic(box): pbc_dims = np.any(np.abs(box) > 1e-6) if np.all(pbc_dims == True) or np.all(pbc_dims == False): return pbc_dims else: - msg = f"Only 3D periodic and gas phase system supported at the moment. Found {box}" + msg = ( + f"Only 3D periodic and gas phase system supported at the moment. Found {box}" + ) raise ValueError(msg) From 46a8a6519e09aba7e86115ee82d59f19f7022488 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 5 Apr 2024 13:04:13 +0200 Subject: [PATCH 7/9] readded unit conversion --- apax/data/input_pipeline.py | 4 ++-- apax/train/run.py | 4 ++-- apax/train/trainer.py | 1 - 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 43ae1dda..f3a0a267 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -62,13 +62,13 @@ def __init__( if pre_shuffle: shuffle(atoms_list) self.sample_atoms = atoms_list[0] - self.inputs = atoms_to_inputs(atoms_list, self.pos_unit) + self.inputs = atoms_to_inputs(atoms_list, pos_unit) max_atoms, max_nbrs = find_largest_system(self.inputs, self.cutoff) self.max_atoms = max_atoms self.max_nbrs = max_nbrs if atoms_list[0].calc and not ignore_labels: - self.labels = atoms_to_labels(atoms_list) + self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit) else: self.labels = None diff --git a/apax/train/run.py b/apax/train/run.py index dff8ad31..e81b11ff 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -63,8 +63,8 @@ def initialize_datasets(config: Config): config.n_epochs, config.data.shuffle_buffer_size, config.n_jitted_steps, - config.data.pos_unit, - config.data.energy_unit, + pos_unit=config.data.pos_unit, + energy_unit=config.data.energy_unit, pre_shuffle=True, cache_path=config.data.model_version_path, ) diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 6d6bc0f0..29d8cb2c 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -84,7 +84,6 @@ def fit( epoch_loss["train_loss"] += jnp.mean(batch_loss) callbacks.on_train_batch_end(batch=batch_idx) - epoch_loss["train_loss"] /= train_steps_per_epoch epoch_loss["train_loss"] = float(epoch_loss["train_loss"]) From 1dc1203fdf2bcc68d46d831419f787f01fd7d0f3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Apr 2024 11:31:26 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/train/trainer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 9f77864c..aae54123 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -130,12 +130,10 @@ def fit( epoch_loss["val_loss"] /= val_steps_per_epoch epoch_loss["val_loss"] = float(epoch_loss["val_loss"]) - epoch_metrics.update( - { - f"val_{key}": float(val) - for key, val in val_batch_metrics.compute().items() - } - ) + epoch_metrics.update({ + f"val_{key}": float(val) + for key, val in val_batch_metrics.compute().items() + }) epoch_metrics.update({**epoch_loss}) From 2305f5045fd4691dd9a2e794a6c2129e1f2a74e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 5 Apr 2024 13:35:05 +0200 Subject: [PATCH 9/9] linting --- apax/utils/convert.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/apax/utils/convert.py b/apax/utils/convert.py index c537bcf2..a7d4cc45 100644 --- a/apax/utils/convert.py +++ b/apax/utils/convert.py @@ -41,7 +41,7 @@ def prune_dict(data_dict): def is_periodic(box): pbc_dims = np.any(np.abs(box) > 1e-6) - if np.all(pbc_dims == True) or np.all(pbc_dims == False): + if np.all(pbc_dims == True) or np.all(pbc_dims == False): # noqa: E712 return pbc_dims else: msg = ( @@ -85,22 +85,22 @@ def atoms_to_inputs( box = box.T # takes row and column convention of ase into account inputs["box"].append(box) - current_pbc = is_periodic(box) + is_pbc = is_periodic(box) - if pbc != current_pbc: + if pbc != is_pbc: raise ValueError( "Apax does not support dataset periodic and non periodic structures" ) - if not current_pbc: - inputs["positions"].append( - (atoms.positions * unit_dict[pos_unit]).astype(DTYPE) - ) - else: + if is_pbc: inv_box = np.linalg.inv(box) pos = (atoms.positions * unit_dict[pos_unit]).astype(DTYPE) frac_pos = space.transform(inv_box, pos) inputs["positions"].append(np.array(frac_pos)) + else: + inputs["positions"].append( + (atoms.positions * unit_dict[pos_unit]).astype(DTYPE) + ) inputs["numbers"].append(atoms.numbers.astype(np.int16)) inputs["n_atoms"].append(len(atoms))