Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX Neighborlist for periodic structures #259

Merged
merged 14 commits into from
Apr 8, 2024
28 changes: 12 additions & 16 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ def pad_nl(idx, offsets, max_neighbors):
return idx, offsets


def find_largest_system(inputs: dict[str, np.ndarray], r_max) -> tuple[int]:
def find_largest_system(inputs, r_max) -> tuple[int]:
positions, boxes = inputs["positions"], inputs["box"]
max_atoms = np.max(inputs["n_atoms"])

max_nbrs = 0
for position, box in zip(inputs["positions"], inputs["box"]):
neighbor_idxs, _ = compute_nl(position, box, r_max)
for pos, box in zip(positions, boxes):
neighbor_idxs, _ = compute_nl(pos, box, r_max)
n_neighbors = neighbor_idxs.shape[1]
max_nbrs = max(max_nbrs, n_neighbors)

Expand All @@ -38,7 +39,7 @@ def find_largest_system(inputs: dict[str, np.ndarray], r_max) -> tuple[int]:
class InMemoryDataset:
def __init__(
self,
atoms,
atoms_list,
cutoff,
bs,
n_epochs,
Expand All @@ -50,26 +51,24 @@ def __init__(
ignore_labels=False,
cache_path=".",
) -> None:

self.n_epochs = n_epochs
self.cutoff = cutoff
self.n_jit_steps = n_jit_steps
self.buffer_size = buffer_size
self.n_data = len(atoms)
self.n_data = len(atoms_list)
self.batch_size = self.validate_batch_size(bs)
self.pos_unit = pos_unit

if pre_shuffle:
shuffle(atoms)
self.sample_atoms = atoms[0]
self.inputs = atoms_to_inputs(atoms, self.pos_unit)
shuffle(atoms_list)
self.sample_atoms = atoms_list[0]
self.inputs = atoms_to_inputs(atoms_list, pos_unit)

max_atoms, max_nbrs = find_largest_system(self.inputs, self.cutoff)
self.max_atoms = max_atoms
self.max_nbrs = max_nbrs

if atoms[0].calc and not ignore_labels:
self.labels = atoms_to_labels(atoms, self.pos_unit, energy_unit)
if atoms_list[0].calc and not ignore_labels:
self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit)
else:
self.labels = None

Expand Down Expand Up @@ -109,9 +108,6 @@ def prepare_data(self, i):
inputs["numbers"] = np.pad(
inputs["numbers"], (0, zeros_to_add), "constant"
).astype(np.int16)
inputs["n_atoms"] = np.pad(
inputs["n_atoms"], (0, zeros_to_add), "constant"
).astype(np.int16)

if not self.labels:
return inputs
Expand All @@ -121,7 +117,6 @@ def prepare_data(self, i):
labels["forces"] = np.pad(
labels["forces"], ((0, zeros_to_add), (0, 0)), "constant"
)

inputs = {k: tf.constant(v) for k, v in inputs.items()}
labels = {k: tf.constant(v) for k, v in labels.items()}
return (inputs, labels)
Expand Down Expand Up @@ -170,6 +165,7 @@ def init_input(self) -> Dict[str, np.ndarray]:
"""Returns first batch of inputs and labels to init the model."""
positions = self.sample_atoms.positions * unit_dict[self.pos_unit]
box = self.sample_atoms.cell.array * unit_dict[self.pos_unit]
# For an input sample, it does not matter whether pos is fractional or cartesian
idx, offsets = compute_nl(positions, box, self.cutoff)
inputs = (
positions,
Expand Down
24 changes: 13 additions & 11 deletions apax/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,33 @@
log = logging.getLogger(__name__)


def compute_nl(position, box, r_max):
def compute_nl(positions, box, r_max):
"""Computes the NL for a single structure.
For periodic systems, positions are assumed to be in
fractional coordinates.
"""
if np.all(box < 1e-6):
cell, cell_origin = get_shrink_wrapped_cell(position)
box, box_origin = get_shrink_wrapped_cell(positions)
idxs_i, idxs_j = neighbour_list(
"ij",
positions=position,
positions=positions,
cutoff=r_max,
cell=cell,
cell_origin=cell_origin,
cell=box,
cell_origin=box_origin,
pbc=[False, False, False],
)

neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32)
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int16)

n_neighbors = neighbor_idxs.shape[1]
offsets = np.full([n_neighbors, 3], 0)

else:
positions = positions @ box
idxs_i, idxs_j, offsets = neighbour_list(
"ijS",
positions=position,
cutoff=r_max,
cell=box,
"ijS", positions=positions, cutoff=r_max, cell=box, pbc=[True, True, True]
)
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32)
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int16)
offsets = np.matmul(offsets, box)
return neighbor_idxs, offsets

Expand Down
4 changes: 2 additions & 2 deletions apax/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def initialize_datasets(config: Config):
config.n_epochs,
config.data.shuffle_buffer_size,
config.n_jitted_steps,
config.data.pos_unit,
config.data.energy_unit,
pos_unit=config.data.pos_unit,
energy_unit=config.data.energy_unit,
pre_shuffle=True,
cache_path=config.data.model_version_path,
)
Expand Down
27 changes: 20 additions & 7 deletions apax/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ def prune_dict(data_dict):
return pruned


def is_periodic(box):
pbc_dims = np.any(np.abs(box) > 1e-6)
if np.all(pbc_dims == True) or np.all(pbc_dims == False): # noqa: E712
return pbc_dims
else:
msg = (
f"Only 3D periodic and gas phase system supported at the moment. Found {box}"
)
raise ValueError(msg)


def atoms_to_inputs(
atoms_list: list[Atoms],
pos_unit: str = "Ang",
Expand Down Expand Up @@ -67,27 +78,29 @@ def atoms_to_inputs(
}

box = atoms_list[0].cell.array
pbc = np.all(box > 1e-6)
pbc = is_periodic(box)

for atoms in atoms_list:
box = (atoms.cell.array * unit_dict[pos_unit]).astype(DTYPE)
box = box.T # takes row and column convention of ase into account
inputs["box"].append(box)

if pbc != np.all(box > 1e-6):
is_pbc = is_periodic(box)

if pbc != is_pbc:
raise ValueError(
"Apax does not support dataset periodic and non periodic structures"
)

if np.all(box < 1e-6):
inputs["positions"].append(
(atoms.positions * unit_dict[pos_unit]).astype(DTYPE)
)
else:
if is_pbc:
inv_box = np.linalg.inv(box)
pos = (atoms.positions * unit_dict[pos_unit]).astype(DTYPE)
frac_pos = space.transform(inv_box, pos)
inputs["positions"].append(np.array(frac_pos))
else:
inputs["positions"].append(
(atoms.positions * unit_dict[pos_unit]).astype(DTYPE)
)

inputs["numbers"].append(atoms.numbers.astype(np.int16))
inputs["n_atoms"].append(len(atoms))
Expand Down
Loading