diff --git a/.gitmodules b/.gitmodules index 6d5c0f8734..ab23324aec 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,4 +3,4 @@ url = https://github.com/HazyResearch/flash-attention.git [submodule "third_party/cutlass"] path = third_party/cutlass - url = https://github.com/hwu36/cutlass.git + url = https://github.com/NVIDIA/cutlass.git diff --git a/tests/test_block_factory.py b/tests/test_block_factory.py index 275fcd5f56..79e85f7a1d 100644 --- a/tests/test_block_factory.py +++ b/tests/test_block_factory.py @@ -8,7 +8,7 @@ # Automatically fetch all registered attentions and Feedforwards from xformers.components import Activation -from xformers.components.attention import ATTENTION_REGISTRY +from xformers.components.attention import ATTENTION_REGISTRY, AttentionMask from xformers.components.feedforward import FEEDFORWARD_REGISTRY from xformers.factory import ( xFormerDecoderBlock, @@ -112,10 +112,12 @@ def test_xformer_encoder_block( _ = block(inputs) # Check that we support attention masking, at least interface wise (do not check correctness yet) - att_mask = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device) + att_mask_tensor = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device) + att_mask = AttentionMask.from_bool(att_mask_tensor) if block.supports_attention_mask: _ = block(inputs, att_mask=att_mask) + _ = block(inputs, att_mask=att_mask_tensor) else: with pytest.raises(AssertionError): # Check that passing an attention mask to a mechanism which does not support it raises @@ -226,7 +228,8 @@ def test_xformer_decoder_block( ) # NOTE: does not make a lot of sense, just checking dimensions # Check that we support masking, at least interface wise (do not check correctness yet) - att_mask = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device) + att_mask_tensor = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device) + att_mask = AttentionMask.from_bool(att_mask_tensor) input_mask = torch.randn(SEQ, dtype=torch.float, device=device) input_mask[input_mask < 0.0] = -float("inf") @@ -235,6 +238,9 @@ def test_xformer_decoder_block( _ = decoder_block( inputs, encoded, encoder_att_mask=att_mask, input_mask=input_mask ) + _ = decoder_block( + inputs, encoded, encoder_att_mask=att_mask_tensor, input_mask=input_mask + ) # Test different sequence lengths when encoding and decoding if ( diff --git a/tests/test_core_attention.py b/tests/test_core_attention.py index 3a9b01fba5..5e022b96cd 100644 --- a/tests/test_core_attention.py +++ b/tests/test_core_attention.py @@ -59,6 +59,15 @@ def test_core_attention_mask_types(): # Now properly handled assert torch.allclose(r_dense_add, r_sparse_add) + # Test additive mask with mismatched batch dim + d = b // 2 + mask = torch.rand(d, s, s) > prob + float_mask_add = torch.zeros_like(mask, dtype=torch.float) + float_mask_add = float_mask_add.masked_fill(mask, float("-inf")) + + # Make sure masking doesn't return errors + r_dense_add = scaled_dot_product_attention(a, a, a, float_mask_add) + @pytest.mark.parametrize("device", _devices) def test_amp_attention_dense_no_mask(device): diff --git a/third_party/cutlass b/third_party/cutlass index 012c62c748..06eb90cc0d 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit 012c62c748bdd5b2badc3ebe83e0891dee7c4e31 +Subproject commit 06eb90cc0daae633b1e25e80ace1ef81ac158baa diff --git a/xformers/components/attention/core.py b/xformers/components/attention/core.py index 443bbdacb0..6709eaa84a 100644 --- a/xformers/components/attention/core.py +++ b/xformers/components/attention/core.py @@ -106,6 +106,16 @@ def _matmul_with_mask( att[~mask] = float("-inf") else: # mask is presumed additive + # repeat if batch sizes don't match + if ( + not isinstance(mask, SparseCS) + and mask.ndim == 3 + and mask.shape[0] != att.shape[0] + and (att.shape[0] % mask.shape[0]) == 0 + ): + repeat_factor = att.shape[0] // mask.shape[0] + mask = mask.repeat([repeat_factor, 1, 1]) + logger.info("Mismatched batch dimensions for mask, repeating mask.") att += mask return att diff --git a/xformers/components/attention/global_tokens.py b/xformers/components/attention/global_tokens.py index d0a6f0166e..653ed619c8 100644 --- a/xformers/components/attention/global_tokens.py +++ b/xformers/components/attention/global_tokens.py @@ -5,7 +5,7 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch import torch.nn as nn @@ -88,7 +88,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - att_mask: Optional[torch.Tensor] = None, + att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, *_, **__, ): @@ -101,7 +101,9 @@ def forward( if att_mask.dtype == torch.bool and isinstance( self.attention_mask, AttentionMask ): - mask = self.attention_mask + AttentionMask.from_bool(att_mask) + if not isinstance(att_mask, AttentionMask): + att_mask = AttentionMask.from_bool(att_mask) + mask = self.attention_mask + att_mask else: mask = self.attention_mask & att_mask else: diff --git a/xformers/components/attention/local.py b/xformers/components/attention/local.py index 68df4bca3d..3220a8d401 100644 --- a/xformers/components/attention/local.py +++ b/xformers/components/attention/local.py @@ -5,7 +5,7 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch import torch.nn as nn @@ -13,6 +13,7 @@ from xformers.components.attention import ( Attention, AttentionConfig, + AttentionMask, maybe_sparsify, register_attention, sparsify, @@ -97,7 +98,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - att_mask: Optional[torch.Tensor] = None, + att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, *args, **kwargs, ): @@ -106,9 +107,13 @@ def forward( self.attention_mask = self._get_local_mask(q.shape).to(q.device) # Take into account the optional user mask - mask = ( - self.attention_mask if att_mask is None else self.attention_mask & att_mask - ) + if att_mask is None: + mask = self.attention_mask + else: + if isinstance(att_mask, AttentionMask): + # Needed because & op not defined for SparseCS with AttentionMask + att_mask = att_mask.to_bool() + mask = self.attention_mask & att_mask return scaled_dot_product_attention( q=q, k=k, v=v, att_mask=mask, dropout=self.attn_drop diff --git a/xformers/components/attention/ortho.py b/xformers/components/attention/ortho.py index 392ed96f36..3737f6cdd0 100644 --- a/xformers/components/attention/ortho.py +++ b/xformers/components/attention/ortho.py @@ -7,14 +7,19 @@ import logging from dataclasses import dataclass from enum import Enum -from typing import Optional +from typing import Optional, Union import torch import torch.autograd.profiler as profiler import torch.nn as nn import torch.nn.functional as Fn -from xformers.components.attention import Attention, AttentionConfig, register_attention +from xformers.components.attention import ( + Attention, + AttentionConfig, + AttentionMask, + register_attention, +) from xformers.components.attention.core import ( scaled_dot_product_attention, scaled_query_key_softmax, @@ -83,7 +88,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - att_mask: Optional[torch.Tensor] = None, + att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None, *args, **kwargs, ): diff --git a/xformers/components/attention/random.py b/xformers/components/attention/random.py index 5e3ee08e69..e07e6c8679 100644 --- a/xformers/components/attention/random.py +++ b/xformers/components/attention/random.py @@ -5,7 +5,7 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch import torch.nn as nn @@ -91,7 +91,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - att_mask: Optional[torch.Tensor] = None, + att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, *args, **kwargs, ): @@ -106,6 +106,9 @@ def forward( ): mask = self.rand_attention_mask + AttentionMask.from_bool(att_mask) else: + if isinstance(att_mask, AttentionMask): + # Needed because & op not defined for SparseCS with AttentionMask + att_mask = att_mask.to_bool() mask = self.rand_attention_mask & att_mask else: mask = self.rand_attention_mask diff --git a/xformers/components/swiglu/cuda/dual_gemm_silu_identity_mul.cu b/xformers/components/swiglu/cuda/dual_gemm_silu_identity_mul.cu index a51cb42a6f..572167caf9 100644 --- a/xformers/components/swiglu/cuda/dual_gemm_silu_identity_mul.cu +++ b/xformers/components/swiglu/cuda/dual_gemm_silu_identity_mul.cu @@ -84,7 +84,7 @@ std::tuple dual_gemm_silu_identity_mul_( EpilogueOutputOp01, EpilogueOutputOp01, EpilogueOutputOp2, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>, kStages, kStoreD0, kStoreD1, diff --git a/xformers/factory/block_factory.py b/xformers/factory/block_factory.py index e139a2a6fc..113f440fe0 100644 --- a/xformers/factory/block_factory.py +++ b/xformers/factory/block_factory.py @@ -20,6 +20,7 @@ build_multi_head_attention, build_patch_embedding, ) +from xformers.components.attention import AttentionMask from xformers.components.feedforward import build_feedforward from xformers.components.positional_embedding import build_positional_embedding from xformers.components.residual import get_deepnorm_coefficients @@ -206,7 +207,7 @@ def get_reversible_layer(config) -> Tuple[nn.Module, nn.Module]: def forward( self, x: torch.Tensor, - att_mask: Optional[torch.Tensor] = None, + att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, input_mask: Optional[torch.Tensor] = None, ): if self.patch_emb is not None: @@ -327,8 +328,8 @@ def forward( self, target: torch.Tensor, memory: torch.Tensor, - encoder_att_mask: Optional[torch.Tensor] = None, - decoder_att_mask: Optional[torch.Tensor] = None, + encoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, + decoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, input_mask: Optional[torch.Tensor] = None, ): if self.pose_encoding is not None: