Skip to content

Commit

Permalink
pt: remove env.DEVICE in all forward functions (#3330)
Browse files Browse the repository at this point in the history
Ensure the saved JIT model can run on both CPUs and GPUs.

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: Chun Cai <[email protected]>
  • Loading branch information
njzjz and caic99 authored Feb 26, 2024
1 parent 261c802 commit a3f4a67
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 36 deletions.
14 changes: 6 additions & 8 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,14 @@ def get_model_sels(self) -> List[List[int]]:
"""Get the sels for each individual models."""
return [model.get_sel() for model in self.models]

def _sort_rcuts_sels(self) -> Tuple[List[float], List[int]]:
def _sort_rcuts_sels(self, device: torch.device) -> Tuple[List[float], List[int]]:
# sort the pair of rcut and sels in ascending order, first based on sel, then on rcut.
rcuts = torch.tensor(
self.get_model_rcuts(), dtype=torch.float64, device=env.DEVICE
)
nsels = torch.tensor(self.get_model_nsels(), device=env.DEVICE)
rcuts = torch.tensor(self.get_model_rcuts(), dtype=torch.float64, device=device)
nsels = torch.tensor(self.get_model_nsels(), device=device)
zipped = torch.stack(
[
torch.tensor(rcuts, device=env.DEVICE),
torch.tensor(nsels, device=env.DEVICE),
torch.tensor(rcuts, device=device),
torch.tensor(nsels, device=device),
],
dim=0,
).T
Expand Down Expand Up @@ -148,7 +146,7 @@ def forward_atomic(
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
extended_coord = extended_coord.view(nframes, -1, 3)
sorted_rcuts, sorted_sels = self._sort_rcuts_sels()
sorted_rcuts, sorted_sels = self._sort_rcuts_sels(device=extended_coord.device)
nlists = build_multiple_neighbor_list(
extended_coord,
nlist,
Expand Down
9 changes: 4 additions & 5 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
FittingOutputDef,
OutputVariableDef,
)
from deepmd.pt.utils import (
env,
)
from deepmd.utils.pair_tab import (
PairTab,
)
Expand Down Expand Up @@ -160,15 +157,17 @@ def forward_atomic(
pairwise_rr = self._get_pairwise_dist(
extended_coord, masked_nlist
) # (nframes, nloc, nnei)
self.tab_data = self.tab_data.to(device=env.DEVICE).view(
self.tab_data = self.tab_data.to(device=extended_coord.device).view(
int(self.tab_info[-1]), int(self.tab_info[-1]), int(self.tab_info[2]), 4
)

# to calculate the atomic_energy, we need 3 tensors, i_type, j_type, pairwise_rr
# i_type : (nframes, nloc), this is atype.
# j_type : (nframes, nloc, nnei)
j_type = extended_atype[
torch.arange(extended_atype.size(0), device=env.DEVICE)[:, None, None],
torch.arange(extended_atype.size(0), device=extended_coord.device)[
:, None, None
],
masked_nlist,
]

Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def _update_g1_conv(
else:
gg1 = _apply_switch(gg1, sw)
invnnei = (1.0 / float(nnei)) * torch.ones(
(nb, nloc, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
(nb, nloc, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=gg1.device
)
# nb x nloc x ng2
g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei
Expand Down Expand Up @@ -474,7 +474,7 @@ def _cal_h2g2(
else:
g2 = _apply_switch(g2, sw)
invnnei = (1.0 / float(nnei)) * torch.ones(
(nb, nloc, 1, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
(nb, nloc, 1, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=g2.device
)
# nb x nloc x 3 x ng2
h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,9 @@ def forward(
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = torch.zeros(
[nfnl, 4, self.filter_neuron[-1]], dtype=self.prec, device=env.DEVICE
[nfnl, 4, self.filter_neuron[-1]],
dtype=self.prec,
device=extended_coord.device,
)
# nfnl x nnei
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, -1)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def _forward_common(
outs = torch.zeros(
(nf, nloc, net_dim_out),
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
device=descriptor.device,
) # jit assertion
if self.old_impl:
assert self.filter_layers_old is not None
Expand Down
21 changes: 11 additions & 10 deletions deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,9 @@ def extend_coord_with_ghosts(
maping extended index to the local index
"""
device = coord.device
nf, nloc = atype.shape
aidx = torch.tile(torch.arange(nloc, device=env.DEVICE).unsqueeze(0), [nf, 1])
aidx = torch.tile(torch.arange(nloc, device=device).unsqueeze(0), [nf, 1])
if cell is None:
nall = nloc
extend_coord = coord.clone()
Expand All @@ -306,17 +307,17 @@ def extend_coord_with_ghosts(
nbuff = torch.ceil(rcut / to_face).to(torch.long)
# 3
nbuff = torch.max(nbuff, dim=0, keepdim=False).values
xi = torch.arange(-nbuff[0], nbuff[0] + 1, 1, device=env.DEVICE)
yi = torch.arange(-nbuff[1], nbuff[1] + 1, 1, device=env.DEVICE)
zi = torch.arange(-nbuff[2], nbuff[2] + 1, 1, device=env.DEVICE)
xi = torch.arange(-nbuff[0], nbuff[0] + 1, 1, device=device)
yi = torch.arange(-nbuff[1], nbuff[1] + 1, 1, device=device)
zi = torch.arange(-nbuff[2], nbuff[2] + 1, 1, device=device)
xyz = xi.view(-1, 1, 1, 1) * torch.tensor(
[1, 0, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
[1, 0, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device
)
xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor(
[0, 1, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
[0, 1, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device
)
xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor(
[0, 0, 1], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
[0, 0, 1], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device
)
xyz = xyz.view(-1, 3)
# ns x 3
Expand All @@ -333,7 +334,7 @@ def extend_coord_with_ghosts(
extend_aidx = torch.tile(aidx.unsqueeze(-2), [1, ns, 1])

return (
extend_coord.reshape([nf, nall * 3]).to(env.DEVICE),
extend_atype.view([nf, nall]).to(env.DEVICE),
extend_aidx.view([nf, nall]).to(env.DEVICE),
extend_coord.reshape([nf, nall * 3]).to(device),
extend_atype.view([nf, nall]).to(device),
extend_aidx.view([nf, nall]).to(device),
)
8 changes: 8 additions & 0 deletions source/tests/pt/model/test_deeppot.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,11 @@ def setUp(self):
)
freeze(ns)
self.model = frozen_model

# Note: this can not actually disable cuda device to be used
# only can be used to test whether devices are mismatched
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.mock.patch("deepmd.pt.utils.env.DEVICE", torch.device("cpu"))
@unittest.mock.patch("deepmd.pt.infer.deep_eval.DEVICE", torch.device("cpu"))
def test_dp_test_cpu(self):
self.test_dp_test()
28 changes: 19 additions & 9 deletions source/tests/pt/model/test_dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from deepmd.pt.utils.utils import (
to_numpy_array,
to_torch_tensor,
)

from .test_env_mat import (
Expand Down Expand Up @@ -298,10 +299,10 @@ def setUp(self):
self.rcut_smth = 0.5
self.sel = [46, 92, 4]
self.nf = 1
self.coord = 2 * torch.rand([self.natoms, 3], dtype=dtype, device="cpu")
cell = torch.rand([3, 3], dtype=dtype, device="cpu")
self.cell = (cell + cell.T) + 5.0 * torch.eye(3, device="cpu")
self.atype = torch.IntTensor([0, 0, 0, 1, 1], device="cpu")
self.coord = 2 * torch.rand([self.natoms, 3], dtype=dtype, device=env.DEVICE)
cell = torch.rand([3, 3], dtype=dtype, device=env.DEVICE)
self.cell = (cell + cell.T) + 5.0 * torch.eye(3, device=env.DEVICE)
self.atype = torch.IntTensor([0, 0, 0, 1, 1], device="cpu").to(env.DEVICE)
self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE)
self.ft0 = DipoleFittingNet(
"dipole",
Expand All @@ -322,17 +323,26 @@ def test_auto_diff(self):
atype = self.atype.view(self.nf, self.natoms)

def ff(coord, atype):
return self.model(coord, atype)["global_dipole"].detach().cpu().numpy()
return (
self.model(to_torch_tensor(coord), to_torch_tensor(atype))[
"global_dipole"
]
.detach()
.cpu()
.numpy()
)

fdf = -finite_difference(ff, self.coord, atype, delta=delta)
fdf = -finite_difference(
ff, to_numpy_array(self.coord), to_numpy_array(atype), delta=delta
)
rff = self.model(self.coord, atype)["force"].detach().cpu().numpy()

np.testing.assert_almost_equal(fdf, rff.transpose(0, 2, 1, 3), decimal=places)

def test_deepdipole_infer(self):
atype = self.atype.view(self.nf, self.natoms)
coord = self.coord.reshape(1, 5, 3)
cell = self.cell.reshape(1, 9)
atype = to_numpy_array(self.atype.view(self.nf, self.natoms))
coord = to_numpy_array(self.coord.reshape(1, 5, 3))
cell = to_numpy_array(self.cell.reshape(1, 9))
jit_md = torch.jit.script(self.model)
torch.jit.save(jit_md, self.file_path)
load_md = DeepDipole(self.file_path)
Expand Down

0 comments on commit a3f4a67

Please sign in to comment.