Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Mosaic transform #1093

Merged
merged 21 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions mmseg/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from .test_time_aug import MultiScaleFlipAug
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
PhotoMetricDistortion, RandomCrop, RandomCutOut,
RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray,
SegRescale)
RandomFlip, RandomMosaic, RandomRotate, Rerange,
Resize, RGB2Gray, SegRescale)

__all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut'
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut',
'RandomMosaic'
]
260 changes: 260 additions & 0 deletions mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy

import mmcv
import numpy as np
from mmcv.utils import deprecated_api_warning, is_tuple_of
Expand Down Expand Up @@ -1040,3 +1042,261 @@ def __repr__(self):
repr_str += f'fill_in={self.fill_in}, '
repr_str += f'seg_fill_in={self.seg_fill_in})'
return repr_str


@PIPELINES.register_module()
class RandomMosaic(object):
"""Mosaic augmentation. Given 4 images, mosaic transform combines them into
one output image. The output image is composed of the parts from each sub-
image.

.. code:: text
mosaic transform
center_x
+------------------------------+
| pad | pad |
| +-----------+ |
| | | |
| | image1 |--------+ |
| | | | |
| | | image2 | |
center_y |----+-------------+-----------|
| | cropped | |
|pad | image3 | image4 |
| | | |
+----|-------------+-----------+
| |
+-------------+
The mosaic transform steps are as follows:
1. Choose the mosaic center as the intersections of 4 images
2. Get the left top image according to the index, and randomly
sample another 3 images from the custom dataset.
3. Sub image will be cropped if image is larger than mosaic patch
lkm2835 marked this conversation as resolved.
Show resolved Hide resolved
Args:
lkm2835 marked this conversation as resolved.
Show resolved Hide resolved
prob (float): mosaic probability.
img_scale (Sequence[int]): Image size after mosaic pipeline of
a single image. The size of the output image is four times
that of a single image. The output image comprises 4 single images.
Default: (640, 640).
center_ratio_range (Sequence[float]): Center ratio range of mosaic
output. Default: (0.5, 1.5).
pad_val (int): Pad value. Default: 0.
seg_pad_val (int): Pad value of segmentation map. Default: 255.
"""

def __init__(self,
prob,
img_scale=(640, 640),
lkm2835 marked this conversation as resolved.
Show resolved Hide resolved
center_ratio_range=(0.5, 1.5),
pad_val=0,
seg_pad_val=255):
assert 0 <= prob and prob <= 1
assert isinstance(img_scale, tuple)
self.prob = prob
self.img_scale = img_scale
self.center_ratio_range = center_ratio_range
self.pad_val = pad_val
self.seg_pad_val = seg_pad_val

def __call__(self, results):
"""Call function to make a mosaic of image.

Args:
results (dict): Result dict.
Returns:
dict: Result dict with mosaic transformed.
"""
mosaic = True if np.random.rand() < self.prob else False
if mosaic:
results = self._mosaic_transform_img(results)
results = self._mosaic_transform_seg(results)
return results

def get_indexes(self, dataset):
"""Call function to collect indexes.

Args:
dataset (:obj:`MultiImageMixDataset`): The dataset.
Returns:
list: indexes.
"""

indexes = [random.randint(0, len(dataset)) for _ in range(3)]
return indexes

def _mosaic_transform_img(self, results):
"""Mosaic transform function.

Args:
results (dict): Result dict.
Returns:
lkm2835 marked this conversation as resolved.
Show resolved Hide resolved
dict: Updated result dict.
"""

assert 'mix_results' in results
if len(results['img'].shape) == 3:
mosaic_img = np.full(
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3),
self.pad_val,
dtype=results['img'].dtype)
else:
mosaic_img = np.full(
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)),
self.pad_val,
dtype=results['img'].dtype)

# mosaic center x, y
self.center_x = int(
random.uniform(*self.center_ratio_range) * self.img_scale[1])
self.center_y = int(
random.uniform(*self.center_ratio_range) * self.img_scale[0])
center_position = (self.center_x, self.center_y)

loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
for i, loc in enumerate(loc_strs):
if loc == 'top_left':
result_patch = copy.deepcopy(results)
else:
result_patch = copy.deepcopy(results['mix_results'][i - 1])

img_i = result_patch['img']
h_i, w_i = img_i.shape[:2]
# keep_ratio resize
scale_ratio_i = min(self.img_scale[0] / h_i,
self.img_scale[1] / w_i)
img_i = mmcv.imresize(
img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))

# compute the combine parameters
paste_coord, crop_coord = self._mosaic_combine(
loc, center_position, img_i.shape[:2][::-1])
x1_p, y1_p, x2_p, y2_p = paste_coord
x1_c, y1_c, x2_c, y2_c = crop_coord

# crop and paste image
mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]

results['img'] = mosaic_img
results['img_shape'] = mosaic_img.shape
results['ori_shape'] = mosaic_img.shape

return results

def _mosaic_transform_seg(self, results):
"""Mosaic transform function for label annotations.

Args:
results (dict): Result dict.
Returns:
lkm2835 marked this conversation as resolved.
Show resolved Hide resolved
dict: Updated result dict.
"""

assert 'mix_results' in results
for key in results.get('seg_fields', []):
mosaic_seg = np.full(
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)),
self.seg_pad_val,
dtype=results[key].dtype)

# mosaic center x, y
center_position = (self.center_x, self.center_y)

loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
for i, loc in enumerate(loc_strs):
if loc == 'top_left':
result_patch = copy.deepcopy(results)
else:
result_patch = copy.deepcopy(results['mix_results'][i - 1])

gt_seg_i = result_patch[key]
h_i, w_i = gt_seg_i.shape[:2]
# keep_ratio resize
scale_ratio_i = min(self.img_scale[0] / h_i,
self.img_scale[1] / w_i)
gt_seg_i = mmcv.imresize(
gt_seg_i,
(int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)),
interpolation='nearest')

# compute the combine parameters
paste_coord, crop_coord = self._mosaic_combine(
loc, center_position, gt_seg_i.shape[:2][::-1])
x1_p, y1_p, x2_p, y2_p = paste_coord
x1_c, y1_c, x2_c, y2_c = crop_coord

# crop and paste image
mosaic_seg[y1_p:y2_p, x1_p:x2_p] = gt_seg_i[y1_c:y2_c,
x1_c:x2_c]

results[key] = mosaic_seg

return results

def _mosaic_combine(self, loc, center_position_xy, img_shape_wh):
"""Calculate global coordinate of mosaic image and local coordinate of
cropped sub-image.
Args:
lkm2835 marked this conversation as resolved.
Show resolved Hide resolved
loc (str): Index for the sub-image, loc in ('top_left',
'top_right', 'bottom_left', 'bottom_right').
center_position_xy (Sequence[float]): Mixing center for 4 images,
(x, y).
img_shape_wh (Sequence[int]): Width and height of sub-image
Returns:
lkm2835 marked this conversation as resolved.
Show resolved Hide resolved
tuple[tuple[float]]: Corresponding coordinate of pasting and
cropping
- paste_coord (tuple): paste corner coordinate in mosaic image.
- crop_coord (tuple): crop corner coordinate in mosaic image.
"""

assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right')
if loc == 'top_left':
# index0 to top left part of image
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
max(center_position_xy[1] - img_shape_wh[1], 0), \
center_position_xy[0], \
center_position_xy[1]
crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - (
y2 - y1), img_shape_wh[0], img_shape_wh[1]

elif loc == 'top_right':
# index1 to top right part of image
x1, y1, x2, y2 = center_position_xy[0], \
max(center_position_xy[1] - img_shape_wh[1], 0), \
min(center_position_xy[0] + img_shape_wh[0],
self.img_scale[1] * 2), \
center_position_xy[1]
crop_coord = 0, img_shape_wh[1] - (y2 - y1), min(
img_shape_wh[0], x2 - x1), img_shape_wh[1]

elif loc == 'bottom_left':
# index2 to bottom left part of image
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
center_position_xy[1], \
center_position_xy[0], \
min(self.img_scale[0] * 2, center_position_xy[1] +
img_shape_wh[1])
crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min(
y2 - y1, img_shape_wh[1])

else:
# index3 to bottom right part of image
x1, y1, x2, y2 = center_position_xy[0], \
center_position_xy[1], \
min(center_position_xy[0] + img_shape_wh[0],
self.img_scale[1] * 2), \
min(self.img_scale[0] * 2, center_position_xy[1] +
img_shape_wh[1])
crop_coord = 0, 0, min(img_shape_wh[0],
x2 - x1), min(y2 - y1, img_shape_wh[1])

paste_coord = x1, y1, x2, y2
return paste_coord, crop_coord

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'img_scale={self.img_scale}, '
repr_str += f'center_ratio_range={self.center_ratio_range}, '
repr_str += f'pad_val={self.pad_val}, '
repr_str += f'seg_pad_val={self.pad_val})'
return repr_str
49 changes: 49 additions & 0 deletions tests/test_data/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,3 +614,52 @@ def test_cutout():
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() > img.sum()
assert cutout_result['gt_semantic_seg'].sum() > seg.sum()


def test_mosaic():
# test prob
with pytest.raises(AssertionError):
transform = dict(type='RandomMosaic', prob=1.5)
build_from_cfg(transform, PIPELINES)
# test assertion for invalid img_scale
with pytest.raises(AssertionError):
transform = dict(type='RandomMosaic', prob=1, img_scale=640)
build_from_cfg(transform, PIPELINES)

results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))

results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']

transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
mosaic_module = build_from_cfg(transform, PIPELINES)
assert 'Mosaic' in repr(mosaic_module)

# test assertion for invalid mix_results
with pytest.raises(AssertionError):
mosaic_module(results)

results['mix_results'] = [copy.deepcopy(results)] * 3
results = mosaic_module(results)
assert results['img'].shape[:2] == (20, 24)

results = dict()
results['img'] = img[:, :, 0]
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']

transform = dict(type='RandomMosaic', prob=0, img_scale=(10, 12))
mosaic_module = build_from_cfg(transform, PIPELINES)
results['mix_results'] = [copy.deepcopy(results)] * 3
results = mosaic_module(results)
assert results['img'].shape[:2] == img.shape[:2]

transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
mosaic_module = build_from_cfg(transform, PIPELINES)
results = mosaic_module(results)
assert results['img'].shape[:2] == (20, 24)