-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmobilenet.py
84 lines (73 loc) · 5.25 KB
/
mobilenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import mxnet as mx
from utils import symbol_utils
def Act(data, act_type, name):
body = mx.sym.Activation(data=data, act_type='relu', name=name)
return body
def Conv(data, num_filter=1, kernel=(1, 1), stride=(1, 1), pad=(0, 0), num_group=1, name=None, suffix=''):
conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, num_group=num_group, stride=stride,
pad=pad, no_bias=True, name='%s%s_conv2d' % (name, suffix))
bn = mx.sym.BatchNorm(data=conv, name='%s%s_batchnorm' % (name, suffix), fix_gamma=True)
act = Act(data=bn, act_type='relu', name='%s%s_relu' % (name, suffix))
return act
def ConvOnly(data, num_filter=1, kernel=(1, 1), stride=(1, 1), pad=(0, 0), num_group=1, name=None, suffix=''):
conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, num_group=num_group, stride=stride,
pad=pad, no_bias=True, name='%s%s_conv2d' % (name, suffix))
return conv
def get_symbol(num_classes, **kwargs):
data = mx.symbol.Variable(name="data") # 224
data = data - 127.5
data = data * 0.0078125
version_input = kwargs.get('version_input', 1)
assert version_input >= 0
version_output = kwargs.get('version_output', 'E')
multiplier = kwargs.get('multiplier', 1.0)
fc_type = version_output
base_filter = int(32 * multiplier)
bf = base_filter
print(version_input, version_output, base_filter)
if version_input == 0:
conv_1 = Conv(data, num_filter=bf, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_1") # 224/112
else:
conv_1 = Conv(data, num_filter=bf, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_1") # 224/112
conv_2_dw = Conv(conv_1, num_group=bf, num_filter=bf, kernel=(3, 3), pad=(1, 1), stride=(1, 1),
name="conv_2_dw") # 112/112
conv_2 = Conv(conv_2_dw, num_filter=bf * 2, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_2") # 112/112
conv_3_dw = Conv(conv_2, num_group=bf * 2, num_filter=bf * 2, kernel=(3, 3), pad=(1, 1), stride=(2, 2),
name="conv_3_dw") # 112/56
conv_3 = Conv(conv_3_dw, num_filter=bf * 4, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_3") # 56/56
conv_4_dw = Conv(conv_3, num_group=bf * 4, num_filter=bf * 4, kernel=(3, 3), pad=(1, 1), stride=(1, 1),
name="conv_4_dw") # 56/56
conv_4 = Conv(conv_4_dw, num_filter=bf * 4, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_4") # 56/56
conv_5_dw = Conv(conv_4, num_group=bf * 4, num_filter=bf * 4, kernel=(3, 3), pad=(1, 1), stride=(2, 2),
name="conv_5_dw") # 56/28
conv_5 = Conv(conv_5_dw, num_filter=bf * 8, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_5") # 28/28
conv_6_dw = Conv(conv_5, num_group=bf * 8, num_filter=bf * 8, kernel=(3, 3), pad=(1, 1), stride=(1, 1),
name="conv_6_dw") # 28/28
conv_6 = Conv(conv_6_dw, num_filter=bf * 8, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_6") # 28/28
conv_7_dw = Conv(conv_6, num_group=bf * 8, num_filter=bf * 8, kernel=(3, 3), pad=(1, 1), stride=(2, 2),
name="conv_7_dw") # 28/14
conv_7 = Conv(conv_7_dw, num_filter=bf * 16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_7") # 14/14
conv_8_dw = Conv(conv_7, num_group=bf * 16, num_filter=bf * 16, kernel=(3, 3), pad=(1, 1), stride=(1, 1),
name="conv_8_dw") # 14/14
conv_8 = Conv(conv_8_dw, num_filter=bf * 16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_8") # 14/14
conv_9_dw = Conv(conv_8, num_group=bf * 16, num_filter=bf * 16, kernel=(3, 3), pad=(1, 1), stride=(1, 1),
name="conv_9_dw") # 14/14
conv_9 = Conv(conv_9_dw, num_filter=bf * 16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_9") # 14/14
conv_10_dw = Conv(conv_9, num_group=bf * 16, num_filter=bf * 16, kernel=(3, 3), pad=(1, 1), stride=(1, 1),
name="conv_10_dw") # 14/14
conv_10 = Conv(conv_10_dw, num_filter=bf * 16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_10") # 14/14
conv_11_dw = Conv(conv_10, num_group=bf * 16, num_filter=bf * 16, kernel=(3, 3), pad=(1, 1), stride=(1, 1),
name="conv_11_dw") # 14/14
conv_11 = Conv(conv_11_dw, num_filter=bf * 16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_11") # 14/14
conv_12_dw = Conv(conv_11, num_group=bf * 16, num_filter=bf * 16, kernel=(3, 3), pad=(1, 1), stride=(1, 1),
name="conv_12_dw") # 14/14
conv_12 = Conv(conv_12_dw, num_filter=bf * 16, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_12") # 14/14
conv_13_dw = Conv(conv_12, num_group=bf * 16, num_filter=bf * 16, kernel=(3, 3), pad=(1, 1), stride=(2, 2),
name="conv_13_dw") # 14/7
conv_13 = Conv(conv_13_dw, num_filter=bf * 32, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_13") # 7/7
conv_14_dw = Conv(conv_13, num_group=bf * 32, num_filter=bf * 32, kernel=(3, 3), pad=(1, 1), stride=(1, 1),
name="conv_14_dw") # 7/7
conv_14 = Conv(conv_14_dw, num_filter=bf * 32, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_14") # 7/7
body = conv_14
fc1 = symbol_utils.get_fc1(body, num_classes, fc_type)
return fc1