Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[microNPU] Refactor Relay to TIR hook #10599

Merged
merged 4 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 51 additions & 47 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

import tvm
from tvm import relay
from tvm import ir
from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
from tvm.relay.backend.contrib.ethosu.tir.compiler import LowerToTIR
from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants
from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator
Expand Down Expand Up @@ -112,30 +111,24 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
return new_call


@ir.transform.module_pass(opt_level=1, name="LUTsOptimizer")
@util.create_npu_function_pass(opt_level=1)
class LUTsOptimizer:
"""Register LUTsOptimizer as a relay pass."""

def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule:
"""Visit relay nodes in the given module.
def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
"""Visit relay nodes in the given NPU function.

Parameters
----------
func : tvm.relay.function.Function
The function to apply the optimization pass for multiple LUTs to.
mod : tvm.IRModule
The module to apply the optimization pass for multiple LUTs to.

Returns
-------
mod : tvm.IRModule
New module with optimized LUTs.
"""
assert len(mod.functions.items()) == 1, "Module can only contain one function."
global_var, func = mod.functions.items()[0]
optimized_func = OptimizeLUTs().visit(func)
mod.update_func(global_var, optimized_func)
return mod
return OptimizeLUTs().visit(func)

def __call__(self, *args, **kwargs):
pass
Expand Down Expand Up @@ -272,30 +265,27 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
return super().visit_call(call)


@ir.transform.module_pass(opt_level=1, name="LayoutOptimizer")
@util.create_npu_function_pass(opt_level=1)
class LayoutOptimizer:
"""Register LayoutOptimizer as a Relay pass."""

OPTIMIZE_OPS = {
"contrib.ethosu.conv2d": op.ethosu_conv2d,
"contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
"contrib.ethosu.pooling": op.ethosu_pooling,
"contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise,
"contrib.ethosu.unary_elementwise": op.ethosu_unary_elementwise,
}

def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule:
def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
"""A pass to optimize the layout of NPU operations. If both the
producer and consumer of a tensor are NPU operators, then the
layout is converted from NHWC to NHCWB16 as this is the layout NPU
uses internally."""
assert len(mod.functions.items()) == 1, "Module can only contain one function."
global_var, func = mod.functions.items()[0]
analyze = AnalyzeConsumers(self.OPTIMIZE_OPS)

optimize_ops = {
"contrib.ethosu.conv2d": op.ethosu_conv2d,
"contrib.ethosu.depthwise_conv2d": op.ethosu_depthwise_conv2d,
"contrib.ethosu.pooling": op.ethosu_pooling,
"contrib.ethosu.binary_elementwise": op.ethosu_binary_elementwise,
"contrib.ethosu.unary_elementwise": op.ethosu_unary_elementwise,
}

analyze = AnalyzeConsumers(optimize_ops)
analyze.visit(func)
optimized_func = LayoutOptimization(analyze.npu_consumers, self.OPTIMIZE_OPS).visit(func)
mod.update_func(global_var, optimized_func)
return mod
return LayoutOptimization(analyze.npu_consumers, optimize_ops).visit(func)

def __call__(self, *args, **kwargs):
pass
Expand All @@ -312,6 +302,22 @@ def IdentityOptimizer(): # pylint: disable=invalid-name
return _ffi_api.IdentityOptimizer()


def OutlineCompilerFunctions(compiler_name): # pylint: disable=invalid-name
"""Pass that outlines functions given a named Compiler attribute.

Parameters
----------
compiler_name
The name of the compiler to look for and outline.

Return
------
Pass
The module pass.
"""
return _ffi_api.OutlineCompilerFunctions(compiler_name)


@tvm._ffi.register_func("relay.ext.ethos-u.constant_updater")
def constant_updater(expr, symbol): # pylint: disable=unused-argument
"""
Expand All @@ -322,43 +328,41 @@ def constant_updater(expr, symbol): # pylint: disable=unused-argument
return dict()


@tvm._ffi.register_func("relay.ext.ethos-u.relay_to_tir_func")
def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc:
@tvm._ffi.register_func("relay.ext.ethos-u.relay_to_tir")
def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
"""
This is the hook for python-based lowering of relay function
that gets offloaded to the microNPU.
This is the hook for python-based lowering of a Relay module which lowers NPU
external functions to TIR.

Parameters
----------
ext_func : relay.Function
This is the partitioned relay function
mod : tvm.ir.IRModule
This is the Relay module.

Returns
-------
primfunc : tir.PrimFunc
This returns the scheduled PrimFunc
mod : tvm.ir.TRModule
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved
The Relay module with scheduled NPU external functions.
"""
assert len(ext_func.params) == 1
mod = tvm.IRModule()
mod["main"] = ext_func
mod = OutlineCompilerFunctions("ethos-u")(mod)
mod = LegalizeEthosU()(mod)
mod = LUTsOptimizer()(mod)
mod = IdentityOptimizer()(mod)
mod = LayoutOptimizer()(mod)
mod = relay.transform.InferType()(mod)

device_contexts = {
gv: "ethos-u" for gv, _ in filter(lambda x: util.is_npu_func(x[1]), mod.functions.items())
}
mod = mod.with_attr("device_contexts", device_contexts)

# We are currently using copy_constants scheduler In the long run,
# this should be a single intelligent and a composite scheduler
# that can perform scheduling based on user inputs such as
# scratch memory size.
tir_mod, const_dict = lower_to_tir(mod["main"], copy_constants())

for param in const_dict.keys():
const_dict[param] = tvm.nd.array(const_dict[param])
mod = LowerToTIR(copy_constants)(mod)

primfunc = tir_mod["main"]
primfunc = primfunc.with_attr("global_symbol", ext_func.attrs["global_symbol"])
primfunc = primfunc.with_attr("ethos-u.constants", const_dict)
return primfunc
return mod


@tvm._ffi.register_func("relay.ext.ethos-u.primfunc_to_artifact")
Expand Down
Loading