From 59e86fde20f3fcafbe79e80e94ea85c7e43b28e9 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Mon, 15 Nov 2021 19:08:51 +0800 Subject: [PATCH] [Fix] Fix SpatialReductionAttention in PVT. (#6488) * [Fix] Fix SpatialReductionAttention in PVT * Add warning --- mmdet/models/backbones/pvt.py | 40 +++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/mmdet/models/backbones/pvt.py b/mmdet/models/backbones/pvt.py index 243ff603db8..11c1eb0e8ca 100644 --- a/mmdet/models/backbones/pvt.py +++ b/mmdet/models/backbones/pvt.py @@ -155,6 +155,15 @@ def __init__(self, # The ret[0] of build_norm_layer is norm name. self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa + from mmdet import mmcv_version, digit_version + if mmcv_version < digit_version('1.3.17'): + warnings.warn('The legacy version of forward function in' + 'SpatialReductionAttention is deprecated in' + 'mmcv>=1.3.17 and will no longer support in the' + 'future. Please upgrade your mmcv.') + self.forward = self.legacy_forward + def forward(self, x, hw_shape, identity=None): x_q = x @@ -169,6 +178,37 @@ def forward(self, x, hw_shape, identity=None): if identity is None: identity = x_q + # Because the dataflow('key', 'query', 'value') of + # ``torch.nn.MultiheadAttention`` is (num_query, batch, + # embed_dims), We should adjust the shape of dataflow from + # batch_first (batch, num_query, embed_dims) to num_query_first + # (num_query ,batch, embed_dims), and recover ``attn_output`` + # from num_query_first to batch_first. + if self.batch_first: + x_q = x_q.transpose(0, 1) + x_kv = x_kv.transpose(0, 1) + + out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] + + if self.batch_first: + out = out.transpose(0, 1) + + return identity + self.dropout_layer(self.proj_drop(out)) + + def legacy_forward(self, x, hw_shape, identity=None): + """multi head attention forward in mmcv version < 1.3.17.""" + x_q = x + if self.sr_ratio > 1: + x_kv = nlc_to_nchw(x, hw_shape) + x_kv = self.sr(x_kv) + x_kv = nchw_to_nlc(x_kv) + x_kv = self.norm(x_kv) + else: + x_kv = x + + if identity is None: + identity = x_q + out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] return identity + self.dropout_layer(self.proj_drop(out))