-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdata.py
61 lines (48 loc) · 4.01 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader
import random
def _get_transforms():
"""
The AdaMatch paper uses CTAugment as its strong augmentations. I'm going to
create a pipeline of transforms similar to the ones used by CTAugment.
"""
train_transform_weak = transforms.Compose([transforms.ToTensor(),
transforms.Resize(28),
transforms.RandomAffine(45, translate=(0.3, 0.3), scale=(0.8, 1.2), shear=(-0.3, 0.3, -0.3, 0.3))
])
train_transform_strong = transforms.Compose([transforms.ToTensor(),
transforms.Resize(28),
transforms.RandomAutocontrast(),
transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 1.)),
#transforms.RandomEqualize(), # only on PIL images
transforms.RandomInvert(),
#transforms.RandomPosterize(random.randint(1, 8)), # only on PIL images
transforms.RandomAdjustSharpness(random.uniform(0, 1)),
transforms.RandomSolarize(random.uniform(0, 1)),
transforms.RandomAffine(45, translate=(0.3, 0.3), scale=(0.8, 1.2), shear=(-0.3, 0.3, -0.3, 0.3)),
transforms.RandomErasing()
])
test_transform = transforms.Compose([transforms.ToTensor(),
transforms.Resize(28)
])
return train_transform_weak, train_transform_strong, test_transform
def get_dataloaders(download_path, batch_size_source=32, workers=2):
train_transform_weak, train_transform_strong, test_transform = _get_transforms()
BATCH_SIZE_source = batch_size_source
BATCH_SIZE_target = 3 * BATCH_SIZE_source
# source datasets and dataloaders
source_dataset_train_weak = torchvision.datasets.MNIST(download_path, train=True, download=True, transform=train_transform_weak)
source_dataset_train_strong = torchvision.datasets.MNIST(download_path, train=True, download=True, transform=train_transform_strong)
source_dataset_test = torchvision.datasets.MNIST(download_path, train=False, download=True, transform=test_transform)
source_dataloader_train_weak = DataLoader(source_dataset_train_weak, shuffle=False, batch_size=BATCH_SIZE_source, num_workers=workers)
source_dataloader_train_strong = DataLoader(source_dataset_train_strong, shuffle=False, batch_size=BATCH_SIZE_source, num_workers=workers)
source_dataloader_test = DataLoader(source_dataset_test, shuffle=True, batch_size=BATCH_SIZE_source, num_workers=workers)
# target datasets and dataloaders
target_dataset_train_weak = torchvision.datasets.USPS(download_path, train=True, download=True, transform=train_transform_weak)
target_dataset_train_strong = torchvision.datasets.USPS(download_path, train=True, download=True, transform=train_transform_strong)
target_dataset_test = torchvision.datasets.USPS(download_path, train=False, download=True, transform=test_transform)
target_dataloader_train_weak = DataLoader(target_dataset_train_weak, shuffle=False, batch_size=BATCH_SIZE_target, num_workers=workers)
target_dataloader_train_strong = DataLoader(target_dataset_train_strong, shuffle=False, batch_size=BATCH_SIZE_target, num_workers=workers)
target_dataloader_test = DataLoader(target_dataset_test, shuffle=True, batch_size=BATCH_SIZE_target, num_workers=workers)
return (source_dataloader_train_weak, source_dataloader_train_strong, source_dataloader_test), (target_dataloader_train_weak, target_dataloader_train_strong, target_dataloader_test)