Skip to content

Commit

Permalink
feat: make no _remove_self_edges when matscipy graph build
Browse files Browse the repository at this point in the history
  • Loading branch information
hexagonrose committed Nov 27, 2024
1 parent dd5b715 commit 1099215
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions sevenn/train/dataload.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,19 @@ def _graph_build_ase(cutoff: float, pbc, cell, pos):


def _remove_self_edges(edge_src, edge_dst, edge_vec, shifts):
is_zero_idx = np.all(edge_vec == 0, axis=1)
is_self_idx = edge_src == edge_dst
non_trivials = ~(is_zero_idx & is_self_idx)
shifts = np.array(shifts[non_trivials])
if _graph_build_f == _graph_build_matscipy:
return edge_src, edge_dst, edge_vec, shifts
else:
is_zero_idx = np.all(edge_vec == 0, axis=1)
is_self_idx = edge_src == edge_dst
non_trivials = ~(is_zero_idx & is_self_idx)
shifts = np.array(shifts[non_trivials])

edge_vec = edge_vec[non_trivials]
edge_src = edge_src[non_trivials]
edge_dst = edge_dst[non_trivials]
edge_vec = edge_vec[non_trivials]
edge_src = edge_src[non_trivials]
edge_dst = edge_dst[non_trivials]

return edge_src, edge_dst, edge_vec, shifts
return edge_src, edge_dst, edge_vec, shifts


def _correct_scalar(v):
Expand All @@ -105,11 +108,12 @@ def unlabeled_atoms_to_graph(atoms: ase.Atoms, cutoff: float):
pos = atoms.get_positions()
cell = np.array(atoms.get_cell())
pbc = atoms.get_pbc()

edge_src, edge_dst, edge_vec, shifts = _remove_self_edges(
*_graph_build_f(cutoff, pbc, cell, pos)
)


edge_idx = np.array([edge_src, edge_dst])

atomic_numbers = atoms.get_atomic_numbers()
Expand Down

0 comments on commit 1099215

Please sign in to comment.