Skip to content

Commit

Permalink
Merge pull request #34 from meikuam/master
Browse files Browse the repository at this point in the history
pytorch 1.2.0 c++ interface (new-style torch.autograd.Function)
  • Loading branch information
longcw authored Oct 16, 2019
2 parents c0a44ed + 1af6929 commit 8df4f37
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 32 deletions.
63 changes: 55 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,64 @@ Currently it only works using the default GPU (index 0)
./test.sh
```
+ Use RoIAlign or crop_and_resize
+ Use RoIAlign or crop_and_resize
Since PyTorch 1.2.0 [Legacy autograd function with non-static forward method is deprecated.](https://github.com/pytorch/pytorch/blob/fdfc676eb6c4d9f50496e564976fbe6d124e23a5/torch/csrc/autograd/python_function.cpp#L636-L638)
We use new-style autograd function with static forward method. Example:
```python
import torch
from roi_align import RoIAlign # RoIAlign module
from roi_align import CropAndResize # crop_and_resize module
# input data
image = to_varabile(image_data, requires_grad=True, is_cuda=is_cuda)
boxes = to_varabile(boxes_data, requires_grad=False, is_cuda=is_cuda)
box_index = to_varabile(box_index_data, requires_grad=False, is_cuda=is_cuda)
# RoIAlign layer
# input feature maps (suppose that we have batch_size==2)
image = torch.arange(0., 49).view(1, 1, 7, 7).repeat(2, 1, 1, 1)
image[0] += 10
print('image: ', image)
# for example, we have two bboxes with coords xyxy (first with batch_id=0, second with batch_id=1).
boxes = torch.Tensor([[1, 0, 5, 4],
[0.5, 3.5, 4, 7]])
box_index = torch.tensor([0, 1], dtype=torch.int) # index of bbox in batch
# RoIAlign layer with crop sizes:
crop_height = 4
crop_width = 4
roi_align = RoIAlign(crop_height, crop_width)
# make crops:
crops = roi_align(image, boxes, box_index)
print('crops:', crops)
```
Output:
```python
image: tensor([[[[10., 11., 12., 13., 14., 15., 16.],
[17., 18., 19., 20., 21., 22., 23.],
[24., 25., 26., 27., 28., 29., 30.],
[31., 32., 33., 34., 35., 36., 37.],
[38., 39., 40., 41., 42., 43., 44.],
[45., 46., 47., 48., 49., 50., 51.],
[52., 53., 54., 55., 56., 57., 58.]]],
[[[ 0., 1., 2., 3., 4., 5., 6.],
[ 7., 8., 9., 10., 11., 12., 13.],
[14., 15., 16., 17., 18., 19., 20.],
[21., 22., 23., 24., 25., 26., 27.],
[28., 29., 30., 31., 32., 33., 34.],
[35., 36., 37., 38., 39., 40., 41.],
[42., 43., 44., 45., 46., 47., 48.]]]])
crops: tensor([[[[11.0000, 12.0000, 13.0000, 14.0000],
[18.0000, 19.0000, 20.0000, 21.0000],
[25.0000, 26.0000, 27.0000, 28.0000],
[32.0000, 33.0000, 34.0000, 35.0000]]],
[[[24.5000, 25.3750, 26.2500, 27.1250],
[30.6250, 31.5000, 32.3750, 33.2500],
[36.7500, 37.6250, 38.5000, 39.3750],
[ 0.0000, 0.0000, 0.0000, 0.0000]]]])
```
2 changes: 1 addition & 1 deletion roi_align/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .roi_align import RoIAlign, CropAndResizeFunction
from .roi_align import RoIAlign, CropAndResizeFunction, CropAndResize
30 changes: 15 additions & 15 deletions roi_align/crop_and_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,34 @@

class CropAndResizeFunction(Function):

def __init__(self, crop_height, crop_width, extrapolation_value=0):
self.crop_height = crop_height
self.crop_width = crop_width
self.extrapolation_value = extrapolation_value

def forward(self, image, boxes, box_ind):
@staticmethod
def forward(ctx, image, boxes, box_ind, crop_height, crop_width, extrapolation_value=0):
ctx.crop_height = crop_height
ctx.crop_width = crop_width
ctx.extrapolation_value = extrapolation_value
crops = torch.zeros_like(image)

if image.is_cuda:
crop_and_resize_gpu.forward(
image, boxes, box_ind,
self.extrapolation_value, self.crop_height, self.crop_width, crops)
ctx.extrapolation_value, ctx.crop_height, ctx.crop_width, crops)
else:
crop_and_resize_cpu.forward(
image, boxes, box_ind,
self.extrapolation_value, self.crop_height, self.crop_width, crops)
ctx.extrapolation_value, ctx.crop_height, ctx.crop_width, crops)

# save for backward
self.im_size = image.size()
self.save_for_backward(boxes, box_ind)
ctx.im_size = image.size()
ctx.save_for_backward(boxes, box_ind)

return crops

def backward(self, grad_outputs):
boxes, box_ind = self.saved_tensors
@staticmethod
def backward(ctx, grad_outputs):
boxes, box_ind = ctx.saved_tensors

grad_outputs = grad_outputs.contiguous()
grad_image = torch.zeros_like(grad_outputs).resize_(*self.im_size)
grad_image = torch.zeros_like(grad_outputs).resize_(*ctx.im_size)

if grad_outputs.is_cuda:
crop_and_resize_gpu.backward(
Expand All @@ -50,7 +50,7 @@ def backward(self, grad_outputs):
grad_outputs, boxes, box_ind, grad_image
)

return grad_image, None, None
return grad_image, None, None, None, None, None


class CropAndResize(nn.Module):
Expand All @@ -67,4 +67,4 @@ def __init__(self, crop_height, crop_width, extrapolation_value=0):
self.extrapolation_value = extrapolation_value

def forward(self, image, boxes, box_ind):
return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(image, boxes, box_ind)
return CropAndResizeFunction.apply(image, boxes, box_ind, self.crop_height, self.crop_width, self.extrapolation_value)
2 changes: 1 addition & 1 deletion roi_align/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ def forward(self, featuremap, boxes, box_ind):

boxes = boxes.detach().contiguous()
box_ind = box_ind.detach()
return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(featuremap, boxes, box_ind)
return CropAndResizeFunction.apply(featuremap, boxes, box_ind, self.crop_height, self.crop_width, self.extrapolation_value)
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
modules = [
CppExtension(
'roi_align.crop_and_resize_cpu',
['roi_align/src/crop_and_resize.cpp']
['roi_align/src/crop_and_resize.cpp'],
extra_compile_args={'cxx': ['-g', '-fopenmp']}
)
]

Expand All @@ -22,7 +23,7 @@

setup(
name='roi_align',
version='0.0.1',
version='0.0.2',
description='PyTorch version of RoIAlign',
author='Long Chen',
author_email='[email protected]',
Expand All @@ -31,5 +32,5 @@

ext_modules=modules,
cmdclass={'build_ext': BuildExtension},
install_requires=['torch']
install_requires=['torch>=1.2.0']
)
2 changes: 1 addition & 1 deletion tests/crop_and_resize_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def to_varabile(tensor, requires_grad=False, is_cuda=True):
box_index = to_varabile(box_index_data, is_cuda=is_cuda)

# Crops and resize bbox1 from img1 and bbox2 from img2
crops_torch = CropAndResizeFunction(crop_height, crop_width, 0)(image_torch, boxes, box_index)
crops_torch = CropAndResizeFunction.apply(image_torch, boxes, box_index, crop_height, crop_width, 0)

# Visualize the crops
print(crops_torch.data.size())
Expand Down
2 changes: 1 addition & 1 deletion tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def compare_with_tf(crop_height, crop_width, is_cuda=True):
box_index = to_varabile(box_index_data, requires_grad=False, is_cuda=is_cuda)

print('pytorch forward and backward start')
crops_torch = CropAndResizeFunction(crop_height, crop_width, 0)(image_torch, boxes, box_index)
crops_torch = CropAndResizeFunction.apply(image_torch, boxes, box_index, crop_height, crop_width, 0)
crops_torch = conv_torch(crops_torch)
crops_torch_data = crops_torch.data.cpu().numpy()

Expand Down
3 changes: 1 addition & 2 deletions tests/test2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,4 @@ def to_varabile(arr, requires_grad=False, is_cuda=True):
box_index = to_varabile(box_index_data, requires_grad=False, is_cuda=is_cuda)

# set transform_fpcoor to False is the crop_and_resize
roi_align = RoIAlign(3, 3, transform_fpcoor=True)
print(roi_align(image_torch, boxes, box_index))
print(RoIAlign.apply(image_torch, boxes, box_index, 3, 3, transform_fpcoor=True))

0 comments on commit 8df4f37

Please sign in to comment.