Skip to content

Commit

Permalink
Merge pull request #124 from MDIL-SNU/ase2matsci
Browse files Browse the repository at this point in the history
[Feat] change ase.neighborlist to from matscipy.neighbours
  • Loading branch information
YutackPark authored Nov 23, 2024
2 parents 32e1357 + d7efb62 commit dd5b715
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 26 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Changelog
All notable changes to this project will be documented in this file.

## [0.10.2]
### Added
- Accelerated graph build routine if matscipy is installed @hexagonerose

## [0.10.1]
### Added
- experimental `SevenNetAtomsDataset` which is memory efficient, can be enabled with `dataset_type='atoms'`
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "sevenn"
version = "0.10.1"
version = "0.10.2.dev"
authors = [
{ name = "Yutack Park", email = "[email protected]" },
{ name = "Jaesun Kim" },
Expand All @@ -25,7 +25,10 @@ dependencies = [
"scikit-learn",
"torch_geometric>=2.5.0",
"numpy<2.0",
#"matscipy",
]
[project.optional-dependencies]
matscipy = ["matscipy"]


[project.scripts]
Expand Down
95 changes: 70 additions & 25 deletions sevenn/train/dataload.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,68 @@
from .dataset import AtomGraphDataset


def _graph_build_matscipy(cutoff: float, pbc, cell, pos):
pbc_x = pbc[0]
pbc_y = pbc[1]
pbc_z = pbc[2]

identity = np.identity(3, dtype=float)
max_positions = np.max(np.absolute(pos)) + 1

# Extend cell in non-periodic directions
# For models with more than 5 layers,
# the multiplicative constant needs to be increased.
if not pbc_x:
cell[0, :] = max_positions * 5 * cutoff * identity[0, :]
if not pbc_y:
cell[1, :] = max_positions * 5 * cutoff * identity[1, :]
if not pbc_z:
cell[2, :] = max_positions * 5 * cutoff * identity[2, :]
# it does not have self-interaction
edge_src, edge_dst, edge_vec, shifts = neighbour_list(
quantities='ijDS',
pbc=pbc,
cell=cell,
positions=pos,
cutoff=cutoff,
)
# dtype issue
edge_src = edge_src.astype(np.int64)
edge_dst = edge_dst.astype(np.int64)

return edge_src, edge_dst, edge_vec, shifts


def _graph_build_ase(cutoff: float, pbc, cell, pos):
# building neighbor list
edge_src, edge_dst, edge_vec, shifts = primitive_neighbor_list(
'ijDS', pbc, cell, pos, cutoff, self_interaction=True
)

return edge_src, edge_dst, edge_vec, shifts


_graph_build_f = _graph_build_ase
try:
from matscipy.neighbours import neighbour_list
_graph_build_f = _graph_build_matscipy
except ImportError:
pass


def _remove_self_edges(edge_src, edge_dst, edge_vec, shifts):
is_zero_idx = np.all(edge_vec == 0, axis=1)
is_self_idx = edge_src == edge_dst
non_trivials = ~(is_zero_idx & is_self_idx)
shifts = np.array(shifts[non_trivials])

edge_vec = edge_vec[non_trivials]
edge_src = edge_src[non_trivials]
edge_dst = edge_dst[non_trivials]

return edge_src, edge_dst, edge_vec, shifts


def _correct_scalar(v):
if isinstance(v, np.ndarray):
v = v.squeeze()
Expand All @@ -42,20 +104,12 @@ def _correct_scalar(v):
def unlabeled_atoms_to_graph(atoms: ase.Atoms, cutoff: float):
pos = atoms.get_positions()
cell = np.array(atoms.get_cell())
pbc = atoms.get_pbc()

# building neighbor list
edge_src, edge_dst, edge_vec, shifts = primitive_neighbor_list(
'ijDS', atoms.get_pbc(), cell, pos, cutoff, self_interaction=True
edge_src, edge_dst, edge_vec, shifts = _remove_self_edges(
*_graph_build_f(cutoff, pbc, cell, pos)
)

is_zero_idx = np.all(edge_vec == 0, axis=1)
is_self_idx = edge_src == edge_dst
non_trivials = ~(is_zero_idx & is_self_idx)
cell_shift = np.array(shifts[non_trivials])

edge_vec = edge_vec[non_trivials]
edge_src = edge_src[non_trivials]
edge_dst = edge_dst[non_trivials]
edge_idx = np.array([edge_src, edge_dst])

atomic_numbers = atoms.get_atomic_numbers()
Expand All @@ -72,7 +126,7 @@ def unlabeled_atoms_to_graph(atoms: ase.Atoms, cutoff: float):
KEY.EDGE_IDX: edge_idx,
KEY.EDGE_VEC: edge_vec,
KEY.CELL: cell,
KEY.CELL_SHIFT: cell_shift,
KEY.CELL_SHIFT: shifts,
KEY.CELL_VOLUME: vol,
KEY.NUM_ATOMS: _correct_scalar(len(atomic_numbers)),
}
Expand Down Expand Up @@ -137,22 +191,13 @@ def atoms_to_graph(

pos = atoms.get_positions()
cell = np.array(atoms.get_cell())
pbc = atoms.get_pbc()

# building neighbor list
edge_src, edge_dst, edge_vec, shifts = primitive_neighbor_list(
'ijDS', atoms.get_pbc(), cell, pos, cutoff, self_interaction=True
edge_src, edge_dst, edge_vec, shifts = _remove_self_edges(
*_graph_build_f(cutoff, pbc, cell, pos)
)

is_zero_idx = np.all(edge_vec == 0, axis=1)
is_self_idx = edge_src == edge_dst
non_trivials = ~(is_zero_idx & is_self_idx)
cell_shift = np.array(shifts[non_trivials])

edge_vec = edge_vec[non_trivials]
edge_src = edge_src[non_trivials]
edge_dst = edge_dst[non_trivials]
edge_idx = np.array([edge_src, edge_dst])

atomic_numbers = atoms.get_atomic_numbers()

cell = np.array(cell)
Expand All @@ -170,7 +215,7 @@ def atoms_to_graph(
KEY.FORCE: y_force,
KEY.STRESS: y_stress.reshape(1, 6), # to make batch have (n_node, 6)
KEY.CELL: cell,
KEY.CELL_SHIFT: cell_shift,
KEY.CELL_SHIFT: shifts,
KEY.CELL_VOLUME: vol,
KEY.NUM_ATOMS: _correct_scalar(len(atomic_numbers)),
KEY.PER_ATOM_ENERGY: _correct_scalar(y_energy / len(pos)),
Expand Down

0 comments on commit dd5b715

Please sign in to comment.