-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathactivations.py
314 lines (237 loc) · 9.72 KB
/
activations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
"""
An implementation of entmax (Peters et al., 2019). See
https://arxiv.org/pdf/1905.05702 for detailed description.
This builds on previous work with sparsemax (Martins & Astudillo, 2016).
See https://arxiv.org/pdf/1602.02068.
"""
# Author: Ben Peters
# Author: Vlad Niculae <[email protected]>
# License: MIT
import torch
import torch.nn as nn
from torch.autograd import Function
def _make_ix_like(X, dim):
d = X.size(dim)
rho = torch.arange(1, d + 1, device=X.device, dtype=X.dtype)
view = [1] * X.dim()
view[0] = -1
return rho.view(view).transpose(0, dim)
def _roll_last(X, dim):
if dim == -1:
return X
elif dim < 0:
dim = X.dim() - dim
perm = [i for i in range(X.dim()) if i != dim] + [dim]
return X.permute(perm)
def _sparsemax_threshold_and_support(X, dim=-1, k=None):
"""Core computation for sparsemax: optimal threshold and support size.
Parameters
----------
X : torch.Tensor
The input tensor to compute thresholds over.
dim : int
The dimension along which to apply sparsemax.
k : int or None
number of largest elements to partial-sort over. For optimal
performance, should be slightly bigger than the expected number of
nonzeros in the solution. If the solution is more than k-sparse,
this function is recursively called with a 2*k schedule.
If `None`, full sorting is performed from the beginning.
Returns
-------
tau : torch.Tensor like `X`, with all but the `dim` dimension intact
the threshold value for each vector
support_size : torch LongTensor, shape like `tau`
the number of nonzeros in each vector.
"""
if k is None or k >= X.shape[dim]: # do full sort
topk, _ = torch.sort(X, dim=dim, descending=True)
else:
topk, _ = torch.topk(X, k=k, dim=dim)
topk_cumsum = topk.cumsum(dim) - 1
rhos = _make_ix_like(topk, dim)
support = rhos * topk > topk_cumsum
support_size = support.sum(dim=dim).unsqueeze(dim)
tau = topk_cumsum.gather(dim, support_size - 1)
tau /= support_size.to(X.dtype)
if k is not None and k < X.shape[dim]:
unsolved = (support_size == k).squeeze(dim)
if torch.any(unsolved):
in_ = _roll_last(X, dim)[unsolved]
tau_, ss_ = _sparsemax_threshold_and_support(in_, dim=-1, k=2 * k)
_roll_last(tau, dim)[unsolved] = tau_
_roll_last(support_size, dim)[unsolved] = ss_
return tau, support_size
def _entmax_threshold_and_support(X, dim=-1, k=None):
"""Core computation for 1.5-entmax: optimal threshold and support size.
Parameters
----------
X : torch.Tensor
The input tensor to compute thresholds over.
dim : int
The dimension along which to apply 1.5-entmax.
k : int or None
number of largest elements to partial-sort over. For optimal
performance, should be slightly bigger than the expected number of
nonzeros in the solution. If the solution is more than k-sparse,
this function is recursively called with a 2*k schedule.
If `None`, full sorting is performed from the beginning.
Returns
-------
tau : torch.Tensor like `X`, with all but the `dim` dimension intact
the threshold value for each vector
support_size : torch LongTensor, shape like `tau`
the number of nonzeros in each vector.
"""
if k is None or k >= X.shape[dim]: # do full sort
Xsrt, _ = torch.sort(X, dim=dim, descending=True)
else:
Xsrt, _ = torch.topk(X, k=k, dim=dim)
rho = _make_ix_like(Xsrt, dim)
mean = Xsrt.cumsum(dim) / rho
mean_sq = (Xsrt ** 2).cumsum(dim) / rho
ss = rho * (mean_sq - mean ** 2)
delta = (1 - ss) / rho
# NOTE this is not exactly the same as in reference algo
# Fortunately it seems the clamped values never wrongly
# get selected by tau <= sorted_z. Prove this!
delta_nz = torch.clamp(delta, 0)
tau = mean - torch.sqrt(delta_nz)
support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim)
tau_star = tau.gather(dim, support_size - 1)
if k is not None and k < X.shape[dim]:
unsolved = (support_size == k).squeeze(dim)
if torch.any(unsolved):
X_ = _roll_last(X, dim)[unsolved]
tau_, ss_ = _entmax_threshold_and_support(X_, dim=-1, k=2 * k)
_roll_last(tau_star, dim)[unsolved] = tau_
_roll_last(support_size, dim)[unsolved] = ss_
return tau_star, support_size
class SparsemaxFunction(Function):
@classmethod
def forward(cls, ctx, X, dim=-1, k=None):
ctx.dim = dim
max_val, _ = X.max(dim=dim, keepdim=True)
X = X - max_val # same numerical stability trick as softmax
tau, supp_size = _sparsemax_threshold_and_support(X, dim=dim, k=k)
output = torch.clamp(X - tau, min=0)
ctx.save_for_backward(supp_size, output)
return output
@classmethod
def backward(cls, ctx, grad_output):
supp_size, output = ctx.saved_tensors
dim = ctx.dim
grad_input = grad_output.clone()
grad_input[output == 0] = 0
v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze(dim)
v_hat = v_hat.unsqueeze(dim)
grad_input = torch.where(output != 0, grad_input - v_hat, grad_input)
return grad_input, None, None
class Entmax15Function(Function):
@classmethod
def forward(cls, ctx, X, dim=0, k=None):
ctx.dim = dim
max_val, _ = X.max(dim=dim, keepdim=True)
X = X - max_val # same numerical stability trick as for softmax
X = X / 2 # divide by 2 to solve actual Entmax
tau_star, _ = _entmax_threshold_and_support(X, dim=dim, k=k)
Y = torch.clamp(X - tau_star, min=0) ** 2
ctx.save_for_backward(Y)
return Y
@classmethod
def backward(cls, ctx, dY):
Y, = ctx.saved_tensors
gppr = Y.sqrt() # = 1 / g'' (Y)
dX = dY * gppr
q = dX.sum(ctx.dim) / gppr.sum(ctx.dim)
q = q.unsqueeze(ctx.dim)
dX -= q * gppr
return dX, None, None
def sparsemax(X, dim=-1, k=None):
"""sparsemax: normalizing sparse transform (a la softmax).
Solves the projection:
min_p ||x - p||_2 s.t. p >= 0, sum(p) == 1.
Parameters
----------
X : torch.Tensor
The input tensor.
dim : int
The dimension along which to apply sparsemax.
k : int or None
number of largest elements to partial-sort over. For optimal
performance, should be slightly bigger than the expected number of
nonzeros in the solution. If the solution is more than k-sparse,
this function is recursively called with a 2*k schedule.
If `None`, full sorting is performed from the beginning.
Returns
-------
P : torch tensor, same shape as X
The projection result, such that P.sum(dim=dim) == 1 elementwise.
"""
return SparsemaxFunction.apply(X, dim, k)
def entmax15(X, dim=-1, k=None):
"""1.5-entmax: normalizing sparse transform (a la softmax).
Solves the optimization problem:
max_p <x, p> - H_1.5(p) s.t. p >= 0, sum(p) == 1.
where H_1.5(p) is the Tsallis alpha-entropy with alpha=1.5.
Parameters
----------
X : torch.Tensor
The input tensor.
dim : int
The dimension along which to apply 1.5-entmax.
k : int or None
number of largest elements to partial-sort over. For optimal
performance, should be slightly bigger than the expected number of
nonzeros in the solution. If the solution is more than k-sparse,
this function is recursively called with a 2*k schedule.
If `None`, full sorting is performed from the beginning.
Returns
-------
P : torch tensor, same shape as X
The projection result, such that P.sum(dim=dim) == 1 elementwise.
"""
return Entmax15Function.apply(X, dim, k)
class Sparsemax(nn.Module):
def __init__(self, dim=-1, k=None):
"""sparsemax: normalizing sparse transform (a la softmax).
Solves the projection:
min_p ||x - p||_2 s.t. p >= 0, sum(p) == 1.
Parameters
----------
dim : int
The dimension along which to apply sparsemax.
k : int or None
number of largest elements to partial-sort over. For optimal
performance, should be slightly bigger than the expected number of
nonzeros in the solution. If the solution is more than k-sparse,
this function is recursively called with a 2*k schedule.
If `None`, full sorting is performed from the beginning.
"""
self.dim = dim
self.k = k
super(Sparsemax, self).__init__()
def forward(self, X):
return sparsemax(X, dim=self.dim, k=self.k)
class Entmax15(nn.Module):
def __init__(self, dim=-1, k=None):
"""1.5-entmax: normalizing sparse transform (a la softmax).
Solves the optimization problem:
max_p <x, p> - H_1.5(p) s.t. p >= 0, sum(p) == 1.
where H_1.5(p) is the Tsallis alpha-entropy with alpha=1.5.
Parameters
----------
dim : int
The dimension along which to apply 1.5-entmax.
k : int or None
number of largest elements to partial-sort over. For optimal
performance, should be slightly bigger than the expected number of
nonzeros in the solution. If the solution is more than k-sparse,
this function is recursively called with a 2*k schedule.
If `None`, full sorting is performed from the beginning.
"""
self.dim = dim
self.k = k
super(Entmax15, self).__init__()
def forward(self, X):
return entmax15(X, dim=self.dim, k=self.k)