-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathMobileNetV3.py
349 lines (273 loc) · 12.7 KB
/
MobileNetV3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
from dropblock import DropBlockScheduled, DropBlock2D
def swish(x):
return x * x.sigmoid()
def hard_sigmoid(x, inplace=False):
return F.relu6(x + 3, inplace) / 6
def hard_swish(x, inplace=False):
return x * hard_sigmoid(x, inplace)
class HardSigmoid(nn.Module):
def __init__(self, inplace=False):
super(HardSigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
return hard_sigmoid(x, inplace=self.inplace)
class HardSwish(nn.Module):
def __init__(self, inplace=False):
super(HardSwish, self).__init__()
self.inplace = inplace
def forward(self, x):
return hard_swish(x, inplace=self.inplace)
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
# https://github.com/jonnedtc/Squeeze-Excitation-PyTorch/blob/master/networks.py
class SqEx(nn.Module):
def __init__(self, n_features, reduction=4):
super(SqEx, self).__init__()
if n_features % reduction != 0:
raise ValueError('n_features must be divisible by reduction (default = 4)')
self.linear1 = nn.Linear(n_features, n_features // reduction, bias=True)
self.nonlin1 = nn.ReLU(inplace=True)
self.linear2 = nn.Linear(n_features // reduction, n_features, bias=True)
self.nonlin2 = HardSigmoid(inplace=True)
def forward(self, x):
y = F.avg_pool2d(x, kernel_size=x.size()[2:4])
y = y.permute(0, 2, 3, 1)
y = self.nonlin1(self.linear1(y))
y = self.nonlin2(self.linear2(y))
y = y.permute(0, 3, 1, 2)
y = x * y
return y
class LinearBottleneck(nn.Module):
def __init__(self, inplanes, outplanes, expplanes, k=3, stride=1, drop_prob=0, num_steps=3e5, start_step=0,
activation=nn.ReLU, act_params={"inplace": True}, SE=False):
super(LinearBottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, expplanes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(expplanes)
self.db1 = DropBlockScheduled(DropBlock2D(drop_prob=0, block_size=7), start_value=0.,
stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
self.act1 = activation(**act_params) # first does have act according to MobileNetV2
self.conv2 = nn.Conv2d(expplanes, expplanes, kernel_size=k, stride=stride, padding=k // 2, bias=False,
groups=expplanes)
self.bn2 = nn.BatchNorm2d(expplanes)
self.db2 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0.,
stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
self.act2 = activation(**act_params)
self.se = SqEx(expplanes) if SE else lambda x: x
self.conv3 = nn.Conv2d(expplanes, outplanes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(outplanes)
self.db3 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0.,
stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
# self.act3 = activation(**act_params) # works worse
self.stride = stride
self.expplanes = expplanes
self.inplanes = inplanes
self.outplanes = outplanes
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.db1(out)
out = self.act1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.db2(out)
out = self.act2(out)
out = self.se(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.db3(out)
# out = self.act3(out)
if self.stride == 1 and self.inplanes == self.outplanes: # TODO: or add 1x1?
out += residual # No inplace if there is in-place activation before
return out
class LastBlockLarge(nn.Module):
def __init__(self, inplanes, num_classes, expplanes1, expplanes2):
super(LastBlockLarge, self).__init__()
self.conv1 = nn.Conv2d(inplanes, expplanes1, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(expplanes1)
self.act1 = HardSwish(inplace=True)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.conv2 = nn.Conv2d(expplanes1, expplanes2, kernel_size=1, stride=1)
self.act2 = HardSwish(inplace=True)
self.dropout = nn.Dropout(p=0.2, inplace=True)
self.fc = nn.Linear(expplanes2, num_classes)
self.expplanes1 = expplanes1
self.expplanes2 = expplanes2
self.inplanes = inplanes
self.num_classes = num_classes
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.act1(out)
out = self.avgpool(out)
out = self.conv2(out)
out = self.act2(out)
# flatten for input to fully-connected layer
out = out.view(out.size(0), -1)
out = self.fc(self.dropout(out))
return out
class LastBlockSmall(nn.Module):
def __init__(self, inplanes, num_classes, expplanes1, expplanes2):
super(LastBlockSmall, self).__init__()
self.conv1 = nn.Conv2d(inplanes, expplanes1, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(expplanes1)
self.act1 = HardSwish(inplace=True)
self.se = SqEx(expplanes1)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.conv2 = nn.Conv2d(expplanes1, expplanes2, kernel_size=1, stride=1, bias=False)
self.act2 = HardSwish(inplace=True)
self.dropout = nn.Dropout(p=0.2, inplace=True)
self.fc = nn.Linear(expplanes2, num_classes)
self.expplanes1 = expplanes1
self.expplanes2 = expplanes2
self.inplanes = inplanes
self.num_classes = num_classes
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.act1(out)
out = self.se(out)
out = self.avgpool(out)
out = self.conv2(out)
out = self.act2(out)
# flatten for input to fully-connected layer
out = out.view(out.size(0), -1)
out = self.fc(self.dropout(out))
return out
class MobileNetV3(nn.Module):
"""MobileNetV3 implementation.
"""
def __init__(self, num_classes=1000, scale=1., in_channels=3, drop_prob=0.0, num_steps=3e5, start_step=0,
small=False):
super(MobileNetV3, self).__init__()
self.num_steps = num_steps
self.start_step = start_step
self.scale = scale
self.num_classes = num_classes
self.small = small
# setting of bottlenecks blocks
self.bottlenecks_setting_large = [
# in, exp, out, s, k, dp, se, act
[16, 16, 16, 1, 3, 0, False, nn.ReLU], # -> 112x112
[16, 64, 24, 2, 3, 0, False, nn.ReLU], # -> 56x56
[24, 72, 24, 1, 3, 0, False, nn.ReLU], # -> 56x56
[24, 72, 40, 2, 5, 0, True, nn.ReLU], # -> 28x28
[40, 120, 40, 1, 5, 0, True, nn.ReLU], # -> 28x28
[40, 120, 40, 1, 5, 0, True, nn.ReLU], # -> 28x28
[40, 240, 80, 2, 3, drop_prob, False, HardSwish], # -> 14x14
[80, 200, 80, 1, 3, drop_prob, False, HardSwish], # -> 14x14
[80, 184, 80, 1, 3, drop_prob, False, HardSwish], # -> 14x14
[80, 184, 80, 1, 3, drop_prob, False, HardSwish], # -> 14x14
[80, 480, 112, 1, 3, drop_prob, True, HardSwish], # -> 14x14
[112, 672, 112, 1, 3, drop_prob, True, HardSwish], # -> 14x14
[112, 672, 160, 2, 5, drop_prob, True, HardSwish], # -> 7x7
[160, 960, 160, 1, 5, drop_prob, True, HardSwish], # -> 7x7
[160, 960, 160, 1, 5, drop_prob, True, HardSwish], # -> 7x7
]
self.bottlenecks_setting_small = [
# in, exp, out, s, k, dp, se, act
[16, 64, 16, 2, 3, 0, True, nn.ReLU], # -> 56x56
[16, 72, 24, 2, 3, 0, False, nn.ReLU], # -> 28x28
[24, 88, 24, 1, 3, 0, False, nn.ReLU], # -> 28x28
[24, 96, 40, 2, 5, 0, True, HardSwish], # -> 14x14
[40, 240, 40, 1, 5, drop_prob, True, HardSwish], # -> 14x14
[40, 240, 40, 1, 5, drop_prob, True, HardSwish], # -> 14x14
[40, 120, 48, 1, 5, drop_prob, True, HardSwish], # -> 14x14
[48, 144, 96, 1, 5, drop_prob, True, HardSwish], # -> 14x14
[96, 288, 96, 2, 5, drop_prob, True, HardSwish], # -> 7x7
[96, 576, 96, 1, 5, drop_prob, True, HardSwish], # -> 7x7
[96, 576, 96, 1, 5, drop_prob, True, HardSwish], # -> 7x7
]
self.bottlenecks_setting = self.bottlenecks_setting_small if small else self.bottlenecks_setting_large
for l in self.bottlenecks_setting:
l[0] = _make_divisible(l[0] * self.scale, 8)
l[1] = _make_divisible(l[1] * self.scale, 8)
l[2] = _make_divisible(l[2] * self.scale, 8)
self.conv1 = nn.Conv2d(in_channels, self.bottlenecks_setting[0][0], kernel_size=3, bias=False, stride=2,
padding=1)
self.bn1 = nn.BatchNorm2d(self.bottlenecks_setting[0][0])
self.act1 = HardSwish(inplace=True)
self.bottlenecks = self._make_bottlenecks()
# Last convolution has 1280 output channels for scale <= 1
self.last_exp2 = 1280 if self.scale <= 1 else _make_divisible(1280 * self.scale, 8)
if small:
self.last_exp1 = _make_divisible(576 * self.scale, 8)
self.last_block = LastBlockSmall(self.bottlenecks_setting[-1][2], num_classes, self.last_exp1,
self.last_exp2)
else:
self.last_exp1 = _make_divisible(960 * self.scale, 8)
self.last_block = LastBlockLarge(self.bottlenecks_setting[-1][2], num_classes, self.last_exp1,
self.last_exp2)
def _make_bottlenecks(self):
modules = OrderedDict()
stage_name = "Bottleneck"
# add LinearBottleneck
for i, setup in enumerate(self.bottlenecks_setting):
name = stage_name + "_{}".format(i)
module = LinearBottleneck(setup[0], setup[2], setup[1], k=setup[4], stride=setup[3], drop_prob=setup[5],
num_steps=self.num_steps, start_step=self.start_step, activation=setup[7],
act_params={"inplace": True}, SE=setup[6])
modules[name] = module
return nn.Sequential(modules)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
x = self.bottlenecks(x)
x = self.last_block(x)
return x
# TODO
model_urls = {
'mobilenetv3_large_1.0_224': 'https://github.com/Randl/MobileNetV3-pytorch/blob/master/results/mobilenetv3large-v1/model_best0-ec869f9b.pth',
}
def mobilenetv3(input_size=224, num_classes=1000, scale=1., in_channels=3, drop_prob=0.0, num_steps=3e5, start_step=0,
small=False, get_weights=True, progress=True):
model = MobileNetV3(num_classes=num_classes, scale=scale, in_channels=in_channels, drop_prob=drop_prob,
num_steps=num_steps, start_step=start_step, small=small)
name = 'mobilenetv3_{}_{}_{}'.format('small' if small else 'large', scale, input_size)
if get_weights:
if name in model_urls:
state_dict = load_state_dict_from_url(model_urls[name], progress=progress, map_location='cpu')
model.load_state_dict(state_dict)
else:
raise ValueError
return model
if __name__ == "__main__":
"""Testing
"""
model1 = MobileNetV3()
print(model1)
model2 = MobileNetV3(scale=0.35)
print(model2)
model3 = MobileNetV3(in_channels=2, num_classes=10)
print(model3)
x = torch.randn(1, 2, 224, 224)
print(model3(x))
model4_size = 32 * 10
model4 = MobileNetV3(num_classes=10)
print(model4)
x2 = torch.randn(1, 3, model4_size, model4_size)
print(model4(x2))
model5 = MobileNetV3(scale=0.35, small=True)
print(model2)