Skip to content

Commit

Permalink
[Tutorial][Executor] Fix the usage of executors in tutorials (apache#…
Browse files Browse the repository at this point in the history
…8586)

* fix: executor usage for keras tutorial

* fix: executor usage for onnx tutorial

* [Tutorial][Executor] Fix executors in tutorials
  • Loading branch information
ganler authored and ylc committed Jan 13, 2022
1 parent bcc3469 commit e6a1b1b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
3 changes: 2 additions & 1 deletion tutorials/dev/bring_your_own_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,9 @@ def get_cat_image():
######################################################################
# It's easy to execute MobileNet with native TVM:

ex = tvm.relay.create_executor("graph", mod=module, params=params)
input = get_cat_image()
result = tvm.relay.create_executor("graph", mod=module).evaluate()(input, **params).numpy()
result = ex.evaluate()(input).numpy()
# print first 10 elements
print(result.flatten()[:10])

Expand Down
4 changes: 2 additions & 2 deletions tutorials/frontend/from_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@
# due to a latent bug. Note that the pass context only has an effect within
# evaluate() and is not captured by create_executor().
with tvm.transform.PassContext(opt_level=0):
model = relay.build_module.create_executor("graph", mod, dev, target).evaluate()
model = relay.build_module.create_executor("graph", mod, dev, target, params).evaluate()


######################################################################
# Execute on TVM
# ---------------
dtype = "float32"
tvm_out = model(tvm.nd.array(data.astype(dtype)), **params)
tvm_out = model(tvm.nd.array(data.astype(dtype)))
top1_tvm = np.argmax(tvm_out.numpy()[0])

#####################################################################
Expand Down
6 changes: 4 additions & 2 deletions tutorials/frontend/from_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,15 @@
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)

with tvm.transform.PassContext(opt_level=1):
compiled = relay.build_module.create_executor("graph", mod, tvm.cpu(0), target).evaluate()
executor = relay.build_module.create_executor(
"graph", mod, tvm.cpu(0), target, params
).evaluate()

######################################################################
# Execute on TVM
# ---------------------------------------------
dtype = "float32"
tvm_output = compiled(tvm.nd.array(x.astype(dtype)), **params).numpy()
tvm_output = executor(tvm.nd.array(x.astype(dtype))).numpy()

######################################################################
# Display results
Expand Down

0 comments on commit e6a1b1b

Please sign in to comment.