Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Features] : add flow1d correlation and correlation lookup #213

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions mmflow/models/utils/correlation1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch
from mmcv.runner import BaseModule
from torch import Tensor


class Correlation1D(BaseModule):
"""Correlation1D Module.

The neck of Flow1D, which calculates correlation tensor of input features
with the method of 3D cost volume.
"""

def __init__(self):
super().__init__()

def forward(self,
feat1: Tensor,
feat2: Tensor,
y_direction: bool = False) -> Tensor:
"""Forward function for Correlation1D.

Args:
feat1 (Tensor): The feature from first input image.
feat2 (Tensor): The 1D cross attention feat2 on x or y direction.
y_direction (bool): whether y direction or not.
Returns:
Tensor: Correlation of x correlation or y correlation.
"""
b, c, h, w = feat1.shape
scale_factor = c**0.5
if y_direction:
# y direction, corr shape is [B, W, H, H]
feat1 = feat1.permute(0, 3, 2, 1)
feat2 = feat2.permute(0, 3, 1, 2)
corr = torch.matmul(feat1, feat2) / scale_factor
else:
# x direction, corr shape is [B, H, W, W]
feat1 = feat1.permute(0, 2, 3, 1)
feat2 = feat2.permute(0, 2, 1, 3)
corr = torch.matmul(feat1, feat2) / scale_factor
MeowZheng marked this conversation as resolved.
Show resolved Hide resolved
return corr
87 changes: 87 additions & 0 deletions mmflow/ops/corr_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,90 @@ def forward(self, corr_pyramid: Sequence[Tensor], flow: Tensor) -> Tensor:

out = torch.cat(out_corr_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()


@OPERATORS.register_module()
class CorrLookupFlow1D(nn.Module):
"""Correlation lookup operator for Flow1D.

This operator is used in `Flow1D<https://arxiv.org/pdf/2104.13918.pdf>`_

Args:
radius (int): the radius of the local neighborhood of the pixels.
Default to 4.
mode (str): interpolation mode to calculate output values 'bilinear'
| 'nearest' | 'bicubic'. Default: 'bilinear' Note: mode='bicubic'
supports only 4-D input.
padding_mode (str): padding mode for outside grid values 'zeros' |
'border' | 'reflection'. Default: 'zeros'
align_corners (bool): If set to True, the extrema (-1 and 1) are
considered as referring to the center points of the input’s corner
pixels. If set to False, they are instead considered as referring
to the corner points of the input’s corner pixels, making the
sampling more resolution agnostic. Default to True.
"""

def __init__(self,
radius: int = 4,
mode: str = 'bilinear',
padding_mode: str = 'zeros',
align_corners: bool = True) -> None:
super().__init__()
self.r = radius
self.mode = mode
self.padding_mode = padding_mode
self.align_corners = align_corners

def forward(self, corr: Sequence[Tensor], flow: Tensor) -> Tensor:
"""Forward function of Correlation lookup for Flow1D.

Args:
corr (Sequence[Tensor]): Correlation on x and y direction.
flow (Tensor): Current estimated optical flow.

Returns:
Tensor: lookup cost volume on the correlation of x and y directions
concatenate together.
"""
B, _, H, W = flow.shape
# reshape corr_x to [B*H*W, 1, 1, W]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# reshape corr_x to [B*H*W, 1, 1, W]
# reshape corr_x from [B, H, W, W] to [B*H*W, 1, 1, W]

corr_x = corr[0].view(-1, 1, 1, W)
# reshape corr_y to [B*H*W, 1, H, 1]
MeowZheng marked this conversation as resolved.
Show resolved Hide resolved
corr_y = corr[1].permute(0, 2, 1, 3).contiguous().view(-1, 1, H, 1)

# reshape flow to [B, H, W, 2]
flow = flow.permute(0, 2, 3, 1)
coords_x = flow[:, :, :, 0]
coords_y = flow[:, :, :, 1]
coords_x = torch.stack((coords_x, torch.zeros_like(coords_x)), dim=-1)
coords_y = torch.stack((torch.zeros_like(coords_y), coords_y), dim=-1)
centroid_x = coords_x.view(B * H * W, 1, 1, 2)
centroid_y = coords_y.view(B * H * W, 1, 1, 2)

dx = torch.linspace(
-self.r, self.r, 2 * self.r + 1, device=flow.device)
dy = torch.linspace(
-self.r, self.r, 2 * self.r + 1, device=flow.device)

delta_x = torch.stack((dx, torch.zeros_like(dx)), dim=-1)
delta_y = torch.stack((torch.zeros_like(dy), dy), dim=-1)
# [1, 2r+1, 1, 2]
delta_y = delta_y.view(1, 2 * self.r + 1, 1, 2)

coords_x = centroid_x + delta_x
coords_y = centroid_y + delta_y

corr_x = bilinear_sample(corr_x, coords_x, self.mode,
self.padding_mode, self.align_corners)
corr_y = bilinear_sample(corr_y, coords_y, self.mode,
self.padding_mode, self.align_corners)

# shape is [B, 2r+1, H, W]
corr_x = corr_x.view(B, H, W, -1)
corr_x = corr_x.permute(0, 3, 1, 2).contiguous()
corr_y = corr_y.view(B, H, W, -1)
corr_y = corr_y.permute(0, 3, 1, 2).contiguous()

correlation = torch.cat((corr_x, corr_y), dim=1)

return correlation
43 changes: 43 additions & 0 deletions tests/test_models/test_utils/test_correlation1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import Tensor

from mmflow.models.utils.correlation1d import Correlation1D

_feat1 = torch.arange(0, 24).view(1, 2, 3, 4)
_feat2 = _feat1 + 1
b, c, h, w = _feat1.size()


def test_correlation():
gt_corr_x = Tensor([[[[110.3087, 118.7939, 127.2792, 135.7645],
[120.2082, 130.1077, 140.0071, 149.9066],
[130.1077, 141.4214, 152.7351, 164.0488],
[140.0071, 152.7351, 165.4630, 178.1909]],
[[206.4752, 220.6173, 234.7595, 248.9016],
[222.0315, 237.5879, 253.1442, 268.7006],
[237.5879, 254.5584, 271.5290, 288.4996],
[253.1442, 271.5290, 289.9138, 308.2986]],
[[347.8965, 367.6955, 387.4945, 407.2935],
[369.1097, 390.3229, 411.5362, 432.7494],
[390.3229, 412.9504, 435.5778, 458.2052],
[411.5362, 435.5778, 459.6194, 483.6610]]]])
gt_corr_y = Tensor([[[[110.3087, 144.2498, 178.1909],
[149.9066, 206.4752, 263.0437],
[189.5046, 268.7006, 347.8965]],
[[130.1077, 169.7056, 209.3036],
[175.3625, 237.5879, 299.8133],
[220.6173, 305.4701, 390.3229]],
[[152.7351, 197.9899, 243.2447],
[203.6468, 271.5290, 339.4113],
[254.5584, 345.0681, 435.5778]],
[[178.1909, 229.1026, 280.0143],
[234.7595, 308.2986, 381.8377],
[291.3280, 387.4945, 483.6610]]]])
corr = Correlation1D()
corr_x = corr(_feat1, _feat2, False)
corr_y = corr(_feat1, _feat2, True)
assert corr_x.size() == (b, h, w, w)
assert corr_y.size() == (b, w, h, h)
assert torch.allclose(corr_x, gt_corr_x, atol=1e-4)
assert torch.allclose(corr_y, gt_corr_y, atol=1e-4)
52 changes: 51 additions & 1 deletion tests/test_op/test_corr_lookup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from torch import Tensor

from mmflow.models.decoders.raft_decoder import CorrelationPyramid
from mmflow.models.utils.correlation1d import Correlation1D
from mmflow.ops.builder import build_operators
from mmflow.ops.corr_lookup import bilinear_sample, coords_grid

Expand All @@ -17,7 +20,6 @@ def test_coords_grid():
assert grid.shape == torch.Size((2, 2, H, W))
for i in range(H):
for j in range(W):

assert torch.all(grid[0, :, i, j] == torch.Tensor((j, i)))


Expand Down Expand Up @@ -56,3 +58,51 @@ def test_corr_lookup():

corr_lpt = corr_lookup_op(corr_pyramid, torch.randn(1, 2, H, W))
assert corr_lpt.shape == torch.Size((1, 81 * 4, H, W))


@pytest.mark.parametrize('mode', ['bilinear', 'nearest', 'bicubic'])
@pytest.mark.parametrize('padding_mode', ['zeros', 'border', 'reflection'])
@pytest.mark.parametrize('align_corners', [True, False])
def test_corr_lookup_flow1d(mode, padding_mode, align_corners):
corr_block = Correlation1D()
feat1 = torch.arange(0, 24)
feat1 = feat1.view(1, 2, 3, 4)
feat2 = feat1 + 1
flow = torch.ones_like(feat1)
b, c, h, w = feat1.size()
radius = 1

# gronud truth
gt_corr_x = Tensor([[[[110.3087, 120.2082, 130.1077, 140.0071],
[206.4752, 222.0315, 237.5879, 253.1442],
[347.8965, 369.1097, 390.3229, 411.5362]],
[[118.7939, 130.1077, 141.4214, 152.7351],
[220.6173, 237.5879, 254.5584, 271.5290],
[367.6955, 390.3229, 412.9504, 435.5778]],
[[127.2792, 140.0071, 152.7351, 165.4630],
[234.7595, 253.1442, 271.5290, 289.9138],
[387.4945, 411.5362, 435.5778, 459.6194]]]])
gt_corr_y = Tensor([[[[110.3087, 130.1077, 152.7351, 178.1909],
[149.9066, 175.3625, 203.6468, 234.7595],
[189.5046, 220.6173, 254.5584, 291.3280]],
[[144.2498, 169.7056, 197.9899, 229.1026],
[206.4752, 237.5879, 271.5290, 308.2986],
[268.7006, 305.4701, 345.0681, 387.4945]],
[[178.1909, 209.3036, 243.2447, 280.0143],
[263.0437, 299.8133, 339.4113, 381.8377],
[347.8965, 390.3229, 435.5778, 483.6610]]]])
gt_corr = torch.cat((gt_corr_x, gt_corr_y), dim=1)
correlation_x = corr_block(feat1, feat2, False)
correlation_y = corr_block(feat1, feat2, True)
correlation = [correlation_x, correlation_y]
corr_lookup_cfg = dict(
type='CorrLookupFlow1D',
radius=radius,
mode=mode,
padding_mode=padding_mode,
align_corners=True)
corr_lookup_op = build_operators(corr_lookup_cfg)

corr_xy = corr_lookup_op(correlation, flow)
assert corr_xy.size() == (b, 2 * (2 * radius + 1), h, w)
assert torch.allclose(gt_corr, corr_xy, atol=1e-4)