forked from rwightman/gen-efficientnet-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathonnx_optimize.py
69 lines (55 loc) · 2.07 KB
/
onnx_optimize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import onnx
from onnx import optimizer
parser = argparse.ArgumentParser(description="Optimize ONNX model")
parser.add_argument("model", help="The ONNX model")
parser.add_argument("--output", required=True, help="The optimized model output filename")
def traverse_graph(graph, prefix=''):
content = []
indent = prefix + ' '
graphs = []
num_nodes = 0
for node in graph.node:
pn, gs = onnx.helper.printable_node(node, indent, subgraphs=True)
assert isinstance(gs, list)
content.append(pn)
graphs.extend(gs)
num_nodes += 1
for g in graphs:
g_count, g_str = traverse_graph(g)
content.append('\n' + g_str)
num_nodes += g_count
return num_nodes, '\n'.join(content)
def main():
args = parser.parse_args()
onnx_model = onnx.load(args.model)
num_original_nodes, original_graph_str = traverse_graph(onnx_model.graph)
# Optimizer passes to perform
passes = [
#'eliminate_deadend',
'eliminate_identity',
'eliminate_nop_dropout',
'eliminate_nop_pad',
'eliminate_nop_transpose',
'eliminate_unused_initializer',
'extract_constant_to_initializer',
'fuse_add_bias_into_conv',
'fuse_bn_into_conv',
'fuse_consecutive_concats',
'fuse_consecutive_reduce_unsqueeze',
'fuse_consecutive_squeezes',
'fuse_consecutive_transposes',
#'fuse_matmul_add_bias_into_gemm',
'fuse_pad_into_conv',
#'fuse_transpose_into_gemm',
#'lift_lexical_references',
]
# Apply the optimization on the original serialized model
optimized_model = optimizer.optimize(onnx_model, passes)
num_optimized_nodes, optimzied_graph_str = traverse_graph(optimized_model.graph)
print('==> The model after optimization:\n{}\n'.format(optimzied_graph_str))
print('==> The optimized model has {} nodes, the original had {}.'.format(num_optimized_nodes, num_original_nodes))
# Save the ONNX model
onnx.save(optimized_model, args.output)
if __name__ == "__main__":
main()