-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresnet18_v1_run_3.py
26 lines (24 loc) · 941 Bytes
/
resnet18_v1_run_3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import tvm
import numpy as np
# tvm module for compiled functions.
loaded_lib = tvm.module.load("./model/model_lib.so")
# json graph
loaded_json = open("./model/model_graph.json").read()
# parameters in binary
loaded_params = bytearray(open("./model/model_graph.params", "rb").read())
# data in binary
x_shape = (1, 3, 224, 224)
# x64, np.float is 8 bytes, we need to change to 4 bytes float using "float32"
x = np.fromfile("./data/cat.bin", dtype="float32")
x.shape = x_shape
fcreate = tvm.get_global_func("tvm.graph_runtime.create")
ctx = tvm.cpu(0)
gmodule = fcreate(loaded_json, loaded_lib, ctx.device_type, ctx.device_id)
set_input, get_output, run = gmodule["set_input"], gmodule["get_output"], gmodule["run"]
set_input("data", tvm.nd.array(x.astype('float32')))
gmodule["load_params"](loaded_params)
run()
out = tvm.nd.empty((1000,), "float32")
get_output(0, out)
top1 = np.argmax(out.asnumpy())
print('TVM prediction top-1:', top1)