-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtorch_utils.py
69 lines (44 loc) · 1.47 KB
/
torch_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import torch
def gpu(tensor, gpu=False):
if gpu:
return tensor.cuda()
else:
return tensor
def cpu(tensor):
if tensor.is_cuda:
return tensor.cpu()
else:
return tensor
def minibatch(*tensors, **kwargs):
batch_size = kwargs.get('batch_size', 128)
if len(tensors) == 1:
tensor = tensors[0]
for i in range(0, len(tensor), batch_size):
yield tensor[i:i + batch_size]
else:
for i in range(0, len(tensors[0]), batch_size):
yield tuple(x[i:i + batch_size] for x in tensors)
def shuffle(*arrays, **kwargs):
random_state = kwargs.get('random_state')
if len(set(len(x) for x in arrays)) != 1:
raise ValueError('All inputs to shuffle must have '
'the same length.')
if random_state is None:
random_state = np.random.RandomState()
shuffle_indices = np.arange(len(arrays[0]))
random_state.shuffle(shuffle_indices)
if len(arrays) == 1:
return arrays[0][shuffle_indices]
else:
return tuple(x[shuffle_indices] for x in arrays)
def assert_no_grad(variable):
if variable.requires_grad:
raise ValueError(
"nn criterions don't compute the gradient w.r.t. targets - please "
"mark these variables as volatile or not requiring gradients"
)
def set_seed(seed, cuda=False):
torch.manual_seed(seed)
if cuda:
torch.cuda.manual_seed(seed)