-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
242 lines (179 loc) · 6.59 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import math
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
'''
Linear schedule for RSPO
'''
def linear_schedule(startval, endval, endtime):
return lambda t: startval + t / endtime * (endval - startval
) if t < endtime else endval
def extend_and_repeat(tensor, dim, repeat):
# Extend and repeast the tensor along dim axie and repeat it
ones_shape = [1 for _ in range(tensor.ndim + 1)]
ones_shape[dim] = repeat
return torch.unsqueeze(tensor, dim) * tensor.new_ones(ones_shape)
def long_num_to_string(x):
if x >= 1000000:
return str(int(x/1000000))+'M'
elif x >= 1000:
return str(int(x/1000)) + 'K'
else:
return str(x)
def create_log_gaussian(mean, log_std, t):
quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2))
l = mean.shape
log_z = log_std
z = l[-1] * math.log(2 * math.pi)
log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z
return log_p
def logsumexp(inputs, dim=None, keepdim=False):
if dim is None:
inputs = inputs.view(-1)
dim = 0
s, _ = torch.max(inputs, dim=dim, keepdim=True)
outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
if not keepdim:
outputs = outputs.squeeze(dim)
return outputs
def soft_update(target, source, tau):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(
target_param.data * (1.0 - tau) + param.data * tau)
def hard_update(target, source):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)
def disable_gradients(network):
# Disable calculations of gradients.
for param in network.parameters():
param.requires_grad = False
def soft_update_from_to(source, target, tau):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(
target_param.data * (1.0 - tau) + param.data * tau)
def copy_model_params_from_to(source, target):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)
def fanin_init(tensor):
size = tensor.size()
if len(size) == 2:
fan_in = size[0]
elif len(size) > 2:
fan_in = np.prod(size[1:])
else:
raise Exception("Shape must be have dimension at least 2.")
bound = 1. / np.sqrt(fan_in)
return tensor.data.uniform_(-bound, bound)
def fanin_init_weights_like(tensor):
size = tensor.size()
if len(size) == 2:
fan_in = size[0]
elif len(size) > 2:
fan_in = np.prod(size[1:])
else:
raise Exception("Shape must be have dimension at least 2.")
bound = 1. / np.sqrt(fan_in)
new_tensor = FloatTensor(tensor.size())
new_tensor.uniform_(-bound, bound)
return new_tensor
"""
GPU wrappers
"""
_use_gpu = False
device = None
_gpu_id = 0
def set_gpu_mode(mode):
global _use_gpu
global device
global _gpu_id
_gpu_id = mode
_use_gpu = mode is not None
device = torch.device("cuda:" + str(_gpu_id) if _use_gpu else "cpu")
print(device)
if _use_gpu:
torch.cuda.set_device(device)
def gpu_enabled():
return _use_gpu
def set_device(gpu_id):
if _use_gpu:
torch.cuda.set_device(gpu_id)
def get_device():
return device
# noinspection PyPep8Naming
def FloatTensor(*args, torch_device=None, **kwargs):
if torch_device is None:
torch_device = device
return torch.FloatTensor(*args, **kwargs, device=torch_device)
def from_numpy(*args, **kwargs):
return torch.from_numpy(*args, **kwargs).float().to(device)
def get_numpy(tensor):
return tensor.to('cpu').detach().numpy()
def zeros(*sizes, torch_device=None, **kwargs):
if torch_device is None:
torch_device = device
return torch.zeros(*sizes, **kwargs, device=torch_device)
def ones(*sizes, torch_device=None, **kwargs):
if torch_device is None:
torch_device = device
return torch.ones(*sizes, **kwargs, device=torch_device)
def ones_like(*args, torch_device=None, **kwargs):
if torch_device is None:
torch_device = device
return torch.ones_like(*args, **kwargs, device=torch_device)
def randn(*args, torch_device=None, **kwargs):
if torch_device is None:
torch_device = device
return torch.randn(*args, **kwargs, device=torch_device)
def rand(*args, torch_device=None, **kwargs):
if torch_device is None:
torch_device = device
return torch.rand(*args, **kwargs, device=torch_device)
def zeros_like(*args, torch_device=None, **kwargs):
if torch_device is None:
torch_device = device
return torch.zeros_like(*args, **kwargs, device=torch_device)
def tensor(*args, torch_device=None, **kwargs):
if torch_device is None:
torch_device = device
return torch.tensor(*args, **kwargs, device=torch_device)
def normal(*args, **kwargs):
return torch.normal(*args, **kwargs).to(device)
def fast_clip_grad_norm(parameters, max_norm):
r"""Clips gradient norm of an iterable of parameters.
Only support norm_type = 2
max_norm = 0, skip the total norm calculation and return 0
https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_
Returns:
Total norm of the parameters (viewed as a single vector).
"""
max_norm = float(max_norm)
if abs(max_norm) < 1e-6: # max_norm = 0
return 0
else:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
total_norm = torch.stack([(p.grad.detach().pow(2)).sum()
for p in parameters]).sum().sqrt().item()
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef)
return total_norm
def quantile_regression_loss(input, target, tau, weight):
"""
input: (N, T)
target: (N, T)
tau: (N, T)
"""
input = input.unsqueeze(-1)
target = target.detach().unsqueeze(-2)
tau = tau.detach().unsqueeze(-1)
weight = weight.detach().unsqueeze(-2)
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
L = F.smooth_l1_loss(expanded_input, expanded_target,
reduction="none") # (N, T, T)
sign = torch.sign(expanded_input - expanded_target) / 2. + 0.5
rho = torch.abs(tau - sign) * L * weight
return rho.sum(dim=-1).mean()