diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index 6f44ac95c6..e72ed3b440 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -117,6 +117,14 @@ def __init__( self._default_normal_init( bavg=bavg, stddev=stddev, generator=random_generator ) + elif init.startswith("custom_glorot"): + step_size = float(init.split(":")[-1]) if ":" in init else 0.1 + self._custom_glorot_normal( + step_size=step_size, + bavg=bavg, + stddev=stddev, + generator=random_generator, + ) elif init == "trunc_normal": self._trunc_normal_init(1.0, generator=random_generator) elif init == "relu": @@ -166,6 +174,23 @@ def _default_normal_init( if self.idt is not None: normal_(self.idt.data, mean=0.1, std=0.001, generator=generator) + def _custom_glorot_normal( + self, + step_size: float = 0.1, + bavg: float = 0.0, + stddev: float = 1.0, + generator: Optional[torch.Generator] = None, + ) -> None: + normal_( + self.matrix.data, + std=stddev * np.sqrt(2 / (step_size**2 * (self.num_out + self.num_in))), + generator=generator, + ) + if self.bias is not None: + normal_(self.bias.data, mean=bavg, std=stddev, generator=generator) + if self.idt is not None: + normal_(self.idt.data, mean=0.1, std=0.001, generator=generator) + def _trunc_normal_init( self, scale=1.0, generator: Optional[torch.Generator] = None ) -> None: