Skip to content

Commit

Permalink
add unet norm
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 7, 2025
1 parent 47d67ba commit ed44850
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 0 deletions.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
use_unet_n: bool = True,
use_unet_e: bool = True,
use_unet_a: bool = True,
unet_norm: str = "None",
unet_rate: float = 0.5,
bn_moment: float = 0.1,
n_update_has_a: bool = False,
Expand Down Expand Up @@ -147,6 +148,7 @@ def __init__(
self.use_unet_e = use_unet_e
self.use_unet_a = use_unet_a
self.unet_rate = unet_rate
self.unet_norm = unet_norm
self.auto_batchsize = auto_batchsize
self.optim_update = optim_update

Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def init_subclass_params(sub_data, sub_class):
use_unet_e=self.repflow_args.use_unet_e,
use_unet_a=self.repflow_args.use_unet_a,
unet_rate=self.repflow_args.unet_rate,
unet_norm=self.repflow_args.unet_norm,
bn_moment=self.repflow_args.bn_moment,
auto_batchsize=self.repflow_args.auto_batchsize,
optim_update=self.repflow_args.optim_update,
Expand Down
90 changes: 90 additions & 0 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
use_unet_e: bool = True,
use_unet_a: bool = True,
unet_rate: float = 0.5,
unet_norm: str = "None",
bn_moment: float = 0.1,
auto_batchsize: int = 0,
optim_update: bool = True,
Expand Down Expand Up @@ -282,6 +283,7 @@ def __init__(
self.use_unet_e = use_unet_e
self.use_unet_a = use_unet_a
self.unet_rate = unet_rate
self.unet_norm = unet_norm
# self.out_ln = None
# if self.pre_ln:
# self.out_ln = torch.nn.LayerNorm(
Expand Down Expand Up @@ -320,6 +322,78 @@ def __init__(
(self.unet_rate ** (self.unet_rest_half - 1 - i))
for i in range(self.unet_rest_half)
]
self.unet_norm_n = None
self.unet_norm_e = None
self.unet_norm_a = None
if self.unet_norm != "None":
norm_idx = self.unet_first_half - 1
if self.unet_norm == "batchnorm":
self.unet_norm_n = (
torch.nn.BatchNorm1d(
int(self.n_dim * self.unet_scale[norm_idx]),
affine=False,
device=env.DEVICE,
dtype=self.prec,
momentum=self.bn_moment,
)
if self.use_unet_n
else None
)
self.unet_norm_e = (
torch.nn.BatchNorm1d(
int(self.e_dim * self.unet_scale[norm_idx]),
affine=False,
device=env.DEVICE,
dtype=self.prec,
momentum=self.bn_moment,
)
if self.use_unet_e
else None
)
self.unet_norm_a = (
torch.nn.BatchNorm1d(
int(self.a_dim * self.unet_scale[norm_idx]),
affine=False,
device=env.DEVICE,
dtype=self.prec,
momentum=self.bn_moment,
)
if self.use_unet_a
else None
)
elif self.unet_norm == "layernorm":
self.unet_norm_n = (
torch.nn.LayerNorm(
int(self.n_dim * self.unet_scale[norm_idx]),
device=env.DEVICE,
dtype=self.prec,
elementwise_affine=False,
)
if self.use_unet_n
else None
)
self.unet_norm_e = (
torch.nn.LayerNorm(
int(self.e_dim * self.unet_scale[norm_idx]),
device=env.DEVICE,
dtype=self.prec,
elementwise_affine=False,
)
if self.use_unet_e
else None
)
self.unet_norm_a = (
torch.nn.LayerNorm(
int(self.a_dim * self.unet_scale[norm_idx]),
device=env.DEVICE,
dtype=self.prec,
elementwise_affine=False,
)
if self.use_unet_a
else None
)
else:
raise ValueError(f"Unsupported unet norm {self.unet_norm}!")

for ii in range(nlayers):
layers.append(
Expand Down Expand Up @@ -739,6 +813,22 @@ def forward(
h1_ext,
)
if self.use_unet:
if self.unet_norm != "None" and idx == self.unet_first_half - 1:
if self.use_unet_n:
assert self.unet_norm_n is not None
node_ebd = self.unet_norm_n(
node_ebd.view(nframes * nloc, -1)
).view(nframes, nloc, -1)
if self.use_unet_e:
assert self.unet_norm_e is not None
edge_ebd = self.unet_norm_e(
edge_ebd.view(nframes * nloc * self.nnei, -1)
).view(nframes, nloc, nnei, -1)
if self.use_unet_a:
assert self.unet_norm_a is not None
angle_ebd = self.unet_norm_a(
angle_ebd.view(nframes * nloc * self.a_sel * self.a_sel, -1)
).view(nframes, nloc, self.a_sel, self.a_sel, -1)
if idx < self.unet_first_half - 1:
# stack half output
tmp_n_dim = int(self.n_dim * self.unet_scale[idx + 1])
Expand Down
6 changes: 6 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1715,6 +1715,12 @@ def dpa3_repflow_args():
optional=True,
default=0.5,
),
Argument(
"unet_norm",
str,
optional=True,
default="None",
),
]


Expand Down

0 comments on commit ed44850

Please sign in to comment.