From 178a197efa0f2d295c07bed1b79610fd5693a8ff Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Tue, 13 Jul 2021 19:00:03 +0300 Subject: [PATCH] [Relay] Modify create_executor to pass params (#8418) * Overload create_executor to accept params * [fix] Add stringdoc for new param in create_executor --- python/tvm/relay/build_module.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index aa826aee57a1..d1cf1c9bea2f 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -511,7 +511,7 @@ def _graph_wrapper(*args, **kwargs): return _graph_wrapper -def create_executor(kind="debug", mod=None, device=None, target="llvm"): +def create_executor(kind="debug", mod=None, device=None, target="llvm", params=None): """Factory function to create an executor. Example @@ -544,6 +544,10 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm"): target : :py:class:`tvm.Target` The corresponding context + params : dict of str to NDArray + Input parameters to the graph that do not change + during inference time. + Returns ------- executor : :py:class:`~tvm.relay.backend.interpreter.Executor` @@ -555,6 +559,9 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm"): else: device = _nd.device(str(target), 0) + if params is not None: + mod = IRModule.from_expr(bind_params_by_name(mod["main"], params)) + if isinstance(target, str): target = Target(target) if kind == "debug":