-
Notifications
You must be signed in to change notification settings - Fork 11
/
samplers.py
91 lines (79 loc) · 4.19 KB
/
samplers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import time
import h5py
import numpy as np
import torch
from torch.utils.data.sampler import BatchSampler
class BalancedRandomBatchSampler(BatchSampler):
def __init__(self, data_source, batch_size, num_epochs=50, balance=True, path=None):
self.data_source = data_source
self.batch_size = batch_size
self.order = balanced_shuffle(data_source.target_tensor, num_epochs, path) \
if balance else shuffle(data_source.target_tensor, num_epochs, path)
self.epoch = 0
def __iter__(self):
for i in range(len(self)):
batch = self.order[self.epoch, i*self.batch_size:(i+1)*self.batch_size]
batch = batch[batch >= 0].tolist()
yield iter(batch)
self.epoch += 1
if self.epoch >= self.order.size(0):
self.epoch = 0
def __len__(self):
return self.order.size(1) // self.batch_size
def balanced_shuffle(labels, num_epochs=50, path=None, start_time=time.time()):
order_path = '{path}/balanced_order_{num_epochs}.h5' \
.format(path=path, num_epochs=num_epochs)
if path is not None and os.path.isfile(order_path):
with h5py.File(order_path, 'r') as f:
order = f['order'][:]
else:
evenness = 5 # batch_size | evenness*num_classes
classes = np.unique(labels.numpy())
num_classes = len(classes)
loc_data_per_class = [np.argwhere(labels.numpy() == k).flatten() for k in classes]
num_data_per_class = [(labels.numpy() == k).sum() for k in classes]
max_data_per_class = max(num_data_per_class)
num_loc_split = (max_data_per_class // evenness) * np.ones(evenness, dtype=int)
num_loc_split[:(max_data_per_class % evenness)] += 1
loc_split = [0]
loc_split.extend(np.cumsum(num_loc_split).tolist())
order = -np.ones([num_epochs, max_data_per_class*num_classes], dtype=int)
for epoch in range(num_epochs):
order_e = -np.ones([max_data_per_class, num_classes], dtype=int)
for k in classes:
loc_k = np.random.permutation(loc_data_per_class[k])
for i in range(evenness):
loc_i = loc_k[loc_split[i]:loc_split[i+1]]
order_e[i:(len(loc_i)*evenness+i):evenness, k] = loc_i
order[epoch] = order_e.flatten()
print_freq = min([100, (num_epochs-1) // 5 + 1])
print_me = (epoch == 0 or epoch == num_epochs-1 or (epoch+1) % print_freq == 0)
if print_me:
print('{epoch:4d}/{num_epochs:4d} e; '.format(epoch=epoch+1, num_epochs=num_epochs), end='')
print('generate balanced random order; {time:8.3f} s'.format(time=time.time()-start_time))
if path is not None:
with h5py.File(order_path, 'w') as f:
f.create_dataset('order', data=order, compression='gzip', compression_opts=9)
print('balanced random order; {time:8.3f} s'.format(time=time.time()-start_time))
return torch.from_numpy(order)
def shuffle(labels, num_epochs=50, path=None, start_time=time.time()):
order_path = '{path}/order_{num_epochs}.h5' \
.format(path=path, num_epochs=num_epochs)
if path is not None and os.path.isfile(order_path):
with h5py.File(order_path, 'r') as f:
order = f['order'][:]
else:
order = -np.ones([num_epochs, labels.size(0)], dtype=int)
for epoch in range(num_epochs):
order[epoch] = np.random.permutation(labels.size(0))
print_freq = min([100, (num_epochs-1) // 5 + 1])
print_me = (epoch == 0 or epoch == num_epochs-1 or (epoch+1) % print_freq == 0)
if print_me:
print('{epoch:4d}/{num_epochs:4d} e; '.format(epoch=epoch+1, num_epochs=num_epochs), end='')
print('generate random order; {time:8.3f} s'.format(time=time.time()-start_time))
if path is not None:
with h5py.File(order_path, 'w') as f:
f.create_dataset('order', data=order, compression='gzip', compression_opts=9)
print('random order; {time:8.3f} s'.format(time=time.time()-start_time))
return torch.from_numpy(order)