-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathmodel.py
159 lines (121 loc) · 5.21 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import math
import six
import tensorflow as tf
from einops.layers.tensorflow import Rearrange
def gelu(x):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
Args:
x: float Tensor to perform activation.
Returns:
`x` with the GELU activation applied.
"""
cdf = 0.5 * (1.0 + tf.tanh(
(math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
def get_activation(identifier):
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
It checks string first and if it is one of customized activation not in TF,
the corresponding activation will be returned. For non-customized activation
names and callable identifiers, always fallback to tf.keras.activations.get.
Args:
identifier: String name of the activation function or callable.
Returns:
A Python function corresponding to the activation function.
"""
if isinstance(identifier, six.string_types):
name_to_fn = {"gelu": gelu}
identifier = str(identifier).lower()
if identifier in name_to_fn:
return tf.keras.activations.get(name_to_fn[identifier])
return tf.keras.activations.get(identifier)
class Residual(tf.keras.Model):
def __init__(self, fn):
super().__init__()
self.fn = fn
def call(self, x):
return self.fn(x) + x
class PreNorm(tf.keras.Model):
def __init__(self, dim, fn):
super().__init__()
self.norm = tf.keras.layers.LayerNormalization(epsilon=1e-5)
self.fn = fn
def call(self, x):
return self.fn(self.norm(x))
class FeedForward(tf.keras.Model):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = tf.keras.Sequential([tf.keras.layers.Dense(hidden_dim, activation=get_activation('gelu')),
tf.keras.layers.Dense(dim)])
def call(self, x):
return self.net(x)
class Attention(tf.keras.Model):
def __init__(self, dim, heads = 8):
super().__init__()
self.heads = heads
self.scale = dim ** -0.5
self.to_qkv = tf.keras.layers.Dense(dim * 3, use_bias=False)
self.to_out = tf.keras.layers.Dense(dim)
self.rearrange_qkv = Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = self.heads)
self.rearrange_out = Rearrange('b h n d -> b n (h d)')
def call(self, x):
qkv = self.to_qkv(x)
qkv = self.rearrange_qkv(qkv)
q = qkv[0]
k = qkv[1]
v = qkv[2]
dots = tf.einsum('bhid,bhjd->bhij', q, k) * self.scale
attn = tf.nn.softmax(dots,axis=-1)
out = tf.einsum('bhij,bhjd->bhid', attn, v)
out = self.rearrange_out(out)
out = self.to_out(out)
return out
class Transformer(tf.keras.Model):
def __init__(self, dim, depth, heads, mlp_dim):
super().__init__()
layers = []
for _ in range(depth):
layers.extend([
Residual(PreNorm(dim, Attention(dim, heads = heads))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
])
self.net = tf.keras.Sequential(layers)
def call(self, x):
return self.net(x)
class ViT(tf.keras.Model):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
self.patch_size = patch_size
self.dim = dim
self.pos_embedding = self.add_weight("position_embeddings",
shape=[num_patches + 1,
dim],
initializer=tf.keras.initializers.RandomNormal(),
dtype=tf.float32)
self.patch_to_embedding = tf.keras.layers.Dense(dim)
self.cls_token = self.add_weight("cls_token",
shape=[1,
1,
dim],
initializer=tf.keras.initializers.RandomNormal(),
dtype=tf.float32)
self.rearrange = Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)
self.transformer = Transformer(dim, depth, heads, mlp_dim)
self.to_cls_token = tf.identity
self.mlp_head = tf.keras.Sequential([tf.keras.layers.Dense(mlp_dim, activation=get_activation('gelu')),
tf.keras.layers.Dense(num_classes)])
@tf.function
def call(self, img):
shapes = tf.shape(img)
x = self.rearrange(img)
x = self.patch_to_embedding(x)
cls_tokens = tf.broadcast_to(self.cls_token,(shapes[0],1,self.dim))
x = tf.concat((cls_tokens, x), axis=1)
x += self.pos_embedding
x = self.transformer(x)
x = self.to_cls_token(x[:, 0])
return self.mlp_head(x)