Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split xformers.ops #486

Merged
merged 5 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions xformers/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import torch

from .memory_efficient_attention import ( # noqa: F401
AttentionMask,
AttentionOpBase,
AttentionOpDispatch,
LowerTriangularMask,
MemoryEfficientAttentionCutlassFwdFlashBwOp,
MemoryEfficientAttentionCutlassOp,
MemoryEfficientAttentionFlashAttentionOp,
MemoryEfficientAttentionOp,
memory_efficient_attention,
)
from .unbind import efficient_stack, get_stack_strides, unbind # noqa: F401


def masked_matmul(a, b, mask=None):
if torch.overrides.has_torch_function((a, b, mask)):
return torch.overrides.handle_torch_function(
masked_matmul, (a, b, mask), a, b, mask
)

att = a @ b

if mask is None:
return att

if mask.dtype == torch.bool:
if mask.ndim == 2:
mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1)
# mask is presumed false == ignore
att[~mask] = float("-inf")
else:
# mask is presumed additive
att += mask
return att
96 changes: 2 additions & 94 deletions xformers/ops.py → xformers/ops/memory_efficient_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,18 @@
import math
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Any, List, Mapping, Optional, Sequence, Set, Tuple, Type, Union
from typing import Any, List, Mapping, Optional, Set, Type, Union

import torch

try:
from . import _C_flashattention # type: ignore[attr-defined]
from .. import _C_flashattention # type: ignore[attr-defined]

has_flashattention = True
except ImportError:
has_flashattention = False


def masked_matmul(a, b, mask=None):
if torch.overrides.has_torch_function((a, b, mask)):
return torch.overrides.handle_torch_function(
masked_matmul, (a, b, mask), a, b, mask
)

att = a @ b

if mask is None:
return att

if mask.dtype == torch.bool:
if mask.ndim == 2:
mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1)
# mask is presumed false == ignore
att[~mask] = float("-inf")
else:
# mask is presumed additive
att += mask
return att


def _get_xformers_operator(name: str):
def no_such_operator(*args, **kwargs):
raise RuntimeError(
Expand Down Expand Up @@ -751,76 +729,6 @@ def from_arguments(
)


def get_stack_strides(
tensors: Sequence[torch.Tensor], dim: int
) -> Optional[Tuple[int, ...]]:
"""
If the tensors are already stacked, returns the strides of the stacked
tensors. Otherwise returns None.
"""
if len(tensors) <= 1 or dim > tensors[0].ndim:
return None

final_stride = []
for i in range(tensors[0].ndim + 1):
if i == dim:
final_stride.append(
tensors[1].storage_offset() - tensors[0].storage_offset()
)
continue
if i > dim:
i -= 1
final_stride.append(tensors[0].stride(i))

for i, x in enumerate(tensors):
# Sanity checks
if x.shape != tensors[0].shape:
return None
# Actual storage check
if x.storage().data_ptr() != tensors[0].storage().data_ptr():
return None
if x.stride() != tensors[0].stride():
return None
if x.storage_offset() != tensors[0].storage_offset() + i * final_stride[dim]:
return None
return tuple(final_stride)


def efficient_stack(
tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int
) -> torch.Tensor:
strides = get_stack_strides(tensors, dim)
if strides is not None:
input_shape = list(tensors[0].shape)
input_shape.insert(dim, len(tensors))
return tensors[0].as_strided(input_shape, strides)
return torch.stack(tensors, dim=dim)


class _Unbind(torch.autograd.Function):
"""
Splits a packed `qkv` tensor into query, key and values.
The magic happens in the backward. We want to `torch.stack` the tensors
together, but we don't need to if the gradients have already the same storage
(and that is something that our attention operators support)
"""

@staticmethod
# type: ignore
def forward(ctx, x: torch.Tensor, dim: int):
ctx.dim = dim
return x.unbind(dim)

@classmethod
# type: ignore
def backward(cls, ctx, *tensors: torch.Tensor):
return efficient_stack(tensors, ctx.dim), None


def unbind(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, ...]:
return _Unbind.apply(x, dim)


def memory_efficient_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down
78 changes: 78 additions & 0 deletions xformers/ops/unbind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Sequence, Tuple, Union

import torch


def get_stack_strides(
tensors: Sequence[torch.Tensor], dim: int
) -> Optional[Tuple[int, ...]]:
"""
If the tensors are already stacked, returns the strides of the stacked
tensors. Otherwise returns None.
"""
if len(tensors) <= 1 or dim > tensors[0].ndim:
return None

final_stride = []
for i in range(tensors[0].ndim + 1):
if i == dim:
final_stride.append(
tensors[1].storage_offset() - tensors[0].storage_offset()
)
continue
if i > dim:
i -= 1
final_stride.append(tensors[0].stride(i))

for i, x in enumerate(tensors):
# Sanity checks
if x.shape != tensors[0].shape:
return None
# Actual storage check
if x.storage().data_ptr() != tensors[0].storage().data_ptr():
return None
if x.stride() != tensors[0].stride():
return None
if x.storage_offset() != tensors[0].storage_offset() + i * final_stride[dim]:
return None
return tuple(final_stride)


def efficient_stack(
tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int
) -> torch.Tensor:
strides = get_stack_strides(tensors, dim)
if strides is not None:
input_shape = list(tensors[0].shape)
input_shape.insert(dim, len(tensors))
return tensors[0].as_strided(input_shape, strides)
return torch.stack(tensors, dim=dim)


class _Unbind(torch.autograd.Function):
"""
Splits a packed `qkv` tensor into query, key and values.
The magic happens in the backward. We want to `torch.stack` the tensors
together, but we don't need to if the gradients have already the same storage
(and that is something that our attention operators support)
"""

@staticmethod
# type: ignore
def forward(ctx, x: torch.Tensor, dim: int):
ctx.dim = dim
return x.unbind(dim)

@classmethod
# type: ignore
def backward(cls, ctx, *tensors: torch.Tensor):
return efficient_stack(tensors, ctx.dim), None


def unbind(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, ...]:
return _Unbind.apply(x, dim)