-
Notifications
You must be signed in to change notification settings - Fork 0
/
paste_masks.py
140 lines (120 loc) · 5.78 KB
/
paste_masks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import numpy as np
from typing import Tuple
import torch
from torch.nn import functional as F
def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True):
"""
Args:
masks: N, 1, H, W
boxes: N, 4
img_h, img_w (int):
skip_empty (bool): only paste masks within the region that
tightly bound all boxes, and returns the results this region only.
An important optimization for CPU.
Returns:
if skip_empty == False, a mask of shape (N, img_h, img_w)
if skip_empty == True, a mask of shape (N, h', w'), and the slice
object for the corresponding region.
"""
# On GPU, paste all masks together (up to chunk size)
# by using the entire image to sample the masks
# Compared to pasting them one by one,
# this has more operations but is faster on COCO-scale dataset.
device = masks.device
if skip_empty and not torch.jit.is_scripting():
x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to(
dtype=torch.int32
)
x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32)
y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32)
else:
x0_int, y0_int = 0, 0
x1_int, y1_int = img_w, img_h
x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1
N = masks.shape[0]
img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5
img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5
img_y = (img_y - y0) / (y1 - y0) * 2 - 1
img_x = (img_x - x0) / (x1 - x0) * 2 - 1
# img_x, img_y have shapes (N, w), (N, h)
gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
grid = torch.stack([gx, gy], dim=3)
if not torch.jit.is_scripting():
if not masks.dtype.is_floating_point:
masks = masks.float()
img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False)
if skip_empty and not torch.jit.is_scripting():
return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
else:
return img_masks[:, 0], ()
def paste_masks_in_image(
masks: torch.Tensor, boxes: torch.Tensor, image_shape: Tuple[int, int], threshold: float = 0.5
):
"""
Paste a set of masks that are of a fixed resolution (e.g., 28 x 28) into an image.
The location, height, and width for pasting each mask is determined by their
corresponding bounding boxes in boxes.
Note:
This is a complicated but more accurate implementation. In actual deployment, it is
often enough to use a faster but less accurate implementation.
See :func:`paste_mask_in_image_old` in this file for an alternative implementation.
Args:
masks (tensor): Tensor of shape (Bimg, Hmask, Wmask), where Bimg is the number of
detected object instances in the image and Hmask, Wmask are the mask width and mask
height of the predicted mask (e.g., Hmask = Wmask = 28). Values are in [0, 1].
boxes (Boxes or Tensor): A Boxes of length Bimg or Tensor of shape (Bimg, 4).
boxes[i] and masks[i] correspond to the same object instance.
image_shape (tuple): height, width
threshold (float): A threshold in [0, 1] for converting the (soft) masks to
binary masks.
Returns:
img_masks (Tensor): A tensor of shape (Bimg, Himage, Wimage), where Bimg is the
number of detected object instances and Himage, Wimage are the image width
and height. img_masks[i] is a binary mask for object instance i.
"""
BYTES_PER_FLOAT = 4
# TODO: This memory limit may be too much or too little. It would be better to
# determine it based on available resources.
GPU_MEM_LIMIT = 1024 ** 3 # 1 GB memory limit
assert masks.shape[-1] == masks.shape[-2], "Only square mask predictions are supported"
N = len(masks)
print(N)
if N == 0:
return masks.new_empty((0,) + image_shape, dtype=torch.uint8)
if not isinstance(boxes, torch.Tensor):
boxes = boxes.tensor
device = boxes.device
assert len(boxes) == N, boxes.shape
img_h, img_w = image_shape
# The actual implementation split the input into chunks,
# and paste them chunk by chunk.
if device.type == "cpu" or torch.jit.is_scripting():
# CPU is most efficient when they are pasted one by one with skip_empty=True
# so that it performs minimal number of operations.
num_chunks = N
else:
# GPU benefits from parallelism for larger chunks, but may have memory issue
# int(img_h) because shape may be tensors in tracing
num_chunks = int(np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT / GPU_MEM_LIMIT))
assert (
num_chunks <= N
), "Default GPU_MEM_LIMIT in mask_ops.py is too small; try increasing it"
chunks = torch.chunk(torch.arange(N, device=device), num_chunks)
img_masks = torch.zeros(
N, img_h, img_w, device=device, dtype=torch.bool if threshold >= 0 else torch.uint8
)
for inds in chunks:
masks_chunk, spatial_inds = _do_paste_mask(
masks[inds, None, :, :], boxes[inds], img_h, img_w, skip_empty=device.type == "cpu"
)
if threshold >= 0:
masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
else:
# for visualization and debugging
masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)
if torch.jit.is_scripting(): # Scripting does not use the optimized codepath
img_masks[inds] = masks_chunk
else:
img_masks[(inds,) + spatial_inds] = masks_chunk
return img_masks