Skip to content

Commit

Permalink
pt: process frames in parallel for env mat stat
Browse files Browse the repository at this point in the history
Resolves deepmodeling#3285 (comment)

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Feb 18, 2024
1 parent 6451cdb commit 117acd9
Showing 1 changed file with 12 additions and 30 deletions.
42 changes: 12 additions & 30 deletions deepmd/pt/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,38 +128,20 @@ def iter(
# TODO: export rcut_smth from DescriptorBlock
self.descriptor.rcut_smth,
)
# reshape to nframes * nloc at the atom level,
# so nframes/mixed_type do not matter
env_mat = env_mat.view(
coord.shape[0], coord.shape[1], self.descriptor.get_nsel(), 4
coord.shape[0] * coord.shape[1], self.descriptor.get_nsel(), 4
)

if "real_natoms_vec" not in system:
end_indexes = torch.cumsum(natoms[0, 2:], 0)
start_indexes = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device=env.DEVICE),
end_indexes[:-1],
]
)
for type_i in range(self.descriptor.get_ntypes()):
dd = env_mat[
:, start_indexes[type_i] : end_indexes[type_i], :, :
] # all descriptors for this element
env_mats = {}
env_mats[f"r_{type_i}"] = dd[:, :, :, :1]
env_mats[f"a_{type_i}"] = dd[:, :, :, 1:]
yield self.compute_stat(env_mats)
else:
for frame_item in range(env_mat.shape[0]):
dd_ff = env_mat[frame_item]
atype_frame = atype[frame_item]
for type_i in range(self.descriptor.get_ntypes()):
type_idx = atype_frame == type_i
dd = dd_ff[type_idx]
dd = dd.reshape([-1, 4]) # typen_atoms * nnei, 4
env_mats = {}
env_mats[f"r_{type_i}"] = dd[:, :1]
env_mats[f"a_{type_i}"] = dd[:, 1:]
yield self.compute_stat(env_mats)
atype = atype.view(coord.shape[0] * coord.shape[1])
for type_i in range(self.descriptor.get_ntypes()):
type_idx = atype == type_i
dd = env_mat[type_idx]
dd = dd.reshape([-1, 4]) # typen_atoms * nnei, 4
env_mats = {}
env_mats[f"r_{type_i}"] = dd[:, :1]
env_mats[f"a_{type_i}"] = dd[:, 1:]
yield self.compute_stat(env_mats)

def get_hash(self) -> str:
"""Get the hash of the environment matrix.
Expand Down

0 comments on commit 117acd9

Please sign in to comment.