diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 28751230ee..c603effbc4 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -623,48 +623,47 @@ def optim_angle_update( angle_ebd: torch.Tensor, node_ebd: torch.Tensor, edge_ebd: torch.Tensor, - angle_dim: int, - node_dim: int, - edge_dim: int, feat: str = "edge", ) -> torch.Tensor: + angle_dim = angle_ebd.shape[-1] + node_dim = node_ebd.shape[-1] + edge_dim = edge_ebd.shape[-1] + sub_angle_idx = (0, angle_dim) + sub_node_idx = (angle_dim, angle_dim + node_dim) + sub_edge_idx_ij = (angle_dim + node_dim, angle_dim + node_dim + edge_dim) + sub_edge_idx_ik = ( + angle_dim + node_dim + edge_dim, + angle_dim + node_dim + 2 * edge_dim, + ) + if feat == "edge": - # fot jit - sub_angle_matrix = self.edge_angle_linear1.matrix[:angle_dim] - sub_node_matrix = self.edge_angle_linear1.matrix[ - angle_dim : angle_dim + node_dim - ] - sub_edge_matrix_ij = self.edge_angle_linear1.matrix[ - angle_dim + node_dim : angle_dim + node_dim + edge_dim - ] - sub_edge_matrix_ik = self.edge_angle_linear1.matrix[ - angle_dim + node_dim + edge_dim : angle_dim + node_dim + 2 * edge_dim - ] - bias = self.edge_angle_linear1.bias + matrix, bias = self.edge_angle_linear1.matrix, self.edge_angle_linear1.bias elif feat == "angle": - sub_angle_matrix = self.angle_self_linear.matrix[:angle_dim] - sub_node_matrix = self.angle_self_linear.matrix[ - angle_dim : angle_dim + node_dim - ] - sub_edge_matrix_ij = self.angle_self_linear.matrix[ - angle_dim + node_dim : angle_dim + node_dim + edge_dim - ] - sub_edge_matrix_ik = self.angle_self_linear.matrix[ - angle_dim + node_dim + edge_dim : angle_dim + node_dim + 2 * edge_dim - ] - bias = self.angle_self_linear.bias + matrix, bias = self.angle_self_linear.matrix, self.angle_self_linear.bias else: + matrix, bias = None, None raise NotImplementedError + assert matrix is not None + assert bias is not None + assert angle_dim + node_dim + 2 * edge_dim == matrix.size()[0] # nf * nloc * a_sel * a_sel * angle_dim - sub_angle_update = torch.matmul(angle_ebd, sub_angle_matrix) + sub_angle_update = torch.matmul( + angle_ebd, matrix[sub_angle_idx[0] : sub_angle_idx[1]] + ) # nf * nloc * angle_dim - sub_node_update = torch.matmul(node_ebd, sub_node_matrix) + sub_node_update = torch.matmul( + node_ebd, matrix[sub_node_idx[0] : sub_node_idx[1]] + ) # nf * nloc * a_nnei * angle_dim - sub_edge_update_ij = torch.matmul(edge_ebd, sub_edge_matrix_ij) - sub_edge_update_ik = torch.matmul(edge_ebd, sub_edge_matrix_ik) + sub_edge_update_ij = torch.matmul( + edge_ebd, matrix[sub_edge_idx_ij[0] : sub_edge_idx_ij[1]] + ) + sub_edge_update_ik = torch.matmul( + edge_ebd, matrix[sub_edge_idx_ik[0] : sub_edge_idx_ik[1]] + ) result_update = ( sub_angle_update @@ -1038,9 +1037,6 @@ def forward( angle_ebd, node_ebd_for_angle, edge_for_angle, - self.a_dim, - self.a_dim, - self.a_dim, "edge", ) ) @@ -1105,9 +1101,6 @@ def forward( angle_ebd, node_ebd_for_angle, edge_for_angle, - self.a_dim, - self.a_dim, - self.a_dim, "angle", ) )