diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index 789dff131d..de6c50f57a 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -42,6 +42,10 @@ def __init__( only_e_ln: bool = False, pre_bn: bool = False, only_e_bn: bool = False, + use_unet: bool = False, + use_unet_n: bool = True, + use_unet_e: bool = True, + use_unet_a: bool = True, bn_moment: float = 0.1, n_update_has_a: bool = False, n_update_has_a_first_sum: bool = False, @@ -135,6 +139,10 @@ def __init__( self.e_a_reduce_use_sqrt = e_a_reduce_use_sqrt self.n_update_has_a = n_update_has_a self.n_update_has_a_first_sum = n_update_has_a_first_sum + self.use_unet = use_unet + self.use_unet_n = use_unet_n + self.use_unet_e = use_unet_e + self.use_unet_a = use_unet_a def __getitem__(self, key): if hasattr(self, key): diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index 78f7ea2f7f..30bc6affc3 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -181,6 +181,10 @@ def init_subclass_params(sub_data, sub_class): only_e_ln=self.repflow_args.only_e_ln, pre_bn=self.repflow_args.pre_bn, only_e_bn=self.repflow_args.only_e_bn, + use_unet=self.repflow_args.use_unet, + use_unet_n=self.repflow_args.use_unet_n, + use_unet_e=self.repflow_args.use_unet_e, + use_unet_a=self.repflow_args.use_unet_a, bn_moment=self.repflow_args.bn_moment, skip_stat=self.repflow_args.skip_stat, exclude_types=exclude_types, diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index e9613e2c39..03d3e805c4 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -420,6 +420,9 @@ def __init__( seed=child_seed(seed, 9), ) else: + # use split + assert self.n_a_compress_dim <= self.n_dim + assert self.e_a_compress_dim <= self.e_dim self.a_compress_n_linear = None self.a_compress_e_linear = None diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 0a8d3e42e5..20e4c9fbac 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -115,6 +115,10 @@ def __init__( only_e_ln: bool = False, pre_bn: bool = False, only_e_bn: bool = False, + use_unet: bool = False, + use_unet_n: bool = True, + use_unet_e: bool = True, + use_unet_a: bool = True, bn_moment: float = 0.1, a_norm_use_max_v: bool = False, e_norm_use_max_v: bool = False, @@ -268,6 +272,10 @@ def __init__( self.pre_bn = pre_bn self.only_e_bn = only_e_bn self.bn_moment = bn_moment + self.use_unet = use_unet + self.use_unet_n = use_unet_n + self.use_unet_e = use_unet_e + self.use_unet_a = use_unet_a # self.out_ln = None # if self.pre_ln: # self.out_ln = torch.nn.LayerNorm( @@ -296,6 +304,15 @@ def __init__( else: self.h1_embd = None layers = [] + self.unet_scale = [1.0 for _ in range(self.nlayers)] + self.unet_first_half = int((self.nlayers + 1) / 2) + self.unet_rest_half = int(self.nlayers / 2) + if self.use_unet: + self.unet_scale = [(0.5**i) for i in range(self.unet_first_half)] + [ + (0.5 ** (self.unet_rest_half - 1 - i)) + for i in range(self.unet_rest_half) + ] + for ii in range(nlayers): layers.append( RepFlowLayer( @@ -306,9 +323,15 @@ def __init__( a_rcut_smth=self.a_rcut_smth, a_sel=self.a_sel, ntypes=self.ntypes, - n_dim=self.n_dim, - e_dim=self.e_dim, - a_dim=self.a_dim, + n_dim=self.n_dim + if (not self.use_unet or not self.use_unet_n) + else int(self.n_dim * self.unet_scale[ii]), + e_dim=self.e_dim + if (not self.use_unet or not self.use_unet_e) + else int(self.e_dim * self.unet_scale[ii]), + a_dim=self.a_dim + if (not self.use_unet or not self.use_unet_a) + else int(self.a_dim * self.unet_scale[ii]), a_compress_rate=self.a_compress_rate, a_mess_has_n=self.a_mess_has_n, a_use_e_mess=self.a_use_e_mess, @@ -555,13 +578,19 @@ def forward( else: mapping3 = None + unet_list_node = [] + unet_list_edge = [] + unet_list_angle = [] + for idx, ll in enumerate(self.layers): # node_ebd: nb x nloc x n_dim # node_ebd_ext: nb x nall x n_dim if comm_dict is None: assert mapping is not None assert mapping3 is not None - node_ebd_ext = torch.gather(node_ebd, 1, mapping) + node_ebd_ext = torch.gather( + node_ebd, 1, mapping[:, :, : node_ebd.shape[-1]] + ) if self.has_h1: assert h1 is not None h1_ext = torch.gather(h1, 1, mapping3) @@ -639,6 +668,38 @@ def forward( h1_ext, ) + if self.use_unet: + if idx < self.unet_first_half - 1: + # stack half output + tmp_n_dim = int(self.n_dim * self.unet_scale[idx + 1]) + tmp_e_dim = int(self.e_dim * self.unet_scale[idx + 1]) + tmp_a_dim = int(self.a_dim * self.unet_scale[idx + 1]) + if self.use_unet_n: + stack_node_ebd, node_ebd = torch.split( + node_ebd, [tmp_n_dim, tmp_n_dim], dim=-1 + ) + unet_list_node.append(stack_node_ebd) + if self.use_unet_e: + stack_edge_ebd, edge_ebd = torch.split( + edge_ebd, [tmp_e_dim, tmp_e_dim], dim=-1 + ) + unet_list_edge.append(stack_edge_ebd) + if self.use_unet_a: + stack_angle_ebd, angle_ebd = torch.split( + angle_ebd, [tmp_a_dim, tmp_a_dim], dim=-1 + ) + unet_list_angle.append(stack_angle_ebd) + elif self.unet_rest_half - 1 < idx < self.nlayers - 1: + # skip connection, concat the half output + if self.use_unet_n: + stack_node_ebd = unet_list_node.pop() + node_ebd = torch.cat([stack_node_ebd, node_ebd], dim=-1) + if self.use_unet_e: + stack_edge_ebd = unet_list_edge.pop() + edge_ebd = torch.cat([stack_edge_ebd, edge_ebd], dim=-1) + if self.use_unet_a: + stack_angle_ebd = unet_list_angle.pop() + angle_ebd = torch.cat([stack_angle_ebd, angle_ebd], dim=-1) # nb x nloc x 3 x e_dim h2g2 = RepFlowLayer._cal_hg(edge_ebd, h2, nlist_mask, sw) # (nb x nloc) x e_dim x 3 diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 20e96e5359..44305e0fe7 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1673,6 +1673,30 @@ def dpa3_repflow_args(): optional=True, default=False, ), + Argument( + "use_unet", + bool, + optional=True, + default=False, + ), + Argument( + "use_unet_n", + bool, + optional=True, + default=True, + ), + Argument( + "use_unet_e", + bool, + optional=True, + default=True, + ), + Argument( + "use_unet_a", + bool, + optional=True, + default=True, + ), ]