Skip to content

Commit

Permalink
add mup configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed Jan 22, 2025
1 parent 81bbb78 commit 9599211
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/haliax/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def dot_product_attention_weights(
bias: Optional[NamedArray] = None,
attention_dtype: Optional[jnp.dtype] = None,
precision: PrecisionLike = None,
use_mup: bool = False,
) -> NamedArray:
"""
NamedArray version of dot product attention. Computes the logits for the attention weights. Note that the
Expand All @@ -51,7 +52,10 @@ def dot_product_attention_weights(
# cf https://github.com/google/flax/blob/509bf97ea272e130d932920f45307ac98947d994/flax/linen/attention.py#L40

orig_dtype = query.dtype
query = query / jnp.sqrt(query.axis_size(Key))
if use_mup:
query = query / query.axis_size(Key)
else:
query = query / jnp.sqrt(query.axis_size(Key))

if attention_dtype is not None:
query = query.astype(attention_dtype)
Expand Down

0 comments on commit 9599211

Please sign in to comment.