Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Support value_proj_ratio in MultiScaleDeformableAttention #2452

Merged
merged 10 commits into from
Dec 11, 2022
10 changes: 7 additions & 3 deletions mmcv/ops/multi_scale_deform_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ class MultiScaleDeformableAttention(BaseModule):
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
value_proj_ratio (float): The expansion ratio of value_proj.
Default: 1.0.
"""

def __init__(self,
Expand All @@ -193,7 +195,8 @@ def __init__(self,
dropout: float = 0.1,
batch_first: bool = False,
norm_cfg: Optional[dict] = None,
init_cfg: Optional[mmengine.ConfigDict] = None):
init_cfg: Optional[mmengine.ConfigDict] = None,
value_proj_ratio: float = 1.0):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
Expand Down Expand Up @@ -228,8 +231,9 @@ def _is_power_of_2(n):
embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
value_proj_size = int(embed_dims * value_proj_ratio)
self.value_proj = nn.Linear(embed_dims, value_proj_size)
self.output_proj = nn.Linear(value_proj_size, embed_dims)
self.init_weights()

def init_weights(self) -> None:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_ops/test_ms_deformable_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,26 @@ def test_multiscale_deformable_attention(device):
spatial_shapes=spatial_shapes,
level_start_index=level_start_index)

# test with value_proj_ratio
embed_dims = 6
value_proj_ratio = 0.5
query = torch.rand(num_query, bs, embed_dims).to(device)
key = torch.rand(num_query, bs, embed_dims).to(device)
msda = MultiScaleDeformableAttention(
embed_dims=embed_dims,
num_levels=2,
num_heads=3,
value_proj_ratio=value_proj_ratio)
msda.init_weights()
msda.to(device)
msda(
query,
key,
key,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index)


def test_forward_multi_scale_deformable_attn_pytorch():
N, M, D = 1, 2, 2
Expand Down