Skip to content

Commit

Permalink
workaround for bfloat16
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 13, 2024
1 parent 4ba6e97 commit e43134d
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ def call(self, x: np.ndarray) -> np.ndarray:
if self.b is not None
else xp.matmul(x, self.w)
)
if y.dtype != x.dtype:
# workaround for bfloat16
# https://github.com/jax-ml/ml_dtypes/issues/235
y = xp.astype(y, x.dtype)
y = fn(y)
if self.idt is not None:
y *= self.idt
Expand Down

0 comments on commit e43134d

Please sign in to comment.