-
Notifications
You must be signed in to change notification settings - Fork 185
/
Copy pathdata.py
60 lines (46 loc) · 2.65 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
import torch
from torchvision import datasets
class Data:
def __init__(self, X_train, Y_train, X_test, Y_test, handler):
self.X_train = X_train
self.Y_train = Y_train
self.X_test = X_test
self.Y_test = Y_test
self.handler = handler
self.n_pool = len(X_train)
self.n_test = len(X_test)
self.labeled_idxs = np.zeros(self.n_pool, dtype=bool)
def initialize_labels(self, num):
# generate initial labeled pool
tmp_idxs = np.arange(self.n_pool)
np.random.shuffle(tmp_idxs)
self.labeled_idxs[tmp_idxs[:num]] = True
def get_labeled_data(self):
labeled_idxs = np.arange(self.n_pool)[self.labeled_idxs]
return labeled_idxs, self.handler(self.X_train[labeled_idxs], self.Y_train[labeled_idxs])
def get_unlabeled_data(self):
unlabeled_idxs = np.arange(self.n_pool)[~self.labeled_idxs]
return unlabeled_idxs, self.handler(self.X_train[unlabeled_idxs], self.Y_train[unlabeled_idxs])
def get_train_data(self):
return self.labeled_idxs.copy(), self.handler(self.X_train, self.Y_train)
def get_test_data(self):
return self.handler(self.X_test, self.Y_test)
def cal_test_acc(self, preds):
return 1.0 * (self.Y_test==preds).sum().item() / self.n_test
def get_MNIST(handler):
raw_train = datasets.MNIST('./data/MNIST', train=True, download=True)
raw_test = datasets.MNIST('./data/MNIST', train=False, download=True)
return Data(raw_train.data[:40000], raw_train.targets[:40000], raw_test.data[:40000], raw_test.targets[:40000], handler)
def get_FashionMNIST(handler):
raw_train = datasets.FashionMNIST('./data/FashionMNIST', train=True, download=True)
raw_test = datasets.FashionMNIST('./data/FashionMNIST', train=False, download=True)
return Data(raw_train.data[:40000], raw_train.targets[:40000], raw_test.data[:40000], raw_test.targets[:40000], handler)
def get_SVHN(handler):
data_train = datasets.SVHN('./data/SVHN', split='train', download=True)
data_test = datasets.SVHN('./data/SVHN', split='test', download=True)
return Data(data_train.data[:40000], torch.from_numpy(data_train.labels)[:40000], data_test.data[:40000], torch.from_numpy(data_test.labels)[:40000], handler)
def get_CIFAR10(handler):
data_train = datasets.CIFAR10('./data/CIFAR10', train=True, download=True)
data_test = datasets.CIFAR10('./data/CIFAR10', train=False, download=True)
return Data(data_train.data[:40000], torch.LongTensor(data_train.targets)[:40000], data_test.data[:40000], torch.LongTensor(data_test.targets)[:40000], handler)