From 349d736a4adefbec90a89ae290003b9043de960d Mon Sep 17 00:00:00 2001 From: ibsidorenko <98739392+ibsidorenko@users.noreply.github.com> Date: Wed, 28 Dec 2022 00:10:15 +0300 Subject: [PATCH] [QNN] Change in Pass Context for lookup table calculation (#13660) Motivation: It is possible to disable specific passes through the "disabled_pass" parameter in the Pass Context. These "disabled" passes can be optional for one target and mandatory for another one. Since lookup table for some QNN operations (tanh, round and etc.) is calculated on the host and some of disabled passes can be required for the host, no need to disable these passes. This constant calculation/ evaluation is orthogonal to the compilation process for specific target. What was changed: This commit creates its own compilation Pass Context for lookup table calculation and evaluation (for elemwise QNN ops: tanh, sqrt ...). --- python/tvm/relay/qnn/op/canonicalizations.py | 23 ++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/qnn/op/canonicalizations.py b/python/tvm/relay/qnn/op/canonicalizations.py index 1f2c57c6da34..6bfcd34aba90 100644 --- a/python/tvm/relay/qnn/op/canonicalizations.py +++ b/python/tvm/relay/qnn/op/canonicalizations.py @@ -23,10 +23,25 @@ def run_const_expr(expr: "relay.Expr") -> np.ndarray: - """Evaluate a const expression, receiving result as np array.""" - mod = tvm.IRModule.from_expr(expr) - vm_exe = relay.create_executor("vm", mod=mod) - return vm_exe.evaluate()().asnumpy() + """Evaluate a const expression, receiving result as np array. + + If a number of passes are disabled in the current Pass Context, then there is no need to disable + these passes for const expression evaluation as well. That's why we use empty list + "disabled_pass=[]", all other arguments are inherited from the current Pass Context. + """ + curr_pass_ctx = tvm.ir.transform.PassContext.current() + with tvm.ir.transform.PassContext( + opt_level=curr_pass_ctx.opt_level, + required_pass=curr_pass_ctx.required_pass, + disabled_pass=[], + instruments=curr_pass_ctx.instruments, + config=curr_pass_ctx.config, + ): + mod = tvm.IRModule.from_expr(expr) + vm_exe = relay.create_executor("vm", mod=mod) + output = vm_exe.evaluate()().asnumpy() + + return output def create_integer_lookup_table(