diff --git a/third_party/cutlass b/third_party/cutlass index 4db6a6140e..3441886dd5 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit 4db6a6140e45c4ffe6339c55b43b159602fa1f35 +Subproject commit 3441886dd5adf9ebd6fd74671a0186da5c6b5f0b diff --git a/xformers/ops/__init__.py b/xformers/ops/__init__.py new file mode 100644 index 0000000000..c5d69faaf8 --- /dev/null +++ b/xformers/ops/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .memory_efficient_attention import ( # noqa: F401 + AttentionMask, + AttentionOpBase, + AttentionOpDispatch, + LowerTriangularMask, + MemoryEfficientAttentionCutlassFwdFlashBwOp, + MemoryEfficientAttentionCutlassOp, + MemoryEfficientAttentionFlashAttentionOp, + MemoryEfficientAttentionOp, + memory_efficient_attention, +) +from .unbind import efficient_stack, get_stack_strides, unbind # noqa: F401 + + +def masked_matmul(a, b, mask=None): + if torch.overrides.has_torch_function((a, b, mask)): + return torch.overrides.handle_torch_function( + masked_matmul, (a, b, mask), a, b, mask + ) + + att = a @ b + + if mask is None: + return att + + if mask.dtype == torch.bool: + if mask.ndim == 2: + mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1) + # mask is presumed false == ignore + att[~mask] = float("-inf") + else: + # mask is presumed additive + att += mask + return att diff --git a/xformers/ops.py b/xformers/ops/memory_efficient_attention.py similarity index 89% rename from xformers/ops.py rename to xformers/ops/memory_efficient_attention.py index 53a3523720..3f7cc0bdb5 100644 --- a/xformers/ops.py +++ b/xformers/ops/memory_efficient_attention.py @@ -7,40 +7,18 @@ import math from dataclasses import dataclass from types import SimpleNamespace -from typing import Any, List, Mapping, Optional, Sequence, Set, Tuple, Type, Union +from typing import Any, List, Mapping, Optional, Set, Type, Union import torch try: - from . import _C_flashattention # type: ignore[attr-defined] + from .. import _C_flashattention # type: ignore[attr-defined] has_flashattention = True except ImportError: has_flashattention = False -def masked_matmul(a, b, mask=None): - if torch.overrides.has_torch_function((a, b, mask)): - return torch.overrides.handle_torch_function( - masked_matmul, (a, b, mask), a, b, mask - ) - - att = a @ b - - if mask is None: - return att - - if mask.dtype == torch.bool: - if mask.ndim == 2: - mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1) - # mask is presumed false == ignore - att[~mask] = float("-inf") - else: - # mask is presumed additive - att += mask - return att - - def _get_xformers_operator(name: str): def no_such_operator(*args, **kwargs): raise RuntimeError( @@ -751,76 +729,6 @@ def from_arguments( ) -def get_stack_strides( - tensors: Sequence[torch.Tensor], dim: int -) -> Optional[Tuple[int, ...]]: - """ - If the tensors are already stacked, returns the strides of the stacked - tensors. Otherwise returns None. - """ - if len(tensors) <= 1 or dim > tensors[0].ndim: - return None - - final_stride = [] - for i in range(tensors[0].ndim + 1): - if i == dim: - final_stride.append( - tensors[1].storage_offset() - tensors[0].storage_offset() - ) - continue - if i > dim: - i -= 1 - final_stride.append(tensors[0].stride(i)) - - for i, x in enumerate(tensors): - # Sanity checks - if x.shape != tensors[0].shape: - return None - # Actual storage check - if x.storage().data_ptr() != tensors[0].storage().data_ptr(): - return None - if x.stride() != tensors[0].stride(): - return None - if x.storage_offset() != tensors[0].storage_offset() + i * final_stride[dim]: - return None - return tuple(final_stride) - - -def efficient_stack( - tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int -) -> torch.Tensor: - strides = get_stack_strides(tensors, dim) - if strides is not None: - input_shape = list(tensors[0].shape) - input_shape.insert(dim, len(tensors)) - return tensors[0].as_strided(input_shape, strides) - return torch.stack(tensors, dim=dim) - - -class _Unbind(torch.autograd.Function): - """ - Splits a packed `qkv` tensor into query, key and values. - The magic happens in the backward. We want to `torch.stack` the tensors - together, but we don't need to if the gradients have already the same storage - (and that is something that our attention operators support) - """ - - @staticmethod - # type: ignore - def forward(ctx, x: torch.Tensor, dim: int): - ctx.dim = dim - return x.unbind(dim) - - @classmethod - # type: ignore - def backward(cls, ctx, *tensors: torch.Tensor): - return efficient_stack(tensors, ctx.dim), None - - -def unbind(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, ...]: - return _Unbind.apply(x, dim) - - def memory_efficient_attention( query: torch.Tensor, key: torch.Tensor, diff --git a/xformers/ops/unbind.py b/xformers/ops/unbind.py new file mode 100644 index 0000000000..62bfeb7851 --- /dev/null +++ b/xformers/ops/unbind.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Sequence, Tuple, Union + +import torch + + +def get_stack_strides( + tensors: Sequence[torch.Tensor], dim: int +) -> Optional[Tuple[int, ...]]: + """ + If the tensors are already stacked, returns the strides of the stacked + tensors. Otherwise returns None. + """ + if len(tensors) <= 1 or dim > tensors[0].ndim: + return None + + final_stride = [] + for i in range(tensors[0].ndim + 1): + if i == dim: + final_stride.append( + tensors[1].storage_offset() - tensors[0].storage_offset() + ) + continue + if i > dim: + i -= 1 + final_stride.append(tensors[0].stride(i)) + + for i, x in enumerate(tensors): + # Sanity checks + if x.shape != tensors[0].shape: + return None + # Actual storage check + if x.storage().data_ptr() != tensors[0].storage().data_ptr(): + return None + if x.stride() != tensors[0].stride(): + return None + if x.storage_offset() != tensors[0].storage_offset() + i * final_stride[dim]: + return None + return tuple(final_stride) + + +def efficient_stack( + tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int +) -> torch.Tensor: + strides = get_stack_strides(tensors, dim) + if strides is not None: + input_shape = list(tensors[0].shape) + input_shape.insert(dim, len(tensors)) + return tensors[0].as_strided(input_shape, strides) + return torch.stack(tensors, dim=dim) + + +class _Unbind(torch.autograd.Function): + """ + Splits a packed `qkv` tensor into query, key and values. + The magic happens in the backward. We want to `torch.stack` the tensors + together, but we don't need to if the gradients have already the same storage + (and that is something that our attention operators support) + """ + + @staticmethod + # type: ignore + def forward(ctx, x: torch.Tensor, dim: int): + ctx.dim = dim + return x.unbind(dim) + + @classmethod + # type: ignore + def backward(cls, ctx, *tensors: torch.Tensor): + return efficient_stack(tensors, ctx.dim), None + + +def unbind(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, ...]: + return _Unbind.apply(x, dim)