From beebe7200cad9396a178096d091a8eacc79ecf7a Mon Sep 17 00:00:00 2001 From: guanxinq Date: Tue, 7 Jan 2020 19:36:31 +0000 Subject: [PATCH] add RandomApply in gluon's transforms --- python/mxnet/gluon/data/vision/transforms.py | 26 +++++++++++++++++++ .../python/unittest/test_gluon_data_vision.py | 16 ++++++++++++ 2 files changed, 42 insertions(+) diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index 935ce2738a6f..1ec72499bcf4 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -24,6 +24,7 @@ from .... import image from ....base import numeric_types from ....util import is_np_array +import random class Compose(Sequential): @@ -581,3 +582,28 @@ def hybrid_forward(self, F, x): if is_np_array(): F = F.npx return F.image.random_lighting(x, self._alpha) + + +class RandomApply(Sequential): + """Apply a list of transformations randomly given probability + + Parameters + ---------- + Inputs: + - **transforms**: list of transformations + - **p**: probability + + Outputs: + Transformed image. + """ + + def __init__(self, transforms, p=0.5): + super(RandomApply, self).__init__() + self.transforms = transforms + self.p = p + + def forward(self, x): + if self.p < random.random(): + return x + x = self.transforms(x) + return x diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index 8bc0f8072260..71efb72b9ce5 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -229,6 +229,22 @@ def test_transformer(): transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read() +@with_seed() +def test_random_transforms(): + from mxnet.gluon.data.vision import transforms + + tmp_t = transforms.Compose([transforms.Resize(300), transforms.RandomResizedCrop(224)]) + transform = transforms.Compose([transforms.RandomApply(tmp_t, 0.5)]) + + img = mx.nd.ones((10, 10, 3), dtype='uint8') + iteration = 1000 + num_apply = 0 + for _ in range(iteration): + out = transform(img) + if out.shape[0] == 224: + num_apply += 1 + assert_almost_equal(num_apply/float(iteration), 0.5, 0.1) + if __name__ == '__main__': import nose