Skip to content

Commit

Permalink
note something weird we did in rotary embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanz committed Apr 27, 2022
1 parent 008c3c1 commit 56815aa
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions mreserve/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def construct_rotary_sinusoids(coords, rotary_hsize: int = 32, max_freq=10.0, dt

def apply_rotary(query_key, sinusoids):
"""
note: there's possibly a bug here (it differs from the usual rotary embedding. but somehow we got good results
anyways. weird!)
:param query_key: The query, key, or both. [*batch_dims, seq_len, num_heads, size_per_head]
:param sinusoids: [*sin_batch_dims, 2, seq_len, rotary_hsize <= size_per_head // 2]
:return: query_key with rotary applied
Expand All @@ -130,7 +134,11 @@ def apply_rotary(query_key, sinusoids):
cos = sinusoids[..., 1, :, None, :]

qk_rope = query_key[..., :rotary_hsize]

# the bug is here...
qk_rotated_two = jnp.stack([-qk_rope[..., ::2], qk_rope[..., 1::2]], -1).reshape(qk_rope.shape)
# should be = jnp.stack([-qk_rope[..., 1::2], qk_rope[..., ::2]], -1).reshape(qk_rope.shape)

qk_rope = qk_rope * cos + qk_rotated_two * sin
query_key = jnp.concatenate([qk_rope, query_key[..., rotary_hsize:]], -1)
return query_key
Expand Down

0 comments on commit 56815aa

Please sign in to comment.