-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathutils.py
225 lines (190 loc) · 8.89 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.utils import weight_norm as wn
import numpy as np
import os
from PIL import Image
def concat_elu(x):
""" like concatenated ReLU (http://arxiv.org/abs/1603.05201), but then with ELU """
# Pytorch ordering
axis = len(x.size()) - 3
return F.elu(torch.cat([x, -x], dim=axis))
def log_sum_exp(x):
""" numerically stable log_sum_exp implementation that prevents overflow """
# TF ordering
axis = len(x.size()) - 1
m, _ = torch.max(x, dim=axis)
m2, _ = torch.max(x, dim=axis, keepdim=True)
return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
def log_prob_from_logits(x):
""" numerically stable log_softmax implementation that prevents overflow """
# TF ordering
axis = len(x.size()) - 1
m, _ = torch.max(x, dim=axis, keepdim=True)
return x - m - torch.log(torch.sum(torch.exp(x - m), dim=axis, keepdim=True))
def discretized_mix_logistic_loss(x, l):
""" log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """
# Pytorch ordering
x = x.permute(0, 2, 3, 1)
l = l.permute(0, 2, 3, 1)
xs = [int(y) for y in x.size()]
ls = [int(y) for y in l.size()]
# here and below: unpacking the params of the mixture of logistics
nr_mix = int(ls[-1] / 10)
logit_probs = l[:, :, :, :nr_mix]
l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3]) # 3 for mean, scale, coef
means = l[:, :, :, :, :nr_mix]
# log_scales = torch.max(l[:, :, :, :, nr_mix:2 * nr_mix], -7.)
log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.)
coeffs = F.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix])
# here and below: getting the means and adjusting them based on preceding
# sub-pixels
x = x.contiguous()
x = x.unsqueeze(-1) + Variable(torch.zeros(xs + [nr_mix]).to(x.device), requires_grad=False)
m2 = (means[:, :, :, 1, :] + coeffs[:, :, :, 0, :]
* x[:, :, :, 0, :]).view(xs[0], xs[1], xs[2], 1, nr_mix)
m3 = (means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] +
coeffs[:, :, :, 2, :] * x[:, :, :, 1, :]).view(xs[0], xs[1], xs[2], 1, nr_mix)
means = torch.cat((means[:, :, :, 0, :].unsqueeze(3), m2, m3), dim=3)
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1. / 255.)
cdf_plus = F.sigmoid(plus_in)
min_in = inv_stdv * (centered_x - 1. / 255.)
cdf_min = F.sigmoid(min_in)
# log probability for edge case of 0 (before scaling)
log_cdf_plus = plus_in - F.softplus(plus_in)
# log probability for edge case of 255 (before scaling)
log_one_minus_cdf_min = -F.softplus(min_in)
cdf_delta = cdf_plus - cdf_min # probability for all other cases
mid_in = inv_stdv * centered_x
# log probability in the center of the bin, to be used in extreme cases
# (not actually used in our code)
log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
# now select the right output: left edge case, right edge case, normal
# case, extremely low prob case (doesn't actually happen for us)
# this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select()
# log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta)))
# robust version, that still works if probabilities are below 1e-5 (which never happens in our code)
# tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs
# the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue
# if the probability on a sub-pixel is below 1e-5, we use an approximation
# based on the assumption that the log-density is constant in the bin of
# the observed sub-pixel value
inner_inner_cond = (cdf_delta > 1e-5).float()
inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * (log_pdf_mid - np.log(127.5))
inner_cond = (x > 0.999).float()
inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
cond = (x < -0.999).float()
log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
log_probs = torch.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs)
return -torch.sum(log_sum_exp(log_probs))
def to_one_hot(tensor, n, fill_with=1.):
# we perform one hot encore with respect to the last axis
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
if tensor.is_cuda : one_hot = one_hot.cuda()
one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
return Variable(one_hot)
def sample_from_discretized_mix_logistic(l, nr_mix):
# Pytorch ordering
l = l.permute(0, 2, 3, 1)
ls = [int(y) for y in l.size()]
xs = ls[:-1] + [3]
# unpack parameters
logit_probs = l[:, :, :, :nr_mix]
l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3])
# sample mixture indicator from softmax
temp = torch.FloatTensor(logit_probs.size())
if l.is_cuda : temp = temp.cuda()
temp.uniform_(1e-5, 1. - 1e-5)
temp = logit_probs.data - torch.log(- torch.log(temp))
_, argmax = temp.max(dim=3)
one_hot = to_one_hot(argmax, nr_mix)
sel = one_hot.view(xs[:-1] + [1, nr_mix])
# select logistic parameters
means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4)
log_scales = torch.clamp(torch.sum(
l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.)
coeffs = torch.sum(F.tanh(
l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) * sel, dim=4)
# sample from logistic & clip to interval
# we don't actually round to the nearest 8bit value when sampling
u = torch.FloatTensor(means.size())
if l.is_cuda : u = u.cuda()
u.uniform_(1e-5, 1. - 1e-5)
u = Variable(u)
x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-1.), max=1.)
x1 = torch.clamp(torch.clamp(
x[:, :, :, 1] + coeffs[:, :, :, 0] * x0, min=-1.), max=1.)
x2 = torch.clamp(torch.clamp(
x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1, min=-1.), max=1.)
out = torch.cat([x0.view(xs[:-1] + [1]), x1.view(xs[:-1] + [1]), x2.view(xs[:-1] + [1])], dim=3)
# put back in Pytorch ordering
out = out.permute(0, 3, 1, 2)
return out
''' utilities for shifting the image around, efficient alternative to masking convolutions '''
def down_shift(x, pad=None):
# Pytorch ordering
xs = [int(y) for y in x.size()]
# when downshifting, the last row is removed
x = x[:, :, :xs[2] - 1, :]
# padding left, padding right, padding top, padding bottom
pad = nn.ZeroPad2d((0, 0, 1, 0)) if pad is None else pad
return pad(x)
def right_shift(x, pad=None):
# Pytorch ordering
xs = [int(y) for y in x.size()]
# when righshifting, the last column is removed
x = x[:, :, :, :xs[3] - 1]
# padding left, padding right, padding top, padding bottom
pad = nn.ZeroPad2d((1, 0, 0, 0)) if pad is None else pad
return pad(x)
def sample(model, sample_batch_size, obs, sample_op):
model.train(False)
with torch.no_grad():
data = torch.zeros(sample_batch_size, obs[0], obs[1], obs[2])
data = data.to(next(model.parameters()).device)
for i in range(obs[1]):
for j in range(obs[2]):
data_v = data
out = model(data_v, sample=True)
out_sample = sample_op(out)
data[:, :, i, j] = out_sample.data[:, :, i, j]
return data
class mean_tracker:
def __init__(self):
self.sum = 0
self.count = 0
def update(self, new_value):
self.sum += new_value
self.count += 1
def get_mean(self):
return self.sum/self.count
def reset(self):
self.sum = 0
self.count = 0
class ratio_tracker:
def __init__(self):
self.sum = 0
self.count = 0
def update(self, new_value, new_count):
self.sum += new_value
self.count += new_count
def get_ratio(self):
return self.sum/self.count
def reset(self):
self.sum = 0
self.count = 0
def check_dir_and_create(dir):
if not os.path.exists(dir):
os.makedirs(dir, exist_ok=True)
def save_images(tensor, images_folder_path, label=''):
os.makedirs(images_folder_path, exist_ok=True)
for i, img_tensor in enumerate(tensor):
img = Image.fromarray((img_tensor.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8), mode='RGB')
img_path = f"{images_folder_path}/{label}_image_{i+1:02d}.jpg"
img.save(img_path)