-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmodels.py
256 lines (189 loc) · 8.42 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
from torch import nn, einsum
import torch
from einops.layers.torch import Rearrange
from einops import rearrange, repeat
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FSAttention(nn.Module):
"""Factorized Self-Attention"""
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class FDAttention(nn.Module):
"""Factorized Dot-product Attention"""
def __init__(self, dim, nt, nh, nw, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.nt = nt
self.nh = nh
self.nw = nw
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, d, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
qs, qt = q.chunk(2, dim=1)
ks, kt = k.chunk(2, dim=1)
vs, vt = v.chunk(2, dim=1)
# Attention over spatial dimension
qs = qs.view(b, h // 2, self.nt, self.nh * self.nw, -1)
ks, vs = ks.view(b, h // 2, self.nt, self.nh * self.nw, -1), vs.view(b, h // 2, self.nt, self.nh * self.nw, -1)
spatial_dots = einsum('b h t i d, b h t j d -> b h t i j', qs, ks) * self.scale
sp_attn = self.attend(spatial_dots)
spatial_out = einsum('b h t i j, b h t j d -> b h t i d', sp_attn, vs)
# Attention over temporal dimension
qt = qt.view(b, h // 2, self.nh * self.nw, self.nt, -1)
kt, vt = kt.view(b, h // 2, self.nh * self.nw, self.nt, -1), vt.view(b, h // 2, self.nh * self.nw, self.nt, -1)
temporal_dots = einsum('b h s i d, b h s j d -> b h s i j', qt, kt) * self.scale
temporal_attn = self.attend(temporal_dots)
temporal_out = einsum('b h s i j, b h s j d -> b h s i d', temporal_attn, vt)
# return self.to_out(out)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class FSATransformerEncoder(nn.Module):
"""Factorized Self-Attention Transformer Encoder"""
def __init__(self, dim, depth, heads, dim_head, mlp_dim, nt, nh, nw, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
self.nt = nt
self.nh = nh
self.nw = nw
for _ in range(depth):
self.layers.append(nn.ModuleList(
[PreNorm(dim, FSAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
PreNorm(dim, FSAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
]))
def forward(self, x):
b = x.shape[0]
x = torch.flatten(x, start_dim=0, end_dim=1) # extract spatial tokens from x
for sp_attn, temp_attn, ff in self.layers:
sp_attn_x = sp_attn(x) + x # Spatial attention
# Reshape tensors for temporal attention
sp_attn_x = sp_attn_x.chunk(b, dim=0)
sp_attn_x = [temp[None] for temp in sp_attn_x]
sp_attn_x = torch.cat(sp_attn_x, dim=0).transpose(1, 2)
sp_attn_x = torch.flatten(sp_attn_x, start_dim=0, end_dim=1)
temp_attn_x = temp_attn(sp_attn_x) + sp_attn_x # Temporal attention
x = ff(temp_attn_x) + temp_attn_x # MLP
# Again reshape tensor for spatial attention
x = x.chunk(b, dim=0)
x = [temp[None] for temp in x]
x = torch.cat(x, dim=0).transpose(1, 2)
x = torch.flatten(x, start_dim=0, end_dim=1)
# Reshape vector to [b, nt*nh*nw, dim]
x = x.chunk(b, dim=0)
x = [temp[None] for temp in x]
x = torch.cat(x, dim=0)
x = torch.flatten(x, start_dim=1, end_dim=2)
return x
class FDATransformerEncoder(nn.Module):
"""Factorized Dot-product Attention Transformer Encoder"""
def __init__(self, dim, depth, heads, dim_head, mlp_dim, nt, nh, nw, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
self.nt = nt
self.nh = nh
self.nw = nw
for _ in range(depth):
self.layers.append(
PreNorm(dim, FDAttention(dim, nt, nh, nw, heads=heads, dim_head=dim_head, dropout=dropout)))
def forward(self, x):
for attn in self.layers:
x = attn(x) + x
return x
class ViViTBackbone(nn.Module):
""" Model-3 backbone of ViViT """
def __init__(self, t, h, w, patch_t, patch_h, patch_w, num_classes, dim, depth, heads, mlp_dim, dim_head=3,
channels=3, mode='tubelet', device='cuda', emb_dropout=0., dropout=0., model=3):
super().__init__()
assert t % patch_t == 0 and h % patch_h == 0 and w % patch_w == 0, "Video dimensions should be divisible by " \
"tubelet size "
self.T = t
self.H = h
self.W = w
self.channels = channels
self.t = patch_t
self.h = patch_h
self.w = patch_w
self.mode = mode
self.device = device
self.nt = self.T // self.t
self.nh = self.H // self.h
self.nw = self.W // self.w
tubelet_dim = self.t * self.h * self.w * channels
self.to_tubelet_embedding = nn.Sequential(
Rearrange('b c (t pt) (h ph) (w pw) -> b t (h w) (pt ph pw c)', pt=self.t, ph=self.h, pw=self.w),
nn.Linear(tubelet_dim, dim)
)
# repeat same spatial position encoding temporally
self.pos_embedding = nn.Parameter(torch.randn(1, 1, self.nh * self.nw, dim)).repeat(1, self.nt, 1, 1)
self.dropout = nn.Dropout(emb_dropout)
if model == 3:
self.transformer = FSATransformerEncoder(dim, depth, heads, dim_head, mlp_dim,
self.nt, self.nh, self.nw, dropout)
elif model == 4:
assert heads % 2 == 0, "Number of heads should be even"
self.transformer = FDATransformerEncoder(dim, depth, heads, dim_head, mlp_dim,
self.nt, self.nh, self.nw, dropout)
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, x):
""" x is a video: (b, C, T, H, W) """
tokens = self.to_tubelet_embedding(x)
tokens += self.pos_embedding
tokens = self.dropout(tokens)
x = self.transformer(tokens)
x = x.mean(dim=1)
x = self.to_latent(x)
return self.mlp_head(x)
if __name__ == '__main__':
device = torch.device('cpu')
x = torch.rand(32, 3, 32, 64, 64).to(device)
vivit = ViViTBackbone(32, 64, 64, 8, 4, 4, 10, 512, 6, 10, 8, model=3).to(device)
out = vivit(x)
print(out)