Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix SpatialReductionAttention in PVT. #6488

Merged
merged 2 commits into from
Nov 15, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions mmdet/models/backbones/pvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ def __init__(self,
# The ret[0] of build_norm_layer is norm name.
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]

# handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
from mmdet import mmcv_version, digit_version
if mmcv_version < digit_version('1.3.17'):
warnings.warn('The legacy version of forward function in'
'SpatialReductionAttention is deprecated in'
'mmcv>=1.3.17 and will no longer support in the'
'future. Please upgrade your mmcv.')
self.forward = self.legacy_forward

def forward(self, x, hw_shape, identity=None):

x_q = x
Expand All @@ -169,6 +178,37 @@ def forward(self, x, hw_shape, identity=None):
if identity is None:
identity = x_q

# Because the dataflow('key', 'query', 'value') of
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
# embed_dims), We should adjust the shape of dataflow from
# batch_first (batch, num_query, embed_dims) to num_query_first
# (num_query ,batch, embed_dims), and recover ``attn_output``
# from num_query_first to batch_first.
if self.batch_first:
x_q = x_q.transpose(0, 1)
x_kv = x_kv.transpose(0, 1)

out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]

if self.batch_first:
out = out.transpose(0, 1)

return identity + self.dropout_layer(self.proj_drop(out))

def legacy_forward(self, x, hw_shape, identity=None):
"""multi head attention forward in mmcv version < 1.3.17."""
x_q = x
if self.sr_ratio > 1:
x_kv = nlc_to_nchw(x, hw_shape)
x_kv = self.sr(x_kv)
x_kv = nchw_to_nlc(x_kv)
x_kv = self.norm(x_kv)
else:
x_kv = x

if identity is None:
identity = x_q

out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]

return identity + self.dropout_layer(self.proj_drop(out))
Expand Down