Skip to content

Commit

Permalink
add auto_batchsize
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jan 30, 2025
1 parent 0a9b96b commit e2c9df6
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 15 deletions.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
bn_moment: float = 0.1,
n_update_has_a: bool = False,
n_update_has_a_first_sum: bool = False,
auto_batchsize: int = 0,
) -> None:
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(
self.use_unet_n = use_unet_n
self.use_unet_e = use_unet_e
self.use_unet_a = use_unet_a
self.auto_batchsize = auto_batchsize

def __getitem__(self, key):
if hasattr(self, key):
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def init_subclass_params(sub_data, sub_class):
use_unet_e=self.repflow_args.use_unet_e,
use_unet_a=self.repflow_args.use_unet_a,
bn_moment=self.repflow_args.bn_moment,
auto_batchsize=self.repflow_args.auto_batchsize,
skip_stat=self.repflow_args.skip_stat,
exclude_types=exclude_types,
env_protection=env_protection,
Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,7 @@ def forward(
a_nlist_mask: torch.Tensor, # nf x nloc x a_nnei
a_sw: torch.Tensor, # switch func, nf x nloc x a_nnei
h1_ext: Optional[torch.Tensor], # nf x nall x 3 x h1_dim
node_ebd_split: Optional[torch.Tensor] = None, # nf x nloc x n_dim
):
"""
Parameters
Expand Down Expand Up @@ -665,7 +666,10 @@ def forward(
"""
nb, nloc, nnei, _ = edge_ebd.shape
nall = node_ebd_ext.shape[1]
node_ebd, _ = torch.split(node_ebd_ext, [nloc, nall - nloc], dim=1)
if node_ebd_split is None:
node_ebd, _ = torch.split(node_ebd_ext, [nloc, nall - nloc], dim=1)
else:
node_ebd = node_ebd_split
assert (nb, nloc) == node_ebd.shape[:2]
assert (nb, nloc, nnei) == h2.shape[:3]
del a_nlist # may be used in the future
Expand Down
91 changes: 77 additions & 14 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
use_unet_e: bool = True,
use_unet_a: bool = True,
bn_moment: float = 0.1,
auto_batchsize: int = 0,
a_norm_use_max_v: bool = False,
e_norm_use_max_v: bool = False,
e_a_reduce_use_sqrt: bool = True,
Expand Down Expand Up @@ -231,6 +232,7 @@ def __init__(
self.n_update_has_a_first_sum = n_update_has_a_first_sum
self.n_attn_hidden = n_attn_hidden
self.n_attn_head = n_attn_head
self.auto_batchsize = auto_batchsize

self.n_dim = n_dim
self.e_dim = e_dim
Expand Down Expand Up @@ -654,20 +656,81 @@ def forward(
node_ebd_ext = concat_switch_virtual(
node_ebd_real_ext, node_ebd_virtual_ext, real_nloc
)
node_ebd, edge_ebd, angle_ebd, h1 = ll.forward(
node_ebd_ext,
edge_ebd,
h2,
angle_ebd,
nlist,
nlist_mask,
sw,
a_nlist,
a_nlist_mask,
a_sw,
h1_ext,
)

if self.auto_batchsize != 0 and nframes * nloc > self.auto_batchsize:
node_ebd_full, _ = torch.split(node_ebd_ext, [nloc, nall - nloc], dim=1)
node_ebd_chunks = torch.split(node_ebd_full, self.auto_batchsize, dim=1)
edge_ebd_chunks = torch.split(edge_ebd, self.auto_batchsize, dim=1)
h2_chunks = torch.split(h2, self.auto_batchsize, dim=1)
angle_ebd_chunks = torch.split(angle_ebd, self.auto_batchsize, dim=1)
nlist_chunks = torch.split(nlist, self.auto_batchsize, dim=1)
nlist_mask_chunks = torch.split(nlist_mask, self.auto_batchsize, dim=1)
sw_chunks = torch.split(sw, self.auto_batchsize, dim=1)
a_nlist_chunks = torch.split(a_nlist, self.auto_batchsize, dim=1)
a_nlist_mask_chunks = torch.split(
a_nlist_mask, self.auto_batchsize, dim=1
)
a_sw_chunks = torch.split(a_sw, self.auto_batchsize, dim=1)

node_ebd_list = []
edge_ebd_list = []
angle_ebd_list = []
for (
node_ebd_sub,
edge_ebd_sub,
h2_sub,
angle_ebd_sub,
nlist_sub,
nlist_mask_sub,
sw_sub,
a_nlist_sub,
a_nlist_mask_sub,
a_sw_sub,
) in zip(
node_ebd_chunks,
edge_ebd_chunks,
h2_chunks,
angle_ebd_chunks,
nlist_chunks,
nlist_mask_chunks,
sw_chunks,
a_nlist_chunks,
a_nlist_mask_chunks,
a_sw_chunks,
):
node_ebd_tmp, edge_ebd_tmp, angle_ebd_tmp, h1_tmp = ll.forward(
node_ebd_ext,
edge_ebd_sub,
h2_sub,
angle_ebd_sub,
nlist_sub,
nlist_mask_sub,
sw_sub,
a_nlist_sub,
a_nlist_mask_sub,
a_sw_sub,
h1_ext,
node_ebd_split=node_ebd_sub,
)
node_ebd_list.append(node_ebd_tmp)
edge_ebd_list.append(edge_ebd_tmp)
angle_ebd_list.append(angle_ebd_tmp)
node_ebd = torch.cat(node_ebd_list, dim=1)
edge_ebd = torch.cat(edge_ebd_list, dim=1)
angle_ebd = torch.cat(angle_ebd_list, dim=1)
else:
node_ebd, edge_ebd, angle_ebd, h1 = ll.forward(
node_ebd_ext,
edge_ebd,
h2,
angle_ebd,
nlist,
nlist_mask,
sw,
a_nlist,
a_nlist_mask,
a_sw,
h1_ext,
)
if self.use_unet:
if idx < self.unet_first_half - 1:
# stack half output
Expand Down
6 changes: 6 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,6 +1697,12 @@ def dpa3_repflow_args():
optional=True,
default=True,
),
Argument(
"auto_batchsize",
int,
optional=True,
default=0,
),
]


Expand Down

0 comments on commit e2c9df6

Please sign in to comment.