Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add RandomApply in gluon's transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
guanxinq committed Jan 8, 2020
1 parent 634f95e commit beebe72
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
26 changes: 26 additions & 0 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .... import image
from ....base import numeric_types
from ....util import is_np_array
import random


class Compose(Sequential):
Expand Down Expand Up @@ -581,3 +582,28 @@ def hybrid_forward(self, F, x):
if is_np_array():
F = F.npx
return F.image.random_lighting(x, self._alpha)


class RandomApply(Sequential):
"""Apply a list of transformations randomly given probability
Parameters
----------
Inputs:
- **transforms**: list of transformations
- **p**: probability
Outputs:
Transformed image.
"""

def __init__(self, transforms, p=0.5):
super(RandomApply, self).__init__()
self.transforms = transforms
self.p = p

def forward(self, x):
if self.p < random.random():
return x
x = self.transforms(x)
return x
16 changes: 16 additions & 0 deletions tests/python/unittest/test_gluon_data_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,22 @@ def test_transformer():
transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read()


@with_seed()
def test_random_transforms():
from mxnet.gluon.data.vision import transforms

tmp_t = transforms.Compose([transforms.Resize(300), transforms.RandomResizedCrop(224)])
transform = transforms.Compose([transforms.RandomApply(tmp_t, 0.5)])

img = mx.nd.ones((10, 10, 3), dtype='uint8')
iteration = 1000
num_apply = 0
for _ in range(iteration):
out = transform(img)
if out.shape[0] == 224:
num_apply += 1
assert_almost_equal(num_apply/float(iteration), 0.5, 0.1)


if __name__ == '__main__':
import nose
Expand Down

0 comments on commit beebe72

Please sign in to comment.