Skip to content

Commit

Permalink
[Enhancement] Support value_proj_ratio in MultiScaleDeformableAttenti…
Browse files Browse the repository at this point in the history
…on (#2452)

* add ratio in ms_deform_attn_

* add ratio in ms_deform_attn

* Update mmcv/ops/multi_scale_deform_attn.py

Co-authored-by: Zaida Zhou <[email protected]>

* Update tests/test_ops/test_ms_deformable_attn.py

Co-authored-by: Zaida Zhou <[email protected]>

* add ratio in ms_deform_attn

* add ratio in ms_deform_attn

* add ratio in ms_deform_attn

* add ratio in ms_deform_attn

* add ratio in ms_deform_attn

* add ratio in ms_deform_attn

Co-authored-by: Zaida Zhou <[email protected]>
  • Loading branch information
okotaku and zhouzaida authored Dec 11, 2022
1 parent fb39e1e commit 4336070
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
10 changes: 7 additions & 3 deletions mmcv/ops/multi_scale_deform_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ class MultiScaleDeformableAttention(BaseModule):
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
value_proj_ratio (float): The expansion ratio of value_proj.
Default: 1.0.
"""

def __init__(self,
Expand All @@ -193,7 +195,8 @@ def __init__(self,
dropout: float = 0.1,
batch_first: bool = False,
norm_cfg: Optional[dict] = None,
init_cfg: Optional[mmengine.ConfigDict] = None):
init_cfg: Optional[mmengine.ConfigDict] = None,
value_proj_ratio: float = 1.0):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
Expand Down Expand Up @@ -228,8 +231,9 @@ def _is_power_of_2(n):
embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
value_proj_size = int(embed_dims * value_proj_ratio)
self.value_proj = nn.Linear(embed_dims, value_proj_size)
self.output_proj = nn.Linear(value_proj_size, embed_dims)
self.init_weights()

def init_weights(self) -> None:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_ops/test_ms_deformable_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,26 @@ def test_multiscale_deformable_attention(device):
spatial_shapes=spatial_shapes,
level_start_index=level_start_index)

# test with value_proj_ratio
embed_dims = 6
value_proj_ratio = 0.5
query = torch.rand(num_query, bs, embed_dims).to(device)
key = torch.rand(num_query, bs, embed_dims).to(device)
msda = MultiScaleDeformableAttention(
embed_dims=embed_dims,
num_levels=2,
num_heads=3,
value_proj_ratio=value_proj_ratio)
msda.init_weights()
msda.to(device)
msda(
query,
key,
key,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index)


def test_forward_multi_scale_deformable_attn_pytorch():
N, M, D = 1, 2, 2
Expand Down

0 comments on commit 4336070

Please sign in to comment.