Skip to content

Commit

Permalink
refactor attn scaling factor
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed Jan 24, 2025
1 parent 9599211 commit 1ab237f
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/haliax/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def dot_product_attention_weights(
bias: Optional[NamedArray] = None,
attention_dtype: Optional[jnp.dtype] = None,
precision: PrecisionLike = None,
use_mup: bool = False,
scaling_factor: Optional[float] = None,
) -> NamedArray:
"""
NamedArray version of dot product attention. Computes the logits for the attention weights. Note that the
Expand All @@ -52,10 +52,10 @@ def dot_product_attention_weights(
# cf https://github.com/google/flax/blob/509bf97ea272e130d932920f45307ac98947d994/flax/linen/attention.py#L40

orig_dtype = query.dtype
if use_mup:
query = query / query.axis_size(Key)
else:
query = query / jnp.sqrt(query.axis_size(Key))
if scaling_factor is None:
scaling_factor = query / jnp.sqrt(query.axis_size(Key))

query = query * scaling_factor

if attention_dtype is not None:
query = query.astype(attention_dtype)
Expand Down

0 comments on commit 1ab237f

Please sign in to comment.