Skip to content

Commit

Permalink
add _custom_glorot_normal
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jan 28, 2025
1 parent 04f03ae commit 0a9b96b
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions deepmd/pt/model/network/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ def __init__(
self._default_normal_init(
bavg=bavg, stddev=stddev, generator=random_generator
)
elif init.startswith("custom_glorot"):
step_size = float(init.split(":")[-1]) if ":" in init else 0.1
self._custom_glorot_normal(
step_size=step_size,
bavg=bavg,
stddev=stddev,
generator=random_generator,
)
elif init == "trunc_normal":
self._trunc_normal_init(1.0, generator=random_generator)
elif init == "relu":
Expand Down Expand Up @@ -166,6 +174,23 @@ def _default_normal_init(
if self.idt is not None:
normal_(self.idt.data, mean=0.1, std=0.001, generator=generator)

def _custom_glorot_normal(
self,
step_size: float = 0.1,
bavg: float = 0.0,
stddev: float = 1.0,
generator: Optional[torch.Generator] = None,
) -> None:
normal_(
self.matrix.data,
std=stddev * np.sqrt(2 / (step_size**2 * (self.num_out + self.num_in))),
generator=generator,
)
if self.bias is not None:
normal_(self.bias.data, mean=bavg, std=stddev, generator=generator)
if self.idt is not None:
normal_(self.idt.data, mean=0.1, std=0.001, generator=generator)

def _trunc_normal_init(
self, scale=1.0, generator: Optional[torch.Generator] = None
) -> None:
Expand Down

0 comments on commit 0a9b96b

Please sign in to comment.