Skip to content

Commit

Permalink
refactor dot_product_attention to be more general but also check argu…
Browse files Browse the repository at this point in the history
…ments better. add self_attention function
  • Loading branch information
dlwh committed Nov 3, 2023
1 parent 08b8852 commit a2e311e
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 33 deletions.
1 change: 1 addition & 0 deletions docs/nn.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ We don't provide an explicit attention module, but we do provide an attention fu

:::haliax.nn.attention.dot_product_attention
:::haliax.nn.attention.dot_product_attention_weights
:::haliax.nn.attention.self_attention

### Masks
::: haliax.nn.attention.causal_mask
Expand Down
1 change: 1 addition & 0 deletions src/haliax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def concatenate(axis: AxisSelector, arrays: Sequence[NamedArray]) -> NamedArray:
# we want to use the axis name for `axis`, because it's not uncommon for those to be different lengths in the arrays
axes = axes[:axis_index] + (aname,) + axes[axis_index + 1 :]
arrays = [a.rearrange(axes) for a in arrays]
print([a.axes for a in arrays])

new_axes = arrays[0].axes[:axis_index] + (axis,) + arrays[0].axes[axis_index + 1 :]
return NamedArray(jnp.concatenate([a.array for a in arrays], axis=axis_index), new_axes)
Expand Down
25 changes: 21 additions & 4 deletions src/haliax/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,14 +281,31 @@ def overlapping_axes(ax1: AxisSelection, ax2: AxisSelection) -> Tuple[AxisSelect
return tuple(out)


def axis_name(ax: AxisSelector) -> str:
@overload
def axis_name(ax: AxisSelector) -> str: # type: ignore
...


@overload
def axis_name(ax: Sequence[AxisSelector]) -> Tuple[str, ...]: # type: ignore
...


def axis_name(ax: AxisSelection) -> Union[str, Tuple[str, ...]]:
"""
Returns the name of the axis. If ax is a string, returns ax. If ax is an Axis, returns ax.name
"""
if isinstance(ax, Axis):
return ax.name

def _ax_name(ax: AxisSelector) -> str:
if isinstance(ax, Axis):
return ax.name
else:
return ax

if isinstance(ax, (Axis, str)):
return _ax_name(ax)
else:
return ax
return tuple(_ax_name(x) for x in ax)


class dslice(eqx.Module):
Expand Down
5 changes: 3 additions & 2 deletions src/haliax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,10 @@ def resolve_axis(self, axes: AxisSelection) -> AxisSpec: # type: ignore
raise ValueError(f"Axis {axes} not found")
else:
result = []
for i in indices:
assert isinstance(axes, Sequence)
for i, ax in zip(indices, axes):
if i is None:
raise ValueError(f"Axis {axes} not found")
raise ValueError(f"Axis {ax} not found in {self.shape}")
result.append(self.axes[i])
return tuple(result)

Expand Down
118 changes: 92 additions & 26 deletions src/haliax/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
from jaxtyping import PRNGKeyArray

import haliax
import haliax.nn.activations
import haliax.nn.normalization
import haliax.random as hrandom
from haliax.axis import Axis, AxisSelection, AxisSelector, AxisSpec
from haliax.axis import Axis, AxisSelection, AxisSelector, AxisSpec, axis_name
from haliax.core import NamedArray
from haliax.types import PrecisionLike
from haliax.util import ensure_tuple


# With attention we usually distinguish between the mask and the bias, though the former is just a special case of the
# With attention, we usually distinguish between the mask and the bias, though the former is just a special case of the
# latter. In practice, the mask is a boolean array that is applied using `where` to the logits, while the bias is a
# float array that is added to the logits. The mask is usually used to prevent attention to certain positions, while
# the bias is usually used to encourage or discourage attention to certain positions.
Expand All @@ -26,7 +25,7 @@


def dot_product_attention_weights(
Head: AxisSelector,
Key: AxisSelector,
KPos: AxisSelection,
query: NamedArray,
key: NamedArray,
Expand All @@ -39,8 +38,8 @@ def dot_product_attention_weights(
NamedArray version of dot product attention. Computes the logits for the attention weights. Note that the
"Pos" axis in query must be distinct from the "Pos" axis in key.
:param Head: Axis of head dimension
:param KPos: Axis of key sequence length. Can be an AxisSpec to attend along more than one axis.
:param Key: Axis of head dimension
:param KPos: Axis or axes that are attended to
:param query: NamedArray of shape (QPos, KeySize)
:param key: NamedArray of shape (KPos, KeySize)
:param mask: Optional[NamedArray] broadcast compatible with (KeySize, QPos, KPos). Should be boolean
Expand All @@ -52,28 +51,27 @@ def dot_product_attention_weights(
# cf https://github.com/google/flax/blob/509bf97ea272e130d932920f45307ac98947d994/flax/linen/attention.py#L40

orig_dtype = query.dtype
query = query / jnp.sqrt(query.axis_size(Head))
query = query / jnp.sqrt(query.axis_size(Key))

if attention_dtype is not None:
query = query.astype(attention_dtype)
key = key.astype(attention_dtype)

weights = haliax.dot(Head, query, key, precision=precision)
weights = haliax.dot(Key, query, key, precision=precision)

if bias is not None:
weights = weights + bias
if mask is not None:
weights = haliax.where(mask, weights, -1e9)

weights = haliax.nn.normalization.softmax(weights, axis=KPos)
weights = haliax.nn.softmax(weights, axis=KPos)

return weights.astype(orig_dtype)


def dot_product_attention(
QPos: Axis,
KPos: Axis,
KeySize: Axis,
KPos: AxisSelection,
Key: AxisSelector,
query: NamedArray,
key: NamedArray,
value: NamedArray,
Expand All @@ -85,12 +83,11 @@ def dot_product_attention(
"""
NamedArray version of dot product attention. This can be multi-headed or not.
:param QPos: Axis of sequence length
:param KPos: Axis of key sequence length
:param KeySize: Axis of head dimension
:param query: NamedArray of shape (QPos, KeySize)
:param key: NamedArray of shape (KPos, KeySize)
:param value: NamedArray of shape (KPos, KeySize)
:param Key: Axis of head dimension
:param query: NamedArray of shape {..., QPos, KeySize}
:param key: NamedArray of shape {..., KPos, KeySize}
:param value: NamedArray of shape {..., KPos, KeySize}
:param mask: Optional[NamedArray] broadcast compatible with (KeySize, QPos, KPos). Should be boolean
:param bias: Optional[NamedArray] broadcast compatible with (KeySize, QPos, KPos). Should be float
:param attention_dtype: Optional dtype to use for attention
Expand All @@ -101,17 +98,86 @@ def dot_product_attention(
For example, mask is frequently just a boolean array of shape (QPos, KPos), while bias is frequently a float
array of shape (KeySize, QPos, KPos) or (KeySize, KPos)
"""
# cf https://github.com/google/flax/blob/509bf97ea272e130d932920f45307ac98947d994/flax/linen/attention.py#L125
if not isinstance(query, NamedArray):
raise TypeError(
f"query must be a NamedArray, got {type(query)}. Probably you are still using the old signature"
"of dot_product_attention. It no longer takes a QPos argument."
)
KPos = ensure_tuple(key.resolve_axis(KPos))
# any axis in KPos that's in query is a problem
for axis in KPos:
if axis in query.axes:
raise ValueError(
f"Axis {axis} in KPos is also in query. Attended-to axes must be distinct from query axis"
)

weights = dot_product_attention_weights(
Key, KPos, query, key, mask=mask, bias=bias, attention_dtype=attention_dtype, precision=precision
)

# rename key/value length axis if it's the same as the query length axis
if KPos == QPos:
KPos = QPos.alias(KPos.name + "_key")
key = key.rename({KPos: QPos})
value = value.rename({KPos: QPos})
return haliax.dot(KPos, weights, value)

weights = dot_product_attention_weights(KeySize, KPos, query, key, mask, bias, attention_dtype, precision)

return haliax.dot(KPos, weights, value)
def self_attention(
Pos: AxisSelection,
Key: AxisSelector,
query: NamedArray,
key: NamedArray,
value: NamedArray,
is_causal: bool, # make people be explicit about this
mask: Optional[NamedArray] = None,
bias: Optional[NamedArray] = None,
attention_dtype: Optional[jnp.dtype] = None,
precision: PrecisionLike = None,
) -> NamedArray:
"""
Convenience function for self attention. This is just a wrapper around dot_product_attention that makes sure
the query and key axes are distinct. This is a common mistake and it's better to catch it early.
Note that mask and bias's Pos axis/axes should be key axes, not query axes. You can't use
query axes in a mask/bias with this method. Use dot_product_attention directly if you need to do that.
As an exception, if is_causal is True, then we create a causal mask for you.
Args:
Pos: Axis of sequence length
Key: Axis of head dimension
query: NamedArray of shape {..., Pos, KeySize}
key: NamedArray of shape {..., Pos, KeySize}
value: NamedArray of shape {..., Pos, KeySize}
is_causal: whether to use a causal mask
mask: Optional[NamedArray] broadcast compatible with (KeySize, Pos, Pos). Should be boolean
bias: Optional[NamedArray] broadcast compatible with (KeySize, Pos, Pos). Should be float
attention_dtype: Optional dtype to use for attention
precision: PrecisionLike for dot product. See precision argument to jax.lax.dot_general
"""
Pos = ensure_tuple(key.resolve_axis(Pos))

# rename key/value length axes if necessary
QPos, renames = _get_query_pos_renames(Pos)

query = query.rename(renames)

if is_causal:
# require that QPos is a single axis
if len(Pos) != 1:
raise ValueError("QPos must be a single axis for causal self attention")
mask = causal_mask(QPos[0], Pos[0])

out = dot_product_attention(Pos, Key, query, key, value, mask, bias, attention_dtype, precision)
# now rename back
return out.rename({v: k for k, v in renames.items()})


def _get_query_pos_renames(Pos):
new_Pos: list[Axis] = []
renames: dict[str, str] = {}
for i, axis in enumerate(Pos):
ax_name = axis_name(axis)
axis = axis.alias(f"q_{ax_name}")
renames[ax_name] = axis.name
new_Pos.append(axis)

return tuple(new_Pos), renames


def mask_to_bias(mask: NamedArray, mask_value: float = -1e9) -> NamedArray:
Expand Down
65 changes: 64 additions & 1 deletion tests/test_attention.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,74 @@
import jax.numpy as jnp
import numpy as np
from jax.random import PRNGKey

import haliax as hax
from haliax.nn.attention import alibi_attention_bias, dot_product_attention_weights, forgetful_causal_mask
from haliax.nn.attention import (
alibi_attention_bias,
causal_mask,
dot_product_attention,
dot_product_attention_weights,
forgetful_causal_mask,
self_attention,
)
from test_utils import skip_if_no_torch


def test_dot_product_attention_requires_axis_to_be_present():
Pos = hax.Axis("Pos", 20)
KeyPos = hax.Axis("Pos_key", 20)
NumHeads = hax.Axis("NumHeads", 1)
Hid = hax.Axis("Hid", 8)

query = hax.ones((NumHeads, KeyPos, Hid)) # NB: KeyPos not Pos
key = hax.ones((KeyPos, NumHeads, Hid))
value = hax.ones((KeyPos, NumHeads, Hid))

try:
dot_product_attention(Pos, Hid, query, key, value)
except ValueError as e:
assert "not found" in str(e)
else:
raise AssertionError("Should have raised an error")


def test_attention_doesnt_allow_overlapping_axes():
KeyPos = hax.Axis("Pos_key", 20)
NumHeads = hax.Axis("NumHeads", 1)
Hid = hax.Axis("Hid", 8)

query = hax.ones((NumHeads, KeyPos, Hid)) # NB: KeyPos not Pos
key = hax.ones((KeyPos, NumHeads, Hid))
value = hax.ones((KeyPos, NumHeads, Hid))

try:
dot_product_attention(KeyPos, Hid, query, key, value)
except ValueError as e:
assert "must be distinct" in str(e)
else:
raise AssertionError("Should have raised an error")


def test_self_attention_basically_works():
Pos = hax.Axis("Pos", 20)
KeyPos = hax.Axis("Pos_key", 20)
NumHeads = hax.Axis("NumHeads", 1)
Hid = hax.Axis("Hid", 8)

query = hax.ones((NumHeads, Pos, Hid))

result = self_attention(Pos, Hid, query, query, query, is_causal=True)
assert result.axes == (NumHeads, Pos, Hid)

k = query.rename({Pos: KeyPos})
cmask = causal_mask(Pos, KeyPos)
result2 = dot_product_attention(KeyPos, Hid, query, k, k, mask=cmask)
assert result2.axes == (NumHeads, Pos, Hid)

# tight tolerances because it should be exactly the same computation
assert jnp.allclose(result.array, result2.array)


def test_alibi_attention_bias():
KeyPos = hax.Axis("KeyPos", 20)
NumHeads = hax.Axis("NumHeads", 1)
Expand Down

0 comments on commit a2e311e

Please sign in to comment.