Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Adding Visual Attention #329

Merged
merged 1 commit into from
Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Four blocksparsity layouts from DeepSpeed [#320]
- Support several initialization options [#312]
- Conv2DFeedforward feedforward part [#321]
- VisualAttention [#329]


## [0.0.11] - 2022-05-30
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*
- [2D Pooling](xformers/components/attention/pooling.py)
- *[Metaformer is actually what you need for vision, Yu et al.](https://arxiv.org/pdf/2111.11418v1.pdf)*

- [Visual Attention](xformers/components/attention/visual.py)
- *[`Visual Attention Network`_, Guo et al](https://arxiv.org/pdf/2202.09741.pdf)*

- ... add a new one [see Contribution.md](CONTRIBUTING.md)

</p></details>
Expand Down Expand Up @@ -199,7 +202,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*

<details><summary>Initializations </summary><p>
This is completely optional, and will only occur when generating full models through xFormers, not when picking parts individually.

There are basically two initialization mechanisms exposed, but the user is free to initialize weights as he/she sees fit after the fact.
- Parts can expose a `init_weights()` method, which define sane defaults
- xFormers supports [specific init schemes](xformers/factory/weight_init.py) which *can take precedence* over the init_weights()
Expand Down
9 changes: 5 additions & 4 deletions examples/cifarMetaformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
num_classes=10,
dim=384,
attention="scaled_dot_product",
feedforward="MLP",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about the defaults here, how to show that you can use these to repro "Visual Attention" for instance ? Should we show different presets ?

layer_norm_style="pre",
use_rotary_embeddings=True,
linear_warmup_ratio=0.1,
Expand All @@ -45,8 +46,7 @@ def __init__(
# Generate the skeleton of our hierarchical Transformer

# This is a small poolformer configuration, adapted to the small CIFAR10 pictures (32x32)
# Any other related config would work,
# and the attention mechanisms don't have to be the same across layers
# Any other related config would work, and the attention mechanisms don't have to be the same across layers
base_hierarchical_configs = [
BasicLayerConfig(
embedding=64,
Expand Down Expand Up @@ -121,8 +121,8 @@ def forward(self, x):

# Adjust batch depending on the available memory on your machine.
# You can also use reversible layers to save memory
REF_BATCH = 512
BATCH = 512 # lower if not enough GPU memory
REF_BATCH = 768
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like a classic default for Cifar10

BATCH = 256 # lower if not enough GPU memory

MAX_EPOCHS = 50
NUM_WORKERS = 4
Expand Down Expand Up @@ -172,6 +172,7 @@ def forward(self, x):
num_classes=num_classes,
attention="scaled_dot_product",
layer_norm_style="pre",
feedforward="MLP",
use_rotary_embeddings=True,
)
trainer = pl.Trainer(
Expand Down
34 changes: 25 additions & 9 deletions tests/test_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@
)

BATCH = 2
SEQ = 128 if torch.cuda.is_available() else 32
SEQ = 128 if torch.cuda.is_available() else 36
MODEL = 128 if torch.cuda.is_available() else 16
GLOBAL_ATTENTION_RATIO = (
_DENSITY_THRESHOLD * 0.9
) # Make sure that we test the sparse implementation, no matter the threshold

assert ATTENTION_REGISTRY.keys(), "Attention layers should have been registered"

_non_order_invariant_attentions = ["visual", "pooling"]


def _get_multihead(
attention_name,
Expand Down Expand Up @@ -93,7 +95,9 @@ def noop(x):
@pytest.mark.parametrize("residual_dropout", [0.0, 0.1])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("heads", [1, 4])
@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys())
@pytest.mark.parametrize(
"attention_name", ATTENTION_REGISTRY.keys() - _non_order_invariant_attentions
)
@pytest.mark.parametrize("device", DEVICES)
def test_order_invariance(
attention_name: str,
Expand All @@ -104,9 +108,6 @@ def test_order_invariance(
device: torch.device,
):

if int(math.sqrt(SEQ)) ** 2 != SEQ and attention_name == "pooling":
pytest.skip(f"{attention_name} requires squared sequence lengths")

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

Expand All @@ -120,6 +121,12 @@ def test_order_invariance(
use_seperate_proj_weights=False,
)

if (
int(math.sqrt(SEQ)) ** 2 != SEQ
and multi_head.attention.requires_squared_context
):
pytest.skip(f"{attention_name} requires squared sequence lengths")

# Check that a shuffled input produces the same results
seqs = [SEQ, SEQ // 2] if (attention_name != "blocksparse") else [SEQ]

Expand Down Expand Up @@ -304,12 +311,15 @@ def test_broadcast_batch_dimension(
device: torch.device,
batch_sizes: Tuple[int, int, int],
):
if int(math.sqrt(SEQ)) ** 2 != SEQ and attention_name == "pooling":
pytest.skip(f"{attention_name} requires squared sequence lengths")

Q_BATCH, K_BATCH, V_BATCH = batch_sizes
multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device)

if (
int(math.sqrt(SEQ)) ** 2 != SEQ
and multi_head.attention.requires_squared_context
):
pytest.skip(f"{attention_name} requires squared sequence lengths")

if multi_head.attention.requires_same_k_q_dimensions:
# pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre.
pytest.skip(f"{attention_name} does not support different k, q dimensions yet.")
Expand Down Expand Up @@ -388,14 +398,20 @@ def test_torch_script_ability(
heads: int,
attn_dropout: float,
):
if attention_name in {"favor", "global", "local", "random", "pooling"}:
if attention_name in {"favor", "global", "local", "random"}:
# pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre.
pytest.skip(f"{attention_name} does not support scripting yet.")

device = torch.device("cpu")

multi_head = _get_multihead(attention_name, attn_dropout, 0.0, False, heads, device)

if (
int(math.sqrt(SEQ)) ** 2 != SEQ
and multi_head.attention.requires_squared_context
):
pytest.skip(f"{attention_name} requires squared sequence lengths")

# input for tracing the function
q = torch.rand((BATCH, SEQ, MODEL), device=device)
k = torch.rand((BATCH, SEQ, MODEL), device=device)
Expand Down
3 changes: 3 additions & 0 deletions xformers/components/attention/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
# so that the MHA wrapper should skip it
self.requires_skip_multi_head = False

# This attention requires a context length which is squared, often due to 2D pooling
self.requires_squared_context = False

# Whether this attention mechanism supports attention masks
self.supports_attention_mask = True
self.supports_key_padding_mask = False
Expand Down
4 changes: 4 additions & 0 deletions xformers/components/attention/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def __init__(
# This operator does not really handle q,k,v
self.requires_same_k_q_dimensions = True

# This attention requires the 2d structure out of the context,
# implictly assumed to be a squared length
self.requires_squared_context = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was already true before, but not formalized like this, I think it's cleaner ? "pooling" (PoolingFormer) and "visual" both recover the 2d structure of and assume a squared context length for that


def forward(self, q: torch.Tensor, *_, **__):
# Expose the 2D token structure
B, HW, C = q.shape
Expand Down
96 changes: 96 additions & 0 deletions xformers/components/attention/visual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import math
from dataclasses import dataclass

import torch
import torch.nn as nn

from xformers.components.attention import Attention, AttentionConfig, register_attention


@dataclass
class VisualAttentionConfig(AttentionConfig):
dim_model: int # dimension of the input sequence


class LKA(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.conv_spatial = nn.Conv2d(
dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3
)
self.conv1 = nn.Conv2d(dim, dim, 1)

def forward(self, x: torch.Tensor):
u = x.clone()
attn = self.conv0(x)
attn = self.conv_spatial(attn)
attn = self.conv1(attn)

return u * attn


@register_attention("visual", VisualAttentionConfig)
class Visual(Attention):
def __init__(
self,
dim_model: int,
*_,
**__,
):
"""
Large kernel attention mechanism, as proposed in `Visual Attention Network`_, Guo et al (2022).
The original notation is tentatively kept as is. See https://github.com/Visual-Attention-Network
for the reference implementation

.. Note: compared to the paper, this block contains the LKA (Large Kernel Attention)
and the prior and posterior transformations (Conv2d and activation)

.. _`Visual Attention Network` : https://arxiv.org/pdf/2202.09741.pdf
"""
super().__init__()

self.block = nn.Sequential(
nn.Conv2d(dim_model, dim_model, 1),
nn.GELU(),
LKA(dim_model),
nn.Conv2d(dim_model, dim_model, 1),
)

# MHA related flags:
self.requires_same_k_q_dimensions = (
True # This mechanism only really supports self attention
)
self.supports_attention_mask = False
self.requires_skip_multi_head = (
True # This mechanism skips the multihead attention altogether
)
self.requires_squared_context = (
True # Recovering the 2D structure from context assumes squared content
)

self.requires_input_projection = (
False # This mechanism does not require that the MHA projects inputs
)

def forward(self, q: torch.Tensor, *_, **__):
# Expose the 2D token structure
B, HW, C = q.shape
H = int(math.sqrt(HW))
assert H * H == HW

x = q.transpose(-2, -1).reshape(B, C, H, H)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've not benchmarked that, but maybe that it's beneficial to .contiguous() here, depending on the Conv2D kernels


# Large kernel attention
residual = x.clone()
x = self.block(x)
x = x + residual

# Get back to B HW C
return x.flatten(2, 3).transpose(-2, -1)
5 changes: 4 additions & 1 deletion xformers/factory/block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,10 @@ def __init__(self, config: xFormerDecoderConfig, **kwargs):
# Expose attention or feedforward specific capabilities
self.supports_attention_mask = mha.attention.supports_attention_mask
self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
self.requires_squared_context_length = feedforward.requires_squared_context
self.requires_squared_context_length = (
feedforward.requires_squared_context
or mha.attention.requires_squared_context
)

self.causal_attention = (
mha.attention.causal if hasattr(mha.attention, "causal") else False
Expand Down