-
Notifications
You must be signed in to change notification settings - Fork 237
/
Copy pathmobilevit_v2.py
77 lines (69 loc) · 2.59 KB
/
mobilevit_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
from typing import Dict
from utils.math_utils import bound_fn, make_divisible
def get_configuration(opts) -> Dict:
width_multiplier = getattr(opts, "model.classification.mitv2.width_multiplier", 1.0)
ffn_multiplier = (
2 # bound_fn(min_val=2.0, max_val=4.0, value=2.0 * width_multiplier)
)
mv2_exp_mult = 2 # max(1.0, min(2.0, 2.0 * width_multiplier))
layer_0_dim = bound_fn(min_val=16, max_val=64, value=32 * width_multiplier)
layer_0_dim = int(make_divisible(layer_0_dim, divisor=8, min_value=16))
config = {
"layer0": {
"img_channels": 3,
"out_channels": layer_0_dim,
},
"layer1": {
"out_channels": int(make_divisible(64 * width_multiplier, divisor=16)),
"expand_ratio": mv2_exp_mult,
"num_blocks": 1,
"stride": 1,
"block_type": "mv2",
},
"layer2": {
"out_channels": int(make_divisible(128 * width_multiplier, divisor=8)),
"expand_ratio": mv2_exp_mult,
"num_blocks": 2,
"stride": 2,
"block_type": "mv2",
},
"layer3": { # 28x28
"out_channels": int(make_divisible(256 * width_multiplier, divisor=8)),
"attn_unit_dim": int(make_divisible(128 * width_multiplier, divisor=8)),
"ffn_multiplier": ffn_multiplier,
"attn_blocks": 2,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"block_type": "mobilevit",
},
"layer4": { # 14x14
"out_channels": int(make_divisible(384 * width_multiplier, divisor=8)),
"attn_unit_dim": int(make_divisible(192 * width_multiplier, divisor=8)),
"ffn_multiplier": ffn_multiplier,
"attn_blocks": 4,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"block_type": "mobilevit",
},
"layer5": { # 7x7
"out_channels": int(make_divisible(512 * width_multiplier, divisor=8)),
"attn_unit_dim": int(make_divisible(256 * width_multiplier, divisor=8)),
"ffn_multiplier": ffn_multiplier,
"attn_blocks": 3,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"block_type": "mobilevit",
},
"last_layer_exp_factor": 4,
}
return config