Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[Relay Translator] Use OpStrategy for lowering (#130)
Browse files Browse the repository at this point in the history
* [Relay Translator] Use OpStrategy for lowering

* Reflect feedback and fix lint issue

* Consider contexts for PassContext, Target, .. for both pass application and lowering
  • Loading branch information
sunggg authored and yongwww committed May 10, 2022
1 parent 0142d5d commit 1f75be5
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 154 deletions.
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .apply_history_best import ApplyHistoryBest
from .extracted_task import ExtractedTask
from .relay_integration import extract_task_from_relay
from .relax_integration import extract_task_from_relax
from .search_strategy import MeasureCandidate
from .tune import TuneConfig, tune_relay, tune_relax, tune_te, tune_tir
from .tune_context import TuneContext
38 changes: 30 additions & 8 deletions python/tvm/meta_schedule/relax_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
# specific language governing permissions and limitations
# under the License.
"""Meta schedule integration with high-level IR"""
from typing import List, Union, Tuple, Dict
from typing import Any, List, Union, Tuple, Dict, Optional

import tvm
from tvm.ir import IRModule, structural_hash, structural_equal
from tvm.meta_schedule import ExtractedTask
from tvm.target import Target
from tvm.relax.expr import Function as RelaxFunc
from tvm.relax.utils import tir_partitioner
from tvm.runtime import NDArray


def deduplicate_extracted_tasks(
Expand Down Expand Up @@ -67,7 +69,15 @@ def deduplicate_extracted_tasks(
return dedup, count


def extract_task_from_relax(mod: Union[IRModule, RelaxFunc], target: Target) -> List[ExtractedTask]:
def extract_task_from_relax(
mod: Union[IRModule, RelaxFunc],
target: Target,
params: Optional[Dict[str, NDArray]] = None,
*,
opt_level: int = 3,
pass_config: Optional[Dict[str, Any]] = None,
disabled_pass: Optional[List[str]] = None,
) -> List[ExtractedTask]:
"""Extract tuning tasks from a relax program.
Parameters
Expand All @@ -87,13 +97,25 @@ def extract_task_from_relax(mod: Union[IRModule, RelaxFunc], target: Target) ->
if not isinstance(target, Target):
target = Target(target)

if disabled_pass is None:
disabled_pass = []
if pass_config is None:
pass_config = {}

if params:
mod = tvm.relax.transform.BindParams("main", params)(mod)

tir_partitions = tir_partitioner(mod)
tir_mods, tir_counts = deduplicate_extracted_tasks(tir_partitions)

tasks = []
for i, tir_mod in enumerate(tir_mods):
task_name = tir_mod.get_global_vars()[0].name_hint
# The second arg to ExtractedTask is supposed to be a high-level IRModule,
# passing tir_mod as a workaround.
tasks.append(ExtractedTask(task_name, tir_mod, target, [tir_mod], tir_counts[i]))
with target, tvm.transform.PassContext(
opt_level=opt_level,
config=pass_config,
disabled_pass=disabled_pass,
):
for i, tir_mod in enumerate(tir_mods):
task_name = tir_mod.get_global_vars()[0].name_hint
# The second arg to ExtractedTask is supposed to be a high-level IRModule,
# passing tir_mod as a workaround.
tasks.append(ExtractedTask(task_name, tir_mod, target, [tir_mod], tir_counts[i]))
return tasks
208 changes: 62 additions & 146 deletions python/tvm/relax/testing/relay_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,138 +18,26 @@
"""Relay to Relax translator."""

from __future__ import annotations
from typing import Dict, List
from typing import Any, Dict, List, Optional
import tvm
from tvm.ir.module import IRModule
from tvm import relax, relay, topi
from tvm import relax, relay
from tvm.relax.testing import nn


class RelayOpConverter(object):
"""A helper class for holding Relay op converters."""

@classmethod
def get_converter(cls):
"""Get converter.
:return: converter, which should be `_impl`.
"""

if hasattr(cls, "_impl"):
return getattr(cls, "_impl")
raise tvm.error.OpNotImplemented("Operator {} is not supported.".format(cls.__name__))


class Dense(RelayOpConverter):
"""Operator converter for nn.dense."""

@classmethod
def _impl(cls, inputs, attrs):
return nn.emit_te(topi.nn.dense, *inputs)


class BatchNorm(RelayOpConverter):
"""Operator converter for nn.batch_norm."""

@classmethod
def _impl(cls, inputs, attrs):
new_attrs = attr_convert(attrs)
return nn.emit_te(topi.nn.batch_norm, *inputs, **new_attrs)


class Conv2D(RelayOpConverter):
"""Operator converter for nn.conv2d."""

@classmethod
def _impl(cls, inputs, attrs):
new_inputs = [*inputs]
if attrs is not None:
new_inputs.append(attrs["strides"])
new_inputs.append(attrs["padding"])
new_inputs.append(attrs["dilation"])
else:
raise RuntimeError("attrs must be provided to conv2d op.")
return nn.emit_te(topi.nn.conv2d_nchw, *new_inputs)


class BatchMatmul(RelayOpConverter):
"""Operator converter for nn.batch_matmul."""

@classmethod
def _impl(cls, inputs, attrs):
new_attrs = attr_convert(attrs)
if "out_dtype" in new_attrs:
new_attrs["out_dtype"] = None
if "transpose_a" in new_attrs:
new_attrs["transpose_a"] = bool(new_attrs["transpose_a"])
if "transpose_b" in new_attrs:
new_attrs["transpose_b"] = bool(new_attrs["transpose_b"])
return nn.emit_te(topi.nn.batch_matmul, *inputs, **new_attrs)


class Softmax(RelayOpConverter):
"""Operator converter for softmax."""

@classmethod
def _impl(cls, inputs, attrs):
new_attrs = attr_convert(attrs)
return nn.emit_te(topi.nn.softmax, *inputs, **new_attrs)


# convert_map defines maps of name to converter functor(callable)
# use attr_convert if attributes need to be converted
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping (fusion), write custom topi func

# Minimal set of ops for transformer
def get_convert_map():
return {
"nn.dense": Dense.get_converter(),
"nn.batch_norm": BatchNorm.get_converter(),
"nn.conv2d": Conv2D.get_converter(),
"nn.batch_matmul": BatchMatmul.get_converter(),
"nn.softmax": Softmax.get_converter(),
}


def convert_operator(op_type: str, inputs: List[relax.Expr], attrs: Dict = None):
"""Convert from Relay operator to Relax operator/topi function.
The converter must specify conversions explicitly for incompatible name, and
apply handlers to operator attributes.
Parameters
----------
op_type : str
Operator name, such as Convolution, FullyConnected
inputs : list of Expr
List of input inputs.
attrs : dict
Dict of operator attributes
Returns
-------
func : tvm.relay.function.Function
Converted relay function
"""
convert_map = get_convert_map()
if op_type in convert_map:
func = convert_map[op_type](inputs, attrs)
else:
raise tvm.error.OpNotImplemented("Operator {} is not supported.".format(op_type))
return func


def attr_convert(attrs: tvm.ir.Attrs) -> Dict:
"""Convert attributes to a dict."""
attrs_dict = {}

for k in attrs.keys():
attrs_dict[k] = attrs[k]

return attrs_dict


def from_relay(func: relay.Function) -> IRModule:
from tvm.relay.backend.te_compiler import select_implementation
from tvm.runtime import NDArray
from tvm.target import Target
from tvm.meta_schedule.utils import autotvm_silencer


def from_relay(
func: relay.Function,
target: Target,
relay_params: Optional[Dict[str, NDArray]] = None,
*,
opt_level: int = 3,
pass_config: Optional[Dict[str, Any]] = None,
disabled_pass: Optional[List[str]] = None,
) -> IRModule:
"""Convert a Relay function into a Relax program.
Parameters
Expand All @@ -166,8 +54,21 @@ def from_relay(func: relay.Function) -> IRModule:
var_map = {}
# The output of the function
output_var = None

if not isinstance(target, Target):
target = Target(target)
if disabled_pass is None:
disabled_pass = []
if pass_config is None:
pass_config = {
"relay.FuseOps.max_depth": 1, # Disable relay fusion
"relay.backend.use_meta_schedule": True,
}

if relay_params:
func = relay.build_module.bind_params_by_name(func, relay_params)

params = []
convert_map = get_convert_map()

def visit_func(node):
nonlocal output_var
Expand All @@ -182,25 +83,29 @@ def visit_func(node):
elif isinstance(node, relay.Call):
args = node.args
new_args = []
te_inputs = []
for arg in args:
if arg in var_map:
new_args.append(var_map[arg])
te_inputs.append(tvm.relax.expr.te_tensor(new_args[-1]))

op_name = node.op.name
attrs = node.attrs
compute_func = node.op.get_attr("FTVMCompute")
if compute_func is None:
if node.op.name not in convert_map:
raise tvm.error.OpNotImplemented(
"Operator {} is not supported.".format(op_name)
)
var = convert_operator(op_name, new_args, attrs)
else:
name_hint = op_name.split(".")[-1]
var = bb.emit_te(
compute_func, attrs, new_args, node.checked_type, primfunc_name_hint=name_hint
)

out_type = node.checked_type

best_impl, outputs = select_implementation(
node.op,
attrs,
te_inputs,
out_type,
target,
use_autotvm=False,
)
compute_func = best_impl.compute
name_hint = op_name.split(".")[-1]
var = bb.emit_te(
compute_func, attrs, new_args, node.checked_type, primfunc_name_hint=name_hint
)
output_var = var
var_map[node] = var
elif isinstance(node, relay.Constant):
Expand Down Expand Up @@ -234,8 +139,19 @@ def visit_func(node):
else:
raise TypeError("{} is not supported yet.".format(str(type(node))))

bb = relax.BlockBuilder()
with bb.function("main"):
relay.analysis.post_order_visit(func, visit_func)
# List of subset of relay->relay optimizations
# See src/relay/backend/utils.cc::GetPassPrefix() for full list
seq = tvm.get_global_func("relay.backend.GetPassPrefixSeq")(True, True)

# Since optimization passes and OpStrategy are highly context-dependent,
# we match the exact same context with `extract_task_from_relay()` env
with autotvm_silencer(), target, tvm.transform.PassContext(
opt_level=opt_level, config=pass_config, disabled_pass=disabled_pass
):
mod = tvm.IRModule.from_expr(func)
mod = seq(mod)
bb = relax.BlockBuilder()
with bb.function("main"):
relay.analysis.post_order_visit(mod["main"], visit_func)

return bb.get()
7 changes: 7 additions & 0 deletions src/relay/backend/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,13 @@ void BindParamsInModule(IRModule mod, Map<String, runtime::NDArray> params) {
BindParamsInModule(mod, params_tmp);
}

TVM_REGISTER_GLOBAL("relay.backend.GetPassPrefixSeq")
.set_body_typed([](bool is_homogeneous, bool is_vm) {
auto pass_seqs = GetPassPrefix(is_homogeneous, is_vm);
transform::Sequential seq(pass_seqs);
return seq;
});

} // namespace backend
} // namespace relay
} // namespace tvm
Loading

0 comments on commit 1f75be5

Please sign in to comment.