-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_data.py
115 lines (91 loc) · 4.85 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, CIFAR100
from torch.utils.data import DataLoader, random_split
import importlib.util
import sys
import logging
class DataManager():
def __init__(self, dataset='cifar10', path="", custom_path="", batch_size=32, remain_ratio=0, validation_ratio=0.1, num_clients=10, random_seed=1234):
self.dataset = dataset
self.trainloader, self.testloader, self.dl_clients, self.remainset = self.__load_data(
dataset=dataset,
path=path,
custom_path = custom_path,
batch_size=batch_size,
remain_ratio=remain_ratio,
validation_ratio=validation_ratio,
num_clients=num_clients,
random_seed=random_seed,
)
def get_data(self):
return self.trainloader, self.testloader, self.dl_clients, self.remainset
def __load_data(self, dataset='cifar10', path='', custom_path='', batch_size=32, remain_ratio=0, validation_ratio=0.1, num_clients=10, random_seed=1234):
## Custom Dataloader ##
if dataset == "custom_dataset":
spec = importlib.util.spec_from_file_location("custom_dataset", custom_path['py_path'])
foo = importlib.util.module_from_spec(spec)
sys.modules["module.name"] = foo
spec.loader.exec_module(foo)
trainset = foo.CustomDataset(path=custom_path['data_path'], train=True)
testset = foo.CustomDataset(path=custom_path['data_path'], train=False)
elif dataset == ('cifar10' or 'cifar100'):
load_func = {
'cifar10': self.load_cifar10,
'cifar100' : self.load_cifar100,
}
trainset, testset = load_func[dataset]()
else:
logging.info("Unavailable dataset name.")
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=True)
remain_ratio = int(remain_ratio) if remain_ratio > 1 else int(len(trainset) * remain_ratio)
if remain_ratio > 0:
remainset, ds_target = random_split(trainset, [remain_ratio, len(trainset) - remain_ratio], torch.Generator().manual_seed(random_seed))
else:
remainset, ds_target = None, trainset
num_clients_data = [len(ds_target) // num_clients for _ in range(num_clients)]
num_clients_data[-1] = len(ds_target) % num_clients if len(ds_target) % num_clients != 0 else num_clients_data[-1]
ds_clients = random_split(ds_target, num_clients_data, torch.Generator().manual_seed(random_seed))
dl_clients = []
for ds_client in ds_clients:
len_val = int(len(ds_client) * validation_ratio)
len_train = len(ds_client) - len_val
ds_train, ds_val = random_split(ds_client, [len_train, len_val], torch.Generator().manual_seed(random_seed))
dl_clients.append({
'train': DataLoader(ds_train, batch_size=batch_size, shuffle=True),
'val': DataLoader(ds_val, batch_size=batch_size, shuffle=True),
})
return trainloader, testloader, dl_clients, remainset
def load_cifar10(self):
# Download and transform CIFAR-10 (train and test)
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# TODO: make the first parameter (path) to be defined from user
trainset = CIFAR10("./data", train=True, download=True, transform=transform_train)
testset = CIFAR10("./data", train=False, download=True, transform=transform_test)
return trainset, testset
def load_cifar100(self):
# Download and transform CIFAR-100 (train and test)
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
])
# TODO: make the first parameter (path) to be defined from user
trainset = CIFAR100("./data", train=True, download=True, transform=transform_train)
testset = CIFAR100("./data", train=False, download=True, transform=transform_test)
return trainset, testset