forked from PaulKMueller/llama_traffic
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathraster_barlow_twins_transform.py
71 lines (60 loc) · 2.11 KB
/
raster_barlow_twins_transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torchvision.transforms as transforms
class BarlowTwinsTransform:
"""Src: https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/barlow-twins.html"""
def __init__(
self,
train=True,
input_height=224,
gaussian_blur=True,
jitter_strength=1.0,
normalize=None,
):
self.input_height = input_height
self.gaussian_blur = gaussian_blur
self.jitter_strength = jitter_strength
self.normalize = normalize
self.train = train
color_jitter = transforms.ColorJitter(
0.8 * self.jitter_strength,
0.8 * self.jitter_strength,
0.8 * self.jitter_strength,
0.2 * self.jitter_strength,
)
color_transform = [
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
]
if self.gaussian_blur:
kernel_size = int(0.1 * self.input_height)
if kernel_size % 2 == 0:
kernel_size += 1
color_transform.append(
transforms.RandomApply(
[transforms.GaussianBlur(kernel_size=kernel_size)], p=0.5
)
)
self.color_transform = transforms.Compose(color_transform)
if normalize is None:
self.final_transform = transforms.ToTensor()
else:
self.final_transform = transforms.Compose(
[transforms.ToTensor(), normalize]
)
self.transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.RandomRotation(10),
transforms.RandomResizedCrop(
self.input_height, scale=(0.6, 1.0)
), # default 0.08, 1.0
# transforms.RandomHorizontalFlip(p=0.5),
self.color_transform,
self.final_transform,
]
)
def __call__(self, sample):
return (
self.transform(torch.from_numpy(sample)),
self.transform(torch.from_numpy(sample)),
)