Skip to content

Commit

Permalink
feat: make no _remove_self_edges when matscipy graph build
Browse files Browse the repository at this point in the history
  • Loading branch information
hexagonrose committed Nov 27, 2024
1 parent 1099215 commit 2fe92c7
Showing 1 changed file with 12 additions and 23 deletions.
35 changes: 12 additions & 23 deletions sevenn/train/dataload.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,33 +66,27 @@ def _graph_build_ase(cutoff: float, pbc, cell, pos):
'ijDS', pbc, cell, pos, cutoff, self_interaction=True
)

is_zero_idx = np.all(edge_vec == 0, axis=1)
is_self_idx = edge_src == edge_dst
non_trivials = ~(is_zero_idx & is_self_idx)
shifts = np.array(shifts[non_trivials])

edge_vec = edge_vec[non_trivials]
edge_src = edge_src[non_trivials]
edge_dst = edge_dst[non_trivials]

return edge_src, edge_dst, edge_vec, shifts


_graph_build_f = _graph_build_ase
try:
from matscipy.neighbours import neighbour_list

_graph_build_f = _graph_build_matscipy
except ImportError:
pass


def _remove_self_edges(edge_src, edge_dst, edge_vec, shifts):
if _graph_build_f == _graph_build_matscipy:
return edge_src, edge_dst, edge_vec, shifts
else:
is_zero_idx = np.all(edge_vec == 0, axis=1)
is_self_idx = edge_src == edge_dst
non_trivials = ~(is_zero_idx & is_self_idx)
shifts = np.array(shifts[non_trivials])

edge_vec = edge_vec[non_trivials]
edge_src = edge_src[non_trivials]
edge_dst = edge_dst[non_trivials]

return edge_src, edge_dst, edge_vec, shifts


def _correct_scalar(v):
if isinstance(v, np.ndarray):
v = v.squeeze()
Expand All @@ -108,11 +102,8 @@ def unlabeled_atoms_to_graph(atoms: ase.Atoms, cutoff: float):
pos = atoms.get_positions()
cell = np.array(atoms.get_cell())
pbc = atoms.get_pbc()

edge_src, edge_dst, edge_vec, shifts = _remove_self_edges(
*_graph_build_f(cutoff, pbc, cell, pos)
)

edge_src, edge_dst, edge_vec, shifts = _graph_build_f(cutoff, pbc, cell, pos)

edge_idx = np.array([edge_src, edge_dst])

Expand Down Expand Up @@ -197,9 +188,7 @@ def atoms_to_graph(
cell = np.array(atoms.get_cell())
pbc = atoms.get_pbc()

edge_src, edge_dst, edge_vec, shifts = _remove_self_edges(
*_graph_build_f(cutoff, pbc, cell, pos)
)
edge_src, edge_dst, edge_vec, shifts = _graph_build_f(cutoff, pbc, cell, pos)

edge_idx = np.array([edge_src, edge_dst])
atomic_numbers = atoms.get_atomic_numbers()
Expand Down

0 comments on commit 2fe92c7

Please sign in to comment.