Skip to content

Commit

Permalink
fix the wrong function reference bug in BaseTransformerLayer when bat…
Browse files Browse the repository at this point in the history
…ch_first is True (#1418)
  • Loading branch information
gaotongxiao authored Nov 2, 2021
1 parent 426e229 commit c522b47
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 21 deletions.
35 changes: 14 additions & 21 deletions mmcv/cnn/bricks/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,27 +102,6 @@ def __init__(self,

self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
**kwargs)
if self.batch_first:

def _bnc_to_nbc(forward):
"""Because the dataflow('key', 'query', 'value') of
``torch.nn.MultiheadAttention`` is (num_query, batch,
embed_dims), We should adjust the shape of dataflow from
batch_first (batch, num_query, embed_dims) to num_query_first
(num_query ,batch, embed_dims), and recover ``attn_output``
from num_query_first to batch_first."""

def forward_wrapper(**kwargs):
convert_keys = ('key', 'query', 'value')
for key in kwargs.keys():
if key in convert_keys:
kwargs[key] = kwargs[key].transpose(0, 1)
attn_output, attn_output_weights = forward(**kwargs)
return attn_output.transpose(0, 1), attn_output_weights

return forward_wrapper

self.attn.forward = _bnc_to_nbc(self.attn.forward)

self.proj_drop = nn.Dropout(proj_drop)
self.dropout_layer = build_dropout(
Expand Down Expand Up @@ -199,13 +178,27 @@ def forward(self,
if key_pos is not None:
key = key + key_pos

# Because the dataflow('key', 'query', 'value') of
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
# embed_dims), We should adjust the shape of dataflow from
# batch_first (batch, num_query, embed_dims) to num_query_first
# (num_query ,batch, embed_dims), and recover ``attn_output``
# from num_query_first to batch_first.
if self.batch_first:
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)

out = self.attn(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask)[0]

if self.batch_first:
out = out.transpose(0, 1)

return identity + self.dropout_layer(self.proj_drop(out))


Expand Down
25 changes: 25 additions & 0 deletions tests/test_cnn/test_transformer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import copy

import pytest
import torch

from mmcv.cnn.bricks.drop import DropPath
from mmcv.cnn.bricks.transformer import (FFN, BaseTransformerLayer,
MultiheadAttention,
TransformerLayerSequence)
from mmcv.runner import ModuleList


def test_multiheadattention():
Expand Down Expand Up @@ -92,6 +95,28 @@ def test_ffn():
ffn(input_tensor).sum() + residual.sum() - input_tensor.sum())


@pytest.mark.skipif(not torch.cuda.is_available(), reason='Cuda not available')
def test_basetransformerlayer_cuda():
# To test if the BaseTransformerLayer's behaviour remains
# consistent after being deepcopied
operation_order = ('self_attn', 'ffn')
baselayer = BaseTransformerLayer(
operation_order=operation_order,
batch_first=True,
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
),
)
baselayers = ModuleList([copy.deepcopy(baselayer) for _ in range(2)])
baselayers.to('cuda')
x = torch.rand(2, 10, 256).cuda()
for m in baselayers:
x = m(x)
assert x.shape == torch.Size([2, 10, 256])


def test_basetransformerlayer():
attn_cfgs = dict(type='MultiheadAttention', embed_dims=256, num_heads=8),
feedforward_channels = 2048
Expand Down

0 comments on commit c522b47

Please sign in to comment.