diff --git a/README.md b/README.md index 6db9e30..7b30f95 100644 --- a/README.md +++ b/README.md @@ -34,17 +34,64 @@ Currently it only works using the default GPU (index 0) ./test.sh ``` -+ Use RoIAlign or crop_and_resize ++ Use RoIAlign or crop_and_resize + + Since PyTorch 1.2.0 [Legacy autograd function with non-static forward method is deprecated.](https://github.com/pytorch/pytorch/blob/fdfc676eb6c4d9f50496e564976fbe6d124e23a5/torch/csrc/autograd/python_function.cpp#L636-L638) + We use new-style autograd function with static forward method. Example: ```python + import torch from roi_align import RoIAlign # RoIAlign module from roi_align import CropAndResize # crop_and_resize module - - # input data - image = to_varabile(image_data, requires_grad=True, is_cuda=is_cuda) - boxes = to_varabile(boxes_data, requires_grad=False, is_cuda=is_cuda) - box_index = to_varabile(box_index_data, requires_grad=False, is_cuda=is_cuda) - - # RoIAlign layer + + # input feature maps (suppose that we have batch_size==2) + image = torch.arange(0., 49).view(1, 1, 7, 7).repeat(2, 1, 1, 1) + image[0] += 10 + print('image: ', image) + + + # for example, we have two bboxes with coords xyxy (first with batch_id=0, second with batch_id=1). + boxes = torch.Tensor([[1, 0, 5, 4], + [0.5, 3.5, 4, 7]]) + + box_index = torch.tensor([0, 1], dtype=torch.int) # index of bbox in batch + + # RoIAlign layer with crop sizes: + crop_height = 4 + crop_width = 4 roi_align = RoIAlign(crop_height, crop_width) + + # make crops: crops = roi_align(image, boxes, box_index) + + print('crops:', crops) + ``` + Output: + ```python + image: tensor([[[[10., 11., 12., 13., 14., 15., 16.], + [17., 18., 19., 20., 21., 22., 23.], + [24., 25., 26., 27., 28., 29., 30.], + [31., 32., 33., 34., 35., 36., 37.], + [38., 39., 40., 41., 42., 43., 44.], + [45., 46., 47., 48., 49., 50., 51.], + [52., 53., 54., 55., 56., 57., 58.]]], + + + [[[ 0., 1., 2., 3., 4., 5., 6.], + [ 7., 8., 9., 10., 11., 12., 13.], + [14., 15., 16., 17., 18., 19., 20.], + [21., 22., 23., 24., 25., 26., 27.], + [28., 29., 30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39., 40., 41.], + [42., 43., 44., 45., 46., 47., 48.]]]]) + + crops: tensor([[[[11.0000, 12.0000, 13.0000, 14.0000], + [18.0000, 19.0000, 20.0000, 21.0000], + [25.0000, 26.0000, 27.0000, 28.0000], + [32.0000, 33.0000, 34.0000, 35.0000]]], + + + [[[24.5000, 25.3750, 26.2500, 27.1250], + [30.6250, 31.5000, 32.3750, 33.2500], + [36.7500, 37.6250, 38.5000, 39.3750], + [ 0.0000, 0.0000, 0.0000, 0.0000]]]]) ``` diff --git a/roi_align/__init__.py b/roi_align/__init__.py index 60a6bef..158e4f1 100644 --- a/roi_align/__init__.py +++ b/roi_align/__init__.py @@ -1 +1 @@ -from .roi_align import RoIAlign, CropAndResizeFunction \ No newline at end of file +from .roi_align import RoIAlign, CropAndResizeFunction, CropAndResize \ No newline at end of file diff --git a/roi_align/crop_and_resize.py b/roi_align/crop_and_resize.py index a315886..6ac7796 100755 --- a/roi_align/crop_and_resize.py +++ b/roi_align/crop_and_resize.py @@ -12,34 +12,34 @@ class CropAndResizeFunction(Function): - def __init__(self, crop_height, crop_width, extrapolation_value=0): - self.crop_height = crop_height - self.crop_width = crop_width - self.extrapolation_value = extrapolation_value - - def forward(self, image, boxes, box_ind): + @staticmethod + def forward(ctx, image, boxes, box_ind, crop_height, crop_width, extrapolation_value=0): + ctx.crop_height = crop_height + ctx.crop_width = crop_width + ctx.extrapolation_value = extrapolation_value crops = torch.zeros_like(image) if image.is_cuda: crop_and_resize_gpu.forward( image, boxes, box_ind, - self.extrapolation_value, self.crop_height, self.crop_width, crops) + ctx.extrapolation_value, ctx.crop_height, ctx.crop_width, crops) else: crop_and_resize_cpu.forward( image, boxes, box_ind, - self.extrapolation_value, self.crop_height, self.crop_width, crops) + ctx.extrapolation_value, ctx.crop_height, ctx.crop_width, crops) # save for backward - self.im_size = image.size() - self.save_for_backward(boxes, box_ind) + ctx.im_size = image.size() + ctx.save_for_backward(boxes, box_ind) return crops - def backward(self, grad_outputs): - boxes, box_ind = self.saved_tensors + @staticmethod + def backward(ctx, grad_outputs): + boxes, box_ind = ctx.saved_tensors grad_outputs = grad_outputs.contiguous() - grad_image = torch.zeros_like(grad_outputs).resize_(*self.im_size) + grad_image = torch.zeros_like(grad_outputs).resize_(*ctx.im_size) if grad_outputs.is_cuda: crop_and_resize_gpu.backward( @@ -50,7 +50,7 @@ def backward(self, grad_outputs): grad_outputs, boxes, box_ind, grad_image ) - return grad_image, None, None + return grad_image, None, None, None, None, None class CropAndResize(nn.Module): @@ -67,4 +67,4 @@ def __init__(self, crop_height, crop_width, extrapolation_value=0): self.extrapolation_value = extrapolation_value def forward(self, image, boxes, box_ind): - return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(image, boxes, box_ind) + return CropAndResizeFunction.apply(image, boxes, box_ind, self.crop_height, self.crop_width, self.extrapolation_value) diff --git a/roi_align/roi_align.py b/roi_align/roi_align.py index 6931539..a6076cd 100644 --- a/roi_align/roi_align.py +++ b/roi_align/roi_align.py @@ -45,4 +45,4 @@ def forward(self, featuremap, boxes, box_ind): boxes = boxes.detach().contiguous() box_ind = box_ind.detach() - return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(featuremap, boxes, box_ind) + return CropAndResizeFunction.apply(featuremap, boxes, box_ind, self.crop_height, self.crop_width, self.extrapolation_value) diff --git a/setup.py b/setup.py index 54cb410..31c12c9 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,8 @@ modules = [ CppExtension( 'roi_align.crop_and_resize_cpu', - ['roi_align/src/crop_and_resize.cpp'] + ['roi_align/src/crop_and_resize.cpp'], + extra_compile_args={'cxx': ['-g', '-fopenmp']} ) ] @@ -22,7 +23,7 @@ setup( name='roi_align', - version='0.0.1', + version='0.0.2', description='PyTorch version of RoIAlign', author='Long Chen', author_email='longch1024@gmail.com', @@ -31,5 +32,5 @@ ext_modules=modules, cmdclass={'build_ext': BuildExtension}, - install_requires=['torch'] + install_requires=['torch>=1.2.0'] ) diff --git a/tests/crop_and_resize_example.py b/tests/crop_and_resize_example.py index d71447a..08ba825 100644 --- a/tests/crop_and_resize_example.py +++ b/tests/crop_and_resize_example.py @@ -42,7 +42,7 @@ def to_varabile(tensor, requires_grad=False, is_cuda=True): box_index = to_varabile(box_index_data, is_cuda=is_cuda) # Crops and resize bbox1 from img1 and bbox2 from img2 -crops_torch = CropAndResizeFunction(crop_height, crop_width, 0)(image_torch, boxes, box_index) +crops_torch = CropAndResizeFunction.apply(image_torch, boxes, box_index, crop_height, crop_width, 0) # Visualize the crops print(crops_torch.data.size()) diff --git a/tests/test.py b/tests/test.py index 0d77293..d9415d6 100644 --- a/tests/test.py +++ b/tests/test.py @@ -68,7 +68,7 @@ def compare_with_tf(crop_height, crop_width, is_cuda=True): box_index = to_varabile(box_index_data, requires_grad=False, is_cuda=is_cuda) print('pytorch forward and backward start') - crops_torch = CropAndResizeFunction(crop_height, crop_width, 0)(image_torch, boxes, box_index) + crops_torch = CropAndResizeFunction.apply(image_torch, boxes, box_index, crop_height, crop_width, 0) crops_torch = conv_torch(crops_torch) crops_torch_data = crops_torch.data.cpu().numpy() diff --git a/tests/test2.py b/tests/test2.py index e48183d..3d9d756 100644 --- a/tests/test2.py +++ b/tests/test2.py @@ -25,5 +25,4 @@ def to_varabile(arr, requires_grad=False, is_cuda=True): box_index = to_varabile(box_index_data, requires_grad=False, is_cuda=is_cuda) # set transform_fpcoor to False is the crop_and_resize -roi_align = RoIAlign(3, 3, transform_fpcoor=True) -print(roi_align(image_torch, boxes, box_index)) \ No newline at end of file +print(RoIAlign.apply(image_torch, boxes, box_index, 3, 3, transform_fpcoor=True)) \ No newline at end of file