-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
47 lines (36 loc) · 2.23 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
def get_dataset(config):
kwargs = {'num_workers': 1, 'pin_memory': True, 'drop_last': True}
rescaling = lambda x: (x - .5) * 2.
ds_transforms = transforms.Compose([transforms.ToTensor(), rescaling])
if config.dataset == 'MNIST':
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(os.path.join('runs', 'datasets', 'MNIST'), download=True,
train=True, transform=ds_transforms),
batch_size=config.batch_size,
shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(datasets.MNIST(config.data_dir, train=False, download=True,
transform=ds_transforms), batch_size=config.batch_size,
shuffle=False, **kwargs)
elif config.dataset == 'FashionMNIST':
train_loader = torch.utils.data.DataLoader(
datasets.FashionMNIST(os.path.join('runs', 'datasets', 'FashionMNIST'), download=True,
train=True, transform=ds_transforms),
batch_size=config.batch_size,
shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(datasets.FashionMNIST(config.data_dir, train=False, download=True,
transform=ds_transforms),
batch_size=config.batch_size,
shuffle=False, **kwargs)
elif 'CIFAR10' in config.dataset:
train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(config.data_dir, train=True,
download=True, transform=ds_transforms),
batch_size=config.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(config.data_dir, train=False, download=True,
transform=ds_transforms),
batch_size=config.batch_size,
shuffle=False, **kwargs)
return train_loader, test_loader