From 2f9636d1e7cc1ed21960404f1108fcac229a666f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 23 Feb 2024 21:03:57 -0500 Subject: [PATCH 1/3] pt: remove env.DEVICE in all forward functions Ensure the saved model can run on both CPUs and GPUs. Signed-off-by: Jinzhe Zeng --- .../model/atomic_model/linear_atomic_model.py | 14 ++++++------- .../atomic_model/pairtab_atomic_model.py | 9 ++++---- deepmd/pt/model/descriptor/repformer_layer.py | 4 ++-- deepmd/pt/model/descriptor/se_a.py | 4 +++- deepmd/pt/model/task/fitting.py | 2 +- deepmd/pt/utils/nlist.py | 21 ++++++++++--------- source/tests/pt/model/test_deeppot.py | 8 +++++++ 7 files changed, 35 insertions(+), 27 deletions(-) diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 70afbcb0bc..f90fa5f237 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -92,16 +92,14 @@ def get_model_sels(self) -> List[List[int]]: """Get the sels for each individual models.""" return [model.get_sel() for model in self.models] - def _sort_rcuts_sels(self) -> Tuple[List[float], List[int]]: + def _sort_rcuts_sels(self, device: torch.device) -> Tuple[List[float], List[int]]: # sort the pair of rcut and sels in ascending order, first based on sel, then on rcut. - rcuts = torch.tensor( - self.get_model_rcuts(), dtype=torch.float64, device=env.DEVICE - ) - nsels = torch.tensor(self.get_model_nsels(), device=env.DEVICE) + rcuts = torch.tensor(self.get_model_rcuts(), dtype=torch.float64, device=device) + nsels = torch.tensor(self.get_model_nsels(), device=device) zipped = torch.stack( [ - torch.tensor(rcuts, device=env.DEVICE), - torch.tensor(nsels, device=env.DEVICE), + torch.tensor(rcuts, device=device), + torch.tensor(nsels, device=device), ], dim=0, ).T @@ -148,7 +146,7 @@ def forward_atomic( if self.do_grad_r() or self.do_grad_c(): extended_coord.requires_grad_(True) extended_coord = extended_coord.view(nframes, -1, 3) - sorted_rcuts, sorted_sels = self._sort_rcuts_sels() + sorted_rcuts, sorted_sels = self._sort_rcuts_sels(device=extended_coord.device) nlists = build_multiple_neighbor_list( extended_coord, nlist, diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index cf5a70eb88..eff445e799 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -12,9 +12,6 @@ FittingOutputDef, OutputVariableDef, ) -from deepmd.pt.utils import ( - env, -) from deepmd.utils.pair_tab import ( PairTab, ) @@ -160,7 +157,7 @@ def forward_atomic( pairwise_rr = self._get_pairwise_dist( extended_coord, masked_nlist ) # (nframes, nloc, nnei) - self.tab_data = self.tab_data.to(device=env.DEVICE).view( + self.tab_data = self.tab_data.to(device=extended_coord.device).view( int(self.tab_info[-1]), int(self.tab_info[-1]), int(self.tab_info[2]), 4 ) @@ -168,7 +165,9 @@ def forward_atomic( # i_type : (nframes, nloc), this is atype. # j_type : (nframes, nloc, nnei) j_type = extended_atype[ - torch.arange(extended_atype.size(0), device=env.DEVICE)[:, None, None], + torch.arange(extended_atype.size(0), device=extended_coord.device)[ + :, None, None + ], masked_nlist, ] diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 66ce38c0f7..55a2cba708 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -446,7 +446,7 @@ def _update_g1_conv( else: gg1 = _apply_switch(gg1, sw) invnnei = (1.0 / float(nnei)) * torch.ones( - (nb, nloc, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + (nb, nloc, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=gg1.device ) # nb x nloc x ng2 g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei @@ -474,7 +474,7 @@ def _cal_h2g2( else: g2 = _apply_switch(g2, sw) invnnei = (1.0 / float(nnei)) * torch.ones( - (nb, nloc, 1, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + (nb, nloc, 1, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=g2.device ) # nb x nloc x 3 x ng2 h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 0550488ecf..7fd8b1dc7d 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -467,7 +467,9 @@ def forward( nfnl = dmatrix.shape[0] # pre-allocate a shape to pass jit xyz_scatter = torch.zeros( - [nfnl, 4, self.filter_neuron[-1]], dtype=self.prec, device=env.DEVICE + [nfnl, 4, self.filter_neuron[-1]], + dtype=self.prec, + device=extended_coord.device, ) # nfnl x nnei exclude_mask = self.emask(nlist, extended_atype).view(nfnl, -1) diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 080bfb5172..e4be16c66b 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -567,7 +567,7 @@ def _forward_common( outs = torch.zeros( (nf, nloc, net_dim_out), dtype=env.GLOBAL_PT_FLOAT_PRECISION, - device=env.DEVICE, + device=descriptor.device, ) # jit assertion if self.old_impl: assert self.filter_layers_old is not None diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py index 0e2d9785f8..cfc75d9438 100644 --- a/deepmd/pt/utils/nlist.py +++ b/deepmd/pt/utils/nlist.py @@ -288,8 +288,9 @@ def extend_coord_with_ghosts( maping extended index to the local index """ + device = coord.device nf, nloc = atype.shape - aidx = torch.tile(torch.arange(nloc, device=env.DEVICE).unsqueeze(0), [nf, 1]) + aidx = torch.tile(torch.arange(nloc, device=device).unsqueeze(0), [nf, 1]) if cell is None: nall = nloc extend_coord = coord.clone() @@ -306,17 +307,17 @@ def extend_coord_with_ghosts( nbuff = torch.ceil(rcut / to_face).to(torch.long) # 3 nbuff = torch.max(nbuff, dim=0, keepdim=False).values - xi = torch.arange(-nbuff[0], nbuff[0] + 1, 1, device=env.DEVICE) - yi = torch.arange(-nbuff[1], nbuff[1] + 1, 1, device=env.DEVICE) - zi = torch.arange(-nbuff[2], nbuff[2] + 1, 1, device=env.DEVICE) + xi = torch.arange(-nbuff[0], nbuff[0] + 1, 1, device=device) + yi = torch.arange(-nbuff[1], nbuff[1] + 1, 1, device=device) + zi = torch.arange(-nbuff[2], nbuff[2] + 1, 1, device=device) xyz = xi.view(-1, 1, 1, 1) * torch.tensor( - [1, 0, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + [1, 0, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device ) xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor( - [0, 1, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + [0, 1, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device ) xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor( - [0, 0, 1], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + [0, 0, 1], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device ) xyz = xyz.view(-1, 3) # ns x 3 @@ -333,7 +334,7 @@ def extend_coord_with_ghosts( extend_aidx = torch.tile(aidx.unsqueeze(-2), [1, ns, 1]) return ( - extend_coord.reshape([nf, nall * 3]).to(env.DEVICE), - extend_atype.view([nf, nall]).to(env.DEVICE), - extend_aidx.view([nf, nall]).to(env.DEVICE), + extend_coord.reshape([nf, nall * 3]).to(device), + extend_atype.view([nf, nall]).to(device), + extend_aidx.view([nf, nall]).to(device), ) diff --git a/source/tests/pt/model/test_deeppot.py b/source/tests/pt/model/test_deeppot.py index ee04942ae7..5cd5729640 100644 --- a/source/tests/pt/model/test_deeppot.py +++ b/source/tests/pt/model/test_deeppot.py @@ -115,3 +115,11 @@ def setUp(self): ) freeze(ns) self.model = frozen_model + + # Note: this can not actually disable cuda device to be used + # only can be used to test whether devices are dismatched + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.mock.patch("deepmd.pt.utils.env.DEVICE", torch.device("cpu")) + @unittest.mock.patch("deepmd.pt.infer.deep_eval.DEVICE", torch.device("cpu")) + def test_dp_test_cpu(self): + self.test_dp_test() From e03623f1069f9ca844be6d301abf6b310cab10b7 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 23 Feb 2024 21:12:44 -0500 Subject: [PATCH 2/3] fix comment Co-authored-by: Chun Cai Signed-off-by: Jinzhe Zeng --- source/tests/pt/model/test_deeppot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/pt/model/test_deeppot.py b/source/tests/pt/model/test_deeppot.py index 5cd5729640..334206a2b0 100644 --- a/source/tests/pt/model/test_deeppot.py +++ b/source/tests/pt/model/test_deeppot.py @@ -117,7 +117,7 @@ def setUp(self): self.model = frozen_model # Note: this can not actually disable cuda device to be used - # only can be used to test whether devices are dismatched + # only can be used to test whether devices are mismatched @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @unittest.mock.patch("deepmd.pt.utils.env.DEVICE", torch.device("cpu")) @unittest.mock.patch("deepmd.pt.infer.deep_eval.DEVICE", torch.device("cpu")) From 4f7b1c646c2b64e5d8cd0204f8419f6266859b65 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 26 Feb 2024 09:11:08 -0500 Subject: [PATCH 3/3] fix tests Signed-off-by: Jinzhe Zeng --- source/tests/pt/model/test_dipole_fitting.py | 28 +++++++++++++------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index fb04e49484..fcdd408726 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -30,6 +30,7 @@ ) from deepmd.pt.utils.utils import ( to_numpy_array, + to_torch_tensor, ) from .test_env_mat import ( @@ -298,10 +299,10 @@ def setUp(self): self.rcut_smth = 0.5 self.sel = [46, 92, 4] self.nf = 1 - self.coord = 2 * torch.rand([self.natoms, 3], dtype=dtype, device="cpu") - cell = torch.rand([3, 3], dtype=dtype, device="cpu") - self.cell = (cell + cell.T) + 5.0 * torch.eye(3, device="cpu") - self.atype = torch.IntTensor([0, 0, 0, 1, 1], device="cpu") + self.coord = 2 * torch.rand([self.natoms, 3], dtype=dtype, device=env.DEVICE) + cell = torch.rand([3, 3], dtype=dtype, device=env.DEVICE) + self.cell = (cell + cell.T) + 5.0 * torch.eye(3, device=env.DEVICE) + self.atype = torch.IntTensor([0, 0, 0, 1, 1], device="cpu").to(env.DEVICE) self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) self.ft0 = DipoleFittingNet( "dipole", @@ -322,17 +323,26 @@ def test_auto_diff(self): atype = self.atype.view(self.nf, self.natoms) def ff(coord, atype): - return self.model(coord, atype)["global_dipole"].detach().cpu().numpy() + return ( + self.model(to_torch_tensor(coord), to_torch_tensor(atype))[ + "global_dipole" + ] + .detach() + .cpu() + .numpy() + ) - fdf = -finite_difference(ff, self.coord, atype, delta=delta) + fdf = -finite_difference( + ff, to_numpy_array(self.coord), to_numpy_array(atype), delta=delta + ) rff = self.model(self.coord, atype)["force"].detach().cpu().numpy() np.testing.assert_almost_equal(fdf, rff.transpose(0, 2, 1, 3), decimal=places) def test_deepdipole_infer(self): - atype = self.atype.view(self.nf, self.natoms) - coord = self.coord.reshape(1, 5, 3) - cell = self.cell.reshape(1, 9) + atype = to_numpy_array(self.atype.view(self.nf, self.natoms)) + coord = to_numpy_array(self.coord.reshape(1, 5, 3)) + cell = to_numpy_array(self.cell.reshape(1, 9)) jit_md = torch.jit.script(self.model) torch.jit.save(jit_md, self.file_path) load_md = DeepDipole(self.file_path)