-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathperformer.py
298 lines (230 loc) · 10.1 KB
/
performer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
"""
Paper: Rethinking Attention with Performers
Link: https://arxiv.org/abs/2009.14794
mainly modified from https://github.com/lucidrains/performer-pytorch/blob/main/performer_pytorch/performer_pytorch.py
"""
import math
from scipy.stats import ortho_group
import torch
from torch import nn
from einops import rearrange, repeat
from functools import partial
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# kernel functions
# transcribed from jax to pytorch from
# https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True,
eps=1e-4, device=None):
b, h, *_ = data.shape
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
ratio = (projection_matrix.shape[0] ** -0.5)
projection = repeat(projection_matrix, 'j d -> b h j d', b=b, h=h)
projection = projection.type_as(data)
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data),
projection)
diag_data = data ** 2
diag_data = torch.sum(diag_data, dim=-1)
diag_data = (diag_data / 2.0) * (data_normalizer ** 2)
diag_data = diag_data.unsqueeze(dim=-1)
if is_query:
data_dash = ratio * (
torch.exp(data_dash - diag_data -
torch.max(data_dash, dim=-1,
keepdim=True).values) + eps)
else:
data_dash = ratio * (
torch.exp(data_dash - diag_data - torch.max(data_dash)) + eps)
return data_dash.type_as(data)
def generalized_kernel(data, *, projection_matrix, kernel_fn=nn.ReLU(),
kernel_epsilon=0.001, normalize_data=True, device=None):
b, h, *_ = data.shape
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
if projection_matrix is None:
return kernel_fn(data_normalizer * data) + kernel_epsilon
projection = repeat(projection_matrix, 'j d -> b h j d', b=b, h=h)
projection = projection.type_as(data)
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data),
projection)
data_prime = kernel_fn(data_dash) + kernel_epsilon
return data_prime.type_as(data)
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling=0,
device=None):
nb_full_blocks = int(nb_rows / nb_columns)
block_list = []
for _ in range(nb_full_blocks):
q = torch.FloatTensor(ortho_group.rvs(nb_columns), device='cpu').to(device)
block_list.append(q)
remaining_rows = nb_rows - nb_full_blocks * nb_columns
if remaining_rows > 0:
q = torch.FloatTensor(ortho_group.rvs(nb_columns), device='cpu').to(device)
block_list.append(q[:remaining_rows])
final_matrix = torch.cat(block_list)
if scaling == 0:
multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(
dim=1)
elif scaling == 1:
multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,),
device=device)
else:
raise ValueError(f'Invalid scaling {scaling}')
return torch.diag(multiplier) @ final_matrix
# linear attention classes with softmax kernel
# non-causal linear attention
def linear_attention(q, k, v):
k_cumsum = k.sum(dim=-2)
D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q))
context = torch.einsum('...nd,...ne->...de', k, v)
out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
return out
class FastAttention(nn.Module):
def __init__(self, dim_heads, nb_features=None, ortho_scaling=0,
generalized_attention=False, kernel_fn=nn.ReLU(),
no_projection=False):
super().__init__()
nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
self.dim_heads = dim_heads
self.nb_features = nb_features
self.ortho_scaling = ortho_scaling
self.create_projection = partial(gaussian_orthogonal_random_matrix,
nb_rows=self.nb_features,
nb_columns=dim_heads,
scaling=ortho_scaling)
projection_matrix = self.create_projection()
self.register_buffer('projection_matrix', projection_matrix)
self.generalized_attention = generalized_attention
self.kernel_fn = kernel_fn
# if this is turned on, no projection will be used
# queries and keys will be softmax-ed as in the original efficient attention paper
self.no_projection = no_projection
self.__flops__ = 0
@torch.no_grad()
def redraw_projection_matrix(self, device):
projections = self.create_projection(device=device)
self.projection_matrix.copy_(projections)
del projections
def forward(self, q, k, v):
device = q.device
if self.no_projection:
q = q.softmax(dim=-1)
k = k.softmax(dim=-2)
elif self.generalized_attention:
create_kernel = partial(generalized_kernel,
kernel_fn=self.kernel_fn,
projection_matrix=self.projection_matrix,
device=device)
q, k = map(create_kernel, (q, k))
else:
create_kernel = partial(softmax_kernel,
projection_matrix=self.projection_matrix,
device=device)
q = create_kernel(q, is_query=True)
k = create_kernel(k, is_query=False)
out = linear_attention(q, k, v)
return out
@staticmethod
def compute_macs(module, input, output):
input = input[0]
_, H, N, C = input.shape
Nf = module.nb_features
assert C == module.dim_heads
macs = 0
n_params = 0
if module.no_projection:
raise ValueError("Not supported yet!")
elif module.generalized_attention:
raise ValueError("Not supported yet!")
else:
n_params += C * Nf
# q = create_kernel(q, is_query=True)
macs += H * N * Nf * C + 2 * H * N * C + 2 * H * N * Nf
# k = create_kernel(k, is_query=False)
macs += H * N * Nf * C + 2 * H * N * C + 2 * H * N * Nf
# out = linear_attention(q, k, v)
# k_cumsum = k.sum(dim=-2)
macs += H * N * Nf
# D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q))
macs += H * N * Nf
# context = torch.einsum('...nd,...ne->...de', k, v)
macs += H * N * Nf * C
# out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
macs += 2 * H * N * Nf * C
# print('macs fast att', macs / 1e8)
module.__flops__ += macs
# return n_params, macs
class PerformerSelfAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,
attn_drop=0., proj_drop=0., nb_features=None,
generalized_attention=False,
kernel_fn=nn.ReLU(), no_projection=False):
super().__init__()
assert dim % num_heads == 0, 'dimension must be divisible by number of heads'
head_dim = dim // num_heads
self.dim = dim
self.head_dim = head_dim
self.nb_features = nb_features
self.no_projection = no_projection
self.generalized_attention = generalized_attention
self.fast_attention = FastAttention(
head_dim, nb_features, generalized_attention=generalized_attention,
kernel_fn=kernel_fn, no_projection=no_projection
)
self.num_heads = num_heads
self.scale = qk_scale or head_dim ** -0.5 # not used in performer
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, H=14, W=14):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
x = self.fast_attention(q, k, v)
x = rearrange(x, 'b h n d -> b n (h d)')
x = self.proj(x)
x = self.proj_drop(x)
return x
def flops(self, N):
H = self.num_heads
C = self.head_dim
Nf = self.nb_features
macs = 0
n_params = 0
if self.no_projection:
raise ValueError("Not supported yet!")
elif self.generalized_attention:
raise ValueError("Not supported yet!")
else:
n_params += C * Nf
# q = create_kernel(q, is_query=True)
macs += H * N * Nf * C + 2 * H * N * C + 2 * H * N * Nf
# k = create_kernel(k, is_query=False)
macs += H * N * Nf * C + 2 * H * N * C + 2 * H * N * Nf
# out = linear_attention(q, k, v)
# k_cumsum = k.sum(dim=-2)
macs += H * N * Nf
# D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q))
macs += H * N * Nf
# context = torch.einsum('...nd,...ne->...de', k, v)
macs += H * N * Nf * C
# out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
macs += 2 * H * N * Nf * C
# print('macs fast att', macs / 1e8)
return macs
if __name__ == '__main__':
dim = 768
num_heads = 12
H = W = 14
B = 64
# special for performer
nb_features = 256
model = PerformerSelfAttention(dim, num_heads=num_heads, qkv_bias=True, nb_features=nb_features)
from utils import measure_flops_params, measure_throughput_cpu, measure_throughput_gpu
x = torch.randn(1, H * W, dim)
measure_flops_params(model, x)
measure_throughput_cpu(model)
measure_throughput_gpu(model)