diff --git a/references/classification/train_pytorch_orientation.py b/references/classification/train_pytorch_orientation.py index e1d0df02f3..46c77d4c38 100644 --- a/references/classification/train_pytorch_orientation.py +++ b/references/classification/train_pytorch_orientation.py @@ -41,7 +41,10 @@ def rnd_rotate(img: torch.Tensor, target): angle = int(np.random.choice(CLASSES)) idx = CLASSES.index(angle) - rotated_img = F.rotate(img, angle=-angle, fill=0, expand=False)[:3] + # augment the angle randomly with a probability of 0.5 + if np.random.rand() < 0.5: + angle += float(np.random.choice(np.arange(-25, 25, 5))) + rotated_img = F.rotate(img, angle=-angle, fill=0, expand=angle not in CLASSES)[:3] return rotated_img, idx diff --git a/references/classification/train_tensorflow_orientation.py b/references/classification/train_tensorflow_orientation.py index 8af14a5cf4..a37ce8e1dc 100644 --- a/references/classification/train_tensorflow_orientation.py +++ b/references/classification/train_tensorflow_orientation.py @@ -36,8 +36,11 @@ def rnd_rotate(img: tf.Tensor, target): angle = int(np.random.choice(CLASSES)) idx = CLASSES.index(angle) + # augment the angle randomly with a probability of 0.5 + if np.random.rand() < 0.5: + angle += float(np.random.choice(np.arange(-25, 25, 5))) # clockwise rotation - rotated_img = rotated_img_tensor(img, -angle, expand=False) + rotated_img = rotated_img_tensor(img, -angle, expand=angle not in CLASSES) return rotated_img, idx