diff --git a/python/tvm/contrib/msc/core/codegen/codegen.py b/python/tvm/contrib/msc/core/codegen/codegen.py index c2711231f400..888f1bad4ebe 100644 --- a/python/tvm/contrib/msc/core/codegen/codegen.py +++ b/python/tvm/contrib/msc/core/codegen/codegen.py @@ -180,9 +180,10 @@ def visit_var_binding_(self, binding: relax.VarBinding) -> None: def _to_var(tensor: MSCTensor): v_name = tensor.alias if use_alias else graph.find_producer(tensor).name - return tvm.relax.Var( - v_name, tvm.relax.TensorStructInfo(tensor.get_shape(), tensor.dtype_name) - ) + dims = [ + d if isinstance(d, int) else tvm.tir.Var(d, "int64") for d in tensor.get_shape(True) + ] + return tvm.relax.Var(v_name, tvm.relax.TensorStructInfo(dims, tensor.dtype_name)) def _save_weights(folder: msc_utils.MSCDirectory): if weights: diff --git a/python/tvm/contrib/msc/core/frontend/translate.py b/python/tvm/contrib/msc/core/frontend/translate.py index cea021ade331..8e9bb0cf00d7 100644 --- a/python/tvm/contrib/msc/core/frontend/translate.py +++ b/python/tvm/contrib/msc/core/frontend/translate.py @@ -31,6 +31,44 @@ from tvm.contrib.msc.core.ir import MSCGraph, MSCTensor +def normalize_inputs(inputs: List[tuple]) -> List[tuple]: + """Normalize the inputs info + + Parameters + ---------- + inputs: list of + The inputs info. + + Returns + ------- + inputs: list of + The normalized inputs info. + """ + + recorded_vars = {} + + def _normalize_input(inp): + def _normalize(info): + if not isinstance(info, (tuple, list)): + return info + dims = [] + for dim in info: + if isinstance(dim, int): + dims.append(dim) + elif dim in recorded_vars: + dims.append(recorded_vars[dim]) + elif isinstance(dim, str): + recorded_vars[dim] = tvm.tir.Var(dim, "int64") + dims.append(recorded_vars[dim]) + else: + raise TypeError("Unexpected dim {} in shape {}".format(dim, info)) + return dims + + return [_normalize(i) for i in inp] + + return [_normalize_input(inp) for inp in inputs] + + def normalize_weights( t_weights: Dict[MSCTensor, tvm.nd.array], graph: MSCGraph ) -> Dict[str, tvm.nd.array]: diff --git a/python/tvm/contrib/msc/core/ir/graph.py b/python/tvm/contrib/msc/core/ir/graph.py index 19a16a375b7a..172f40e06a31 100644 --- a/python/tvm/contrib/msc/core/ir/graph.py +++ b/python/tvm/contrib/msc/core/ir/graph.py @@ -41,6 +41,8 @@ class MSCTensor(Object): The shape of the tensor. alias: string The alias of the tensor. + prims: list + The prims of the tensor. """ def __init__( @@ -50,15 +52,31 @@ def __init__( layout: str, shape: List[int], alias: Optional[str] = None, + prims: List[str] = None, ): if not isinstance(dtype, tvm.DataType): dtype = tvm.DataType(dtype) self.__init_handle_by_constructor__( - _ffi_api.MSCTensor, name, dtype, layout, shape, alias or "" + _ffi_api.MSCTensor, name, dtype, layout, shape, alias or "", prims or [] ) - def get_shape(self) -> List[int]: - return [int(i) for i in self.shape] + def get_shape(self, with_prims: bool = False) -> List[Union[int, str]]: + """Get shape of the tensor + + Parameters + ------- + with_prims: bool + Whether get shape with prims. + + Returns + ------- + shape: list + The shape of tensor. + """ + + if not self.prims or not with_prims: + return [int(i) for i in self.shape] + return [int(p) if p.isdigit() else p for p in self.prims] def get_size(self) -> int: return int(_ffi_api.MSCTensorGetSize(self)) @@ -98,7 +116,7 @@ def equal(self, other: Object) -> bool: if not isinstance(other, MSCTensor): return False - if self.get_shape() != other.get_shape(): + if self.get_shape(True) != other.get_shape(True): return False if self.dtype != other.dtype: return False @@ -124,7 +142,7 @@ def inspect(self) -> dict: The tensor description in json format. """ - tensor_des = {"name": self.alias, "shape": self.get_shape(), "dtype": self.dtype_name} + tensor_des = {"name": self.alias, "shape": self.get_shape(True), "dtype": self.dtype_name} tensor_des["layout"] = self.layout.name if self.layout else "" return tensor_des @@ -405,6 +423,30 @@ def equal(self, other: BaseJoint) -> bool: return msc_utils.dict_equal(self.get_attrs(), other.get_attrs()) +@tvm._ffi.register_object("msc.core.MSCPrim") +class MSCPrim(BaseJoint): + """Prim in MSCGraph + + Parameters + ---------- + index: int + The index of the prim. + name: string + The name of the prim. + optype: string + The optype of the prim. + attrs: dict + The attributes of the node. + parents: list + The parents of the prim. + """ + + def __init__( + self, index: int, name: str, optype: str, attrs: Dict[str, str], parents: List[BaseJoint] + ): + self.__init_handle_by_constructor__(_ffi_api.MSCPrim, index, name, optype, attrs, parents) + + @tvm._ffi.register_object("msc.core.WeightJoint") class WeightJoint(BaseJoint): """Node in WeightGraph @@ -586,6 +628,22 @@ def find_node(self, name: str) -> MSCJoint: return _ffi_api.MSCGraphFindNode(self, name) + def find_prim(self, name: str) -> MSCPrim: + """Find prim by name. + + Parameters + ---------- + name: string + The name of the prim. + + Returns + ------- + prim: MSCPrim + The found prim. + """ + + return _ffi_api.MSCGraphFindPrim(self, name) + def has_tensor(self, name: str) -> bool: """Check if tensor in the graph. @@ -679,6 +737,18 @@ def get_nodes(self) -> Iterable[MSCJoint]: for n in self.node_names: yield self.find_node(n) + def get_prims(self) -> Iterable[MSCPrim]: + """Get all the prims in the graph. + + Returns + ------- + prims: generator + The generator of prims. + """ + + for n in self.prim_names: + yield self.find_prim(n) + def get_weights(self) -> Iterable[MSCTensor]: """Get all the weights in the graph. @@ -789,11 +859,16 @@ def inspect(self) -> dict: "nodes": {"total": 0}, } for node in self.get_nodes(): + graph_des["nodes"].setdefault(node.optype, 0) graph_des["nodes"]["total"] += 1 - if node.optype not in graph_des["nodes"]: - graph_des["nodes"][node.optype] = 1 - else: - graph_des["nodes"][node.optype] += 1 + graph_des["nodes"][node.optype] += 1 + prims = {"total": 0} + for prim in self.get_prims(): + prims.setdefault(prim.optype, 0) + prims["total"] += 1 + prims[prim.optype] += 1 + if prims["total"] > 0: + graph_des["prims"] = prims return graph_des @classmethod diff --git a/python/tvm/contrib/msc/core/tools/prune/pruner.py b/python/tvm/contrib/msc/core/tools/prune/pruner.py index 90273e25416b..a008100be252 100644 --- a/python/tvm/contrib/msc/core/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/core/tools/prune/pruner.py @@ -340,7 +340,12 @@ def _prune_by_shape(tensor: MSCTensor, shape: List[int]): def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None): shape = tensor.get_shape() if channel_axis is None: - channel_axis = tensor.layout_of("C") + if self.has_w_node(tensor.name): + w_node = self.find_w_node(tensor.name) + _, channel_axis = self._get_io_axes(w_node) + else: + channel_axis = tensor.layout_of("C") + assert channel_axis >= 0, "Can not infer channel_axis for " + str(tensor) shape[channel_axis] = dim return _prune_by_shape(tensor, shape) diff --git a/python/tvm/contrib/msc/core/tools/tool.py b/python/tvm/contrib/msc/core/tools/tool.py index 626ae312bcf4..06a16f2bbe49 100644 --- a/python/tvm/contrib/msc/core/tools/tool.py +++ b/python/tvm/contrib/msc/core/tools/tool.py @@ -1620,6 +1620,9 @@ def _get_io_axes(self, w_node: WeightJoint) -> Tuple[int, int]: in_axis, out_axis = w_node.weight.layout_of("I"), w_node.weight.layout_of("O") if in_axis >= 0 and out_axis >= 0: return in_axis, out_axis + if w_node.weight.ndim == 2 and w_node.weight.dim_at("N") > 0: + io_axis = 1 - w_node.weight.layout_of("N") + return io_axis, io_axis if w_node.weight.layout_of("C") >= 0: return w_node.weight.layout_of("C"), w_node.weight.layout_of("C") raise Exception("Can not infer in_axis/out_axis from " + str(w_node)) diff --git a/python/tvm/contrib/msc/framework/torch/frontend/translate.py b/python/tvm/contrib/msc/framework/torch/frontend/translate.py index 2509f1abfcbe..c8c2844c2859 100644 --- a/python/tvm/contrib/msc/framework/torch/frontend/translate.py +++ b/python/tvm/contrib/msc/framework/torch/frontend/translate.py @@ -22,9 +22,8 @@ import torch import tvm from tvm.relax.frontend.torch import from_fx - from tvm.contrib.msc.core.ir.graph import MSCGraph -from tvm.contrib.msc.core.frontend import from_relax +from tvm.contrib.msc.core.frontend import from_relax, normalize_inputs from tvm.contrib.msc.core.codegen import relay_to_relax @@ -104,6 +103,7 @@ def from_torch( """ if via_relax: + input_info = normalize_inputs(input_info) graph_model, params = torch.fx.symbolic_trace(model), None with torch.no_grad(): relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map) diff --git a/python/tvm/contrib/msc/pipeline/pipeline.py b/python/tvm/contrib/msc/pipeline/pipeline.py index f02503a113ca..e003f692241c 100644 --- a/python/tvm/contrib/msc/pipeline/pipeline.py +++ b/python/tvm/contrib/msc/pipeline/pipeline.py @@ -676,10 +676,20 @@ def _get_loader(self, name: str = MSCStage.PREPARE) -> Any: max_batch = config.get("max_batch", 5) def get_random(): + def _to_data(inp): + shape = [1 if isinstance(d, str) else d for d in inp[1]] + return np.random.rand(*shape).astype(inp[2]) + for _ in range(max_batch): - yield {i[0]: np.random.rand(*i[1]).astype(i[2]) for i in self._config["inputs"]} + yield {i[0]: _to_data(i) for i in self._config["inputs"]} loader, source_type = get_random, "random" + elif isinstance(source_loader, dict): + + def load_data(): + return [source_loader] + + loader, source_type = load_data, "dict" elif msc_utils.is_io_dataset(source_loader): max_batch = config.get("max_batch", -1) diff --git a/python/tvm/contrib/msc/pipeline/utils.py b/python/tvm/contrib/msc/pipeline/utils.py index e4d91ee14b62..c6689e1f0091 100644 --- a/python/tvm/contrib/msc/pipeline/utils.py +++ b/python/tvm/contrib/msc/pipeline/utils.py @@ -16,6 +16,7 @@ # under the License. """tvm.contrib.msc.pipeline.config""" +import copy from typing import List, Union, Dict, Tuple from tvm.contrib.msc.core.tools import ToolType @@ -129,6 +130,7 @@ def create_config( dataset: Dict[str, dict] = None, tools: List[Tuple[str, Union[dict, str]]] = None, dynamic: bool = False, + run_config: Dict[str, dict] = None, skip_config: Dict[str, str] = None, **extra_config, ) -> dict: @@ -160,11 +162,13 @@ def create_config( The extra config. """ + all_stages = [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE] baseline_type = baseline_type or model_type optimize_type = optimize_type or baseline_type compile_type = compile_type or optimize_type tools = tools or [] tools = [config_tool(t_type, t_config) for t_type, t_config in tools] + extra_config = extra_config or {} # basic config config = { "model_type": model_type, @@ -194,27 +198,34 @@ def create_config( "profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}}, } + # update run config + if run_config: + if "all" in run_config: + all_config = run_config.pop("all") + run_config.update({s: copy.deepcopy(all_config) for s in all_stages}) + for stage, r_config in run_config.items(): + extra_config.setdefault(stage, {}).setdefault("run_config", {}).update(r_config) + # update config if extra_config: config = msc_utils.update_dict(config, extra_config) # skip stages - skip_config = skip_config or {} - for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]: - if stage not in config: - continue - for key in ["all", stage]: - if key not in skip_config: + if skip_config: + if "all" in run_config: + all_config = skip_config.pop("all") + skip_config.update({s: copy.deepcopy(all_config) for s in all_stages}) + for stage, s_type in skip_config.items(): + if stage not in config: continue - if skip_config[key] == "stage": + if s_type == "stage": config.pop(stage) - elif skip_config[key] == "profile": + elif s_type == "profile": config[stage].pop("profile") - elif skip_config[key] == "check": - config[stage]["profile"].pop("check") - elif skip_config[key] == "benchmark": + elif s_type == "check": + config[stage]["profile"]["check"]["err_rate"] = -1 + elif s_type == "benchmark": config[stage]["profile"].pop("benchmark") else: - raise TypeError("Unexpected skip type " + str(skip_config[key])) - + raise TypeError("Unexpected skip type " + str(s_type)) return config diff --git a/python/tvm/contrib/msc/pipeline/wrapper.py b/python/tvm/contrib/msc/pipeline/wrapper.py index 1332b3c79115..91862c794027 100644 --- a/python/tvm/contrib/msc/pipeline/wrapper.py +++ b/python/tvm/contrib/msc/pipeline/wrapper.py @@ -240,6 +240,9 @@ class TorchWrapper(BaseWrapper): """Wrapper of torch models""" def __call__(self, *inputs): + return self.forward(*inputs) + + def forward(self, *inputs): framework = self._get_framework() if framework != MSCFramework.TORCH: inputs = [msc_utils.cast_array(i, framework, self.device) for i in inputs] diff --git a/src/contrib/msc/core/codegen/base_codegen.h b/src/contrib/msc/core/codegen/base_codegen.h index acaac896a153..f582f6416d93 100644 --- a/src/contrib/msc/core/codegen/base_codegen.h +++ b/src/contrib/msc/core/codegen/base_codegen.h @@ -58,9 +58,11 @@ class BaseOpCode { virtual ~BaseOpCode() = default; /*! \brief Config the BaseOpCode*/ - void Config(const MSCJoint& node, const std::shared_ptr config) { + void Config(const MSCJoint& node, const std::shared_ptr config, + const Map& prims) { node_ = node; config_ = config; + prims_ = prims; } /*! \brief Get docs for the node*/ @@ -158,6 +160,13 @@ class BaseCodeGen { } } + virtual void Init() { + // define prims + for (const auto& p_name : this->graph()->prim_names) { + prims_.Set(p_name, this->DescribePrim(this->graph()->FindPrim(p_name))); + } + } + virtual ~BaseCodeGen() = default; /*! \brief Get sources*/ @@ -211,6 +220,29 @@ class BaseCodeGen { /*! \brief Get the docs for the op*/ virtual const Array GetOpCodes(const MSCJoint& node) = 0; + /*! \brief Describe the prim*/ + virtual const String DescribePrim(const MSCPrim& prim) { + if (prim->optype == "Int") { + return prim->GetTypeAttr("value"); + } + if (prim->optype == "shape") { + const auto& producer = this->graph()->FindNode(prim->GetTypeAttr("producer")); + int out_idx = prim->GetTypeAttr("out_idx"); + const auto& dim = prim->GetTypeAttr("dim"); + return this->IdxOutputBase(producer, out_idx) + ".shape[" + dim + "]"; + } + // binary ops + DESCRIBE_PRIM_BINARY("Add", "+", false) + DESCRIBE_PRIM_BINARY("Sub", "-", false) + DESCRIBE_PRIM_BINARY("Mul", "*", false) + DESCRIBE_PRIM_BINARY("Divide", "/", false) + DESCRIBE_PRIM_BINARY("LT", "<", false) + DESCRIBE_PRIM_BINARY("LE", "<=", false) + DESCRIBE_PRIM_BINARY("GT", ">", false) + DESCRIBE_PRIM_BINARY("GE", ">=", false) + LOG_FATAL << "Unexpected prim " << prim; + } + /*! \brief Get the graph*/ const MSCGraph graph() const { return graph_; } diff --git a/src/contrib/msc/core/codegen/codegen_utils.cc b/src/contrib/msc/core/codegen/codegen_utils.cc index 44626debe1d8..741b729bd015 100644 --- a/src/contrib/msc/core/codegen/codegen_utils.cc +++ b/src/contrib/msc/core/codegen/codegen_utils.cc @@ -54,13 +54,37 @@ const String CodeGenUtils::IdxWeight(const MSCJoint& node, const String& wtype, return wtype + "_" + std::to_string(node->index) + suffix; } -const String CodeGenUtils::CommentNode(const MSCJoint& node, const String& prefix) { +const Array CodeGenUtils::GetPrims(const MSCTensor& tensor, + const Map& prims) { + Array dims; + if (tensor->prims.size() == 0) { + for (size_t i = 0; i < tensor->Ndim(); i++) { + dims.push_back(StringUtils::ToString(tensor->DimAt(i))); + } + return dims; + } + for (size_t i = 0; i < tensor->Ndim(); i++) { + const auto& prim = tensor->PrimAt(i); + dims.push_back(prims.count(prim) ? prims[prim] : prim); + } + return dims; +} + +const String CodeGenUtils::CommentNode(const MSCJoint& node, const String& prefix, + const Map& prims) { String comment = node->name + "(" + node->optype + "): <"; for (size_t i = 0; i < node->inputs.size(); i++) { comment = comment + IdxInput(node, prefix, i) + (i == node->inputs.size() - 1 ? "> -> <" : ","); } for (size_t i = 0; i < node->outputs.size(); i++) { - comment = comment + IdxOutput(node, prefix, i) + (i == node->outputs.size() - 1 ? ">" : ","); + const auto& t_output = node->OutputAt(i); + const auto& t_prims = GetPrims(t_output, prims); + comment = comment + IdxOutput(node, prefix, i) + "|" + StringUtils::Join(t_prims, ":"); + comment = comment + "|" + t_output->DTypeName(); + if (t_output->layout.defined()) { + comment = comment + "|" + t_output->layout->name; + } + comment = comment + (i == node->outputs.size() - 1 ? ">" : ", "); } return comment; } diff --git a/src/contrib/msc/core/codegen/codegen_utils.h b/src/contrib/msc/core/codegen/codegen_utils.h index 1af8df5ac1a4..abdb91b4703f 100644 --- a/src/contrib/msc/core/codegen/codegen_utils.h +++ b/src/contrib/msc/core/codegen/codegen_utils.h @@ -76,12 +76,23 @@ using namespace tvm::script::printer; LOG(FATAL) << "Do not support key " << key; \ } +#define DESCRIBE_PRIM_BINARY(OpType, Symbol, AsFunc) \ + if (prim->optype == OpType) { \ + if (AsFunc) { \ + return std::string(Symbol) + "(" + this->DescribePrim(prim->ParentAt(0)) + "," + \ + this->DescribePrim(prim->ParentAt(1)) + ")"; \ + } \ + return "(" + this->DescribePrim(prim->ParentAt(0)) + Symbol + \ + this->DescribePrim(prim->ParentAt(1)) + ")"; \ + } + #define CODEGEN_MEMBERS \ public: \ virtual const String DType(const DataType& dtype) { return runtime::DLDataType2String(dtype); } \ \ protected: \ const std::shared_ptr config() { return config_; } \ + const Map prims() { return prims_; } \ const String IdxNodeBase(const MSCJoint& node) { \ return helper_.IdxNodeBase(node, config()->prefix, ""); \ } \ @@ -95,13 +106,19 @@ using namespace tvm::script::printer; const String IdxWeightBase(const MSCJoint& node, const String& wtype, bool process = true) { \ return helper_.IdxWeightBase(node, wtype, "", process && config()->use_tools); \ } \ - const String Comment(const MSCJoint& node) { return helper_.Comment(node, config()->prefix); } \ + const Array GetPrims(const MSCTensor& tensor) { \ + return CodeGenUtils::GetPrims(tensor, prims_); \ + } \ + const String Comment(const MSCJoint& node) { \ + return helper_.Comment(node, config()->prefix, prims_); \ + } \ int CompareVersion(size_t major, size_t minor, size_t patch) { \ return CommonUtils::CompareVersion(config()->version, {major, minor, patch}); \ } \ \ private: \ std::shared_ptr config_; \ + Map prims_; \ HelperType helper_; /*! @@ -137,11 +154,18 @@ class CodeGenUtils { TVM_DLL static const String IdxWeight(const MSCJoint& node, const String& wtype, const String& suffix = ""); + /*! + * \brief Infer prims of tensor. + * \return The prims. + */ + TVM_DLL static const Array GetPrims(const MSCTensor& tensor, + const Map& prims); /*! * \brief Get comment of a node. * \return The String. */ - TVM_DLL static const String CommentNode(const MSCJoint& node, const String& prefix); + TVM_DLL static const String CommentNode(const MSCJoint& node, const String& prefix, + const Map& prims); }; /*! @@ -180,8 +204,9 @@ class BaseCodeGenHelper { const String& suffix = "", bool process = false) { return CodeGenUtils::IdxWeight(node, wtype, suffix + GetSuffix(node, process)); } - virtual const String Comment(const MSCJoint& node, const String& prefix = "") { - return CodeGenUtils::CommentNode(node, prefix); + virtual const String Comment(const MSCJoint& node, const String& prefix = "", + const Map& prims = Map()) { + return CodeGenUtils::CommentNode(node, prefix, prims); } }; diff --git a/src/contrib/msc/core/codegen/cpp_codegen.h b/src/contrib/msc/core/codegen/cpp_codegen.h index 2c07aeb4c741..81b7d1e871a2 100644 --- a/src/contrib/msc/core/codegen/cpp_codegen.h +++ b/src/contrib/msc/core/codegen/cpp_codegen.h @@ -95,6 +95,20 @@ class CppCodeGen : public BaseCodeGen { } protected: + /*! \brief Describe the prim*/ + virtual const String DescribePrim(const MSCPrim& prim) { + // binary ops + DESCRIBE_PRIM_BINARY("Min", "std::min", true) + DESCRIBE_PRIM_BINARY("Max", "std::max", true) + // special + if (prim->optype == "if_then_else") { + return "(" + this->DescribePrim(prim->ParentAt(0)) + "?" + + this->DescribePrim(prim->ParentAt(1)) + ":" + this->DescribePrim(prim->ParentAt(2)) + + ")"; + } + return BaseCodeGen::DescribePrim(prim); + } + /*! \brief Stack the docs for the node*/ virtual void CodeGenNode(const MSCJoint& node, bool use_tools) { this->stack_.comment(this->Comment(node)); diff --git a/src/contrib/msc/core/codegen/py_codegen.h b/src/contrib/msc/core/codegen/py_codegen.h index e1ceb716a278..c1ecded61df1 100644 --- a/src/contrib/msc/core/codegen/py_codegen.h +++ b/src/contrib/msc/core/codegen/py_codegen.h @@ -82,6 +82,20 @@ class PyCodeGen : public BaseCodeGen { } protected: + /*! \brief Describe the prim*/ + virtual const String DescribePrim(const MSCPrim& prim) { + // binary ops + DESCRIBE_PRIM_BINARY("Min", "min", true) + DESCRIBE_PRIM_BINARY("Max", "max", true) + // special + if (prim->optype == "if_then_else") { + return "(" + this->DescribePrim(prim->ParentAt(1)) + " if " + + this->DescribePrim(prim->ParentAt(0)) + " else " + + this->DescribePrim(prim->ParentAt(2)) + ")"; + } + return BaseCodeGen::DescribePrim(prim); + } + /*! \brief Stack the docs for the header*/ virtual void CodeGenHeader() { this->stack_.line("import os") diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc index ca1bff09725f..ae42537a4ce1 100644 --- a/src/contrib/msc/core/ir/graph.cc +++ b/src/contrib/msc/core/ir/graph.cc @@ -35,13 +35,14 @@ namespace contrib { namespace msc { MSCTensor::MSCTensor(const String& name, const DataType& dtype, const String& layout, - const Array& shape, const String& alias) { + const Array& shape, const String& alias, const Array& prims) { ObjectPtr n = make_object(); n->name = std::move(name); n->alias = std::move(alias); n->dtype = std::move(dtype); n->shape = std::move(shape); n->layout = tvm::tir::Layout(layout); + n->prims = prims; data_ = std::move(n); } @@ -68,6 +69,9 @@ const JsonMSCTensor MSCTensorNode::ToJson() const { for (const auto& s : shape) { j_tensor.shape.push_back(s->value); } + for (const auto& p : prims) { + j_tensor.prims.push_back(p); + } return j_tensor; } @@ -81,6 +85,9 @@ void MSCTensorNode::FromJson(const JsonMSCTensor& j_tensor) { for (const auto& s : j_tensor.shape) { shape.push_back(s); } + for (const auto& p : j_tensor.prims) { + prims.push_back(p); + } } void MSCTensorNode::FromJson(const std::string& json_str) { @@ -103,6 +110,17 @@ const Integer MSCTensorNode::DimAt(const String& axis) const { return DimAt(index); } +const String MSCTensorNode::PrimAt(int index) const { + if (prims.size() == 0) { + return ""; + } + return prims[CommonUtils::GetIndex(index, Ndim())]; +} + +const String MSCTensorNode::PrimAt(const String& axis) const { + return PrimAt(layout.IndexOf(tvm::tir::LayoutAxis::Get(axis))); +} + int32_t MSCTensorNode::LayoutOf(const String& axis) const { return layout.IndexOf(tvm::tir::LayoutAxis::Get(axis)); } @@ -498,6 +516,76 @@ const std::pair MSCJointNode::ProducerAndIdxOf(const MSCTensor return ProducerAndIdxOf(input->name); } +MSCPrim::MSCPrim(int index, const String& name, const String& optype, + const Array& parents, const Map& attrs) { + ObjectPtr n = make_object(); + n->index = index; + n->name = std::move(name); + n->optype = std::move(optype); + n->attrs = std::move(attrs); + for (const auto& p : parents) { + n->parents.push_back(p); + } + data_ = std::move(n); +} + +MSCPrim::MSCPrim(const JsonMSCPrim& j_prim, const Map& prims) { + ObjectPtr n = make_object(); + n->FromJson(j_prim, prims); + data_ = std::move(n); +} + +MSCPrim::MSCPrim(const std::string& json_str, const Map& prims) { + ObjectPtr n = make_object(); + n->FromJson(json_str, prims); + data_ = std::move(n); +} + +const JsonMSCPrim MSCPrimNode::ToJson() const { + JsonMSCPrim j_prim; + j_prim.index = index; + j_prim.name = name; + j_prim.optype = optype; + for (const auto& pair : attrs) { + j_prim.attrs[pair.first] = pair.second; + } + for (const auto& p : parents) { + j_prim.parents.push_back(Downcast(p)->name); + } + return j_prim; +} + +void MSCPrimNode::FromJson(const JsonMSCPrim& j_prim, const Map& prims) { + index = j_prim.index; + name = j_prim.name; + optype = j_prim.optype; + for (const auto& pair : j_prim.attrs) { + attrs.Set(pair.first, pair.second); + } + for (const auto& p_name : j_prim.parents) { + ICHECK(prims.count(p_name)) << "Can not find parent " << p_name; + parents.push_back(prims[p_name]); + } +} + +void MSCPrimNode::FromJson(const std::string& json_str, const Map& prims) { + std::istringstream is(json_str); + dmlc::JSONReader reader(&is); + JsonMSCPrim j_prim; + reader.Read(&j_prim); + FromJson(j_prim, prims); +} + +const MSCPrim MSCPrimNode::ParentAt(int index) const { + size_t v_index = CommonUtils::GetIndex(index, parents.size()); + return Downcast(parents[v_index]); +} + +const MSCPrim MSCPrimNode::ChildAt(int index) const { + size_t v_index = CommonUtils::GetIndex(index, children.size()); + return Downcast(children[v_index]); +} + WeightJoint::WeightJoint(int index, const String& name, const String& shared_ref, const String& weight_type, const MSCTensor& weight, const Array parents, const Map& attrs, @@ -587,7 +675,8 @@ const bool BaseGraphNode::HasNode(const String& name) const { } MSCGraph::MSCGraph(const String& name, const Array& nodes, - const Array& input_names, const Array& output_names) { + const Array& input_names, const Array& output_names, + const Array& prims) { ObjectPtr n = make_object(); n->name = std::move(name); for (const auto& node : nodes) { @@ -596,6 +685,10 @@ MSCGraph::MSCGraph(const String& name, const Array& nodes, } n->input_names = std::move(input_names); n->output_names = std::move(output_names); + for (const auto& prim : prims) { + n->prim_names.push_back(prim->name); + n->prims.Set(prim->name, prim); + } n->AnalysisGraph(); data_ = std::move(n); } @@ -625,6 +718,10 @@ const JsonMSCGraph MSCGraphNode::ToJson() const { const auto& node = FindNode(n); j_graph.nodes.push_back(node->ToJson()); } + for (const auto& n : prim_names) { + const auto& prim = FindPrim(n); + j_graph.prims.push_back(prim->ToJson()); + } return j_graph; } @@ -646,6 +743,16 @@ void MSCGraphNode::FromJson(const JsonMSCGraph& j_graph) { node_names.push_back(node->name); nodes.Set(node->name, node); } + Map loaded_prims; + for (const auto& n : j_graph.prims) { + const auto& prim = MSCPrim(n, loaded_prims); + loaded_prims.Set(prim->name, prim); + for (const auto& p : prim->parents) { + Downcast(p)->AddChild(prim); + } + prim_names.push_back(prim->name); + prims.Set(prim->name, prim); + } AnalysisGraph(); } @@ -697,6 +804,11 @@ const MSCJoint MSCGraphNode::FindNode(const String& name) const { return Downcast(nodes[name]); } +const MSCPrim MSCGraphNode::FindPrim(const String& name) const { + ICHECK(prims.count(name)) << "Can not find prim " << name; + return prims[name]; +} + const MSCTensor MSCGraphNode::InputAt(int index) const { size_t v_index = CommonUtils::GetIndex(index, input_names.size()); return FindTensor(input_names[v_index]); @@ -1004,9 +1116,8 @@ void WeightGraphNode::Build(const MSCGraph& graph, const MapOutputAt(0); Map attrs; attrs.Set("producer_type", node->optype); - if (node->optype == "reshape" && node->InputAt(0)->LayoutOf("C") >= 0 && - node->OutputAt(0)->LayoutOf("C") >= 0 && - node->InputAt(0)->DimAt("C")->value == node->OutputAt(0)->DimAt("C")->value) { + if (node->optype == "reshape") { + // TODO(archermmt): check non-passby reshape attrs.Set("weight_strategy", "passby"); } else { attrs.Set("weight_strategy", relation_wtypes[node->optype]); @@ -1155,7 +1266,11 @@ MSCGraph PruneWeights(const MSCGraph& graph, const Map& prune Downcast(p)->AddChild(new_node); } } - return MSCGraph(graph->name, nodes, graph->input_names, graph->output_names); + Array prims; + for (const auto& name : graph->prim_names) { + prims.push_back(graph->FindPrim(name)); + } + return MSCGraph(graph->name, nodes, graph->input_names, graph->output_names, prims); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -1168,7 +1283,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } p->stream << "<"; for (size_t i = 0; i < tensor->Ndim(); i++) { - p->stream << tensor->shape[i]->value << (i == tensor->Ndim() - 1 ? "|" : ","); + const auto& prim = tensor->PrimAt(i); + p->stream << (prim.size() > 0 ? prim : StringUtils::ToString(tensor->shape[i])) + << (i == tensor->Ndim() - 1 ? "|" : ","); } p->stream << tensor->dtype; if (tensor->layout.defined()) { @@ -1177,8 +1294,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ">"; }); -#define MSC_NODE_BASE_HEAD(Stream, Joint) \ - Stream << "ID_" << Joint->index << " " << Joint->name; \ +#define MSC_NODE_BASE_HEAD(Stream, Joint, Type) \ + Stream << Type << "_" << Joint->index << " " << Joint->name; \ if (Joint->shared_ref.size() > 0) { \ Stream << "(M: " << Joint->shared_ref << ")"; \ } \ @@ -1200,7 +1317,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* joint = static_cast(node.get()); p->PrintIndent(); - MSC_NODE_BASE_HEAD(p->stream, joint); + MSC_NODE_BASE_HEAD(p->stream, joint, "N"); if (joint->inputs.size() > 0) { p->stream << " IN: "; for (size_t i = 0; i < joint->inputs.size(); i++) { @@ -1234,11 +1351,26 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* prim = static_cast(node.get()); + p->PrintIndent(); + MSC_NODE_BASE_HEAD(p->stream, prim, "P"); + p->stream << " OPTYPE: " << prim->optype; + if (prim->attrs.size() > 0) { + p->stream << "\n ATTRS: "; + for (const auto& pair : prim->attrs) { + p->stream << pair.first << "=" << pair.second << " "; + } + } + p->stream << "\n"; + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* joint = static_cast(node.get()); p->PrintIndent(); - MSC_NODE_BASE_HEAD(p->stream, joint); + MSC_NODE_BASE_HEAD(p->stream, joint, "W"); if (joint->friends.size() > 0) { p->stream << " FRIENDS: "; for (size_t i = 0; i < joint->friends.size(); i++) { @@ -1279,6 +1411,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) for (size_t i = 0; i < graph->output_names.size(); i++) { p->stream << graph->output_names[i] << (i == graph->output_names.size() - 1 ? ">\n" : ","); } + for (const auto& n : graph->prim_names) { + p->stream << graph->FindPrim(n) << "\n"; + } for (const auto& n : graph->node_names) { p->stream << graph->FindNode(n) << "\n"; } @@ -1288,6 +1423,8 @@ TVM_REGISTER_NODE_TYPE(MSCTensorNode); TVM_REGISTER_NODE_TYPE(MSCJointNode); +TVM_REGISTER_NODE_TYPE(MSCPrimNode); + TVM_REGISTER_NODE_TYPE(WeightJointNode); TVM_REGISTER_NODE_TYPE(MSCGraphNode); @@ -1296,8 +1433,9 @@ TVM_REGISTER_NODE_TYPE(WeightGraphNode); TVM_REGISTER_GLOBAL("msc.core.MSCTensor") .set_body_typed([](const String& name, const DataType& dtype, const String& layout, - const Array& shape, const String& alias) -> MSCTensor { - return MSCTensor(name, dtype, layout, shape, alias); + const Array& shape, const String& alias, + const Array& prims) -> MSCTensor { + return MSCTensor(name, dtype, layout, shape, alias, prims); }); TVM_REGISTER_GLOBAL("msc.core.MSCTensorToJson") @@ -1326,6 +1464,16 @@ TVM_REGISTER_GLOBAL("msc.core.MSCJoint") weights); }); +TVM_REGISTER_GLOBAL("msc.core.MSCPrim") + .set_body_typed([](Integer index, const String& name, const String& optype, + const Map& attrs, const Array& parents) -> MSCPrim { + Array b_parents; + for (const auto& p : parents) { + b_parents.push_back(p); + } + return MSCPrim(index->value, name, optype, b_parents, attrs); + }); + TVM_REGISTER_GLOBAL("msc.core.WeightJoint") .set_body_typed([](Integer index, const String& name, const String& shared_ref, const String& weight_type, const MSCTensor& weight, @@ -1349,9 +1497,9 @@ TVM_REGISTER_GLOBAL("msc.core.WeightJointSetAttr") TVM_REGISTER_GLOBAL("msc.core.MSCGraph") .set_body_typed([](const String& name, const Array& nodes, - const Array& input_names, - const Array& output_names) -> MSCGraph { - return MSCGraph(name, nodes, input_names, output_names); + const Array& input_names, const Array& output_names, + const Array& prims) -> MSCGraph { + return MSCGraph(name, nodes, input_names, output_names, prims); }); TVM_REGISTER_GLOBAL("msc.core.WeightGraph") @@ -1371,6 +1519,11 @@ TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindNode") return graph->FindNode(name); }); +TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindPrim") + .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCPrim { + return graph->FindPrim(name); + }); + TVM_REGISTER_GLOBAL("msc.core.MSCGraphHasTensor") .set_body_typed([](const MSCGraph& graph, const String& name) -> Bool { return Bool(graph->HasTensor(name)); diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h index 7005518f367b..1e22e96ac951 100644 --- a/src/contrib/msc/core/ir/graph.h +++ b/src/contrib/msc/core/ir/graph.h @@ -48,6 +48,7 @@ struct JsonMSCTensor { std::string dtype; std::string layout; std::vector shape; + std::vector prims; void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); @@ -56,6 +57,7 @@ struct JsonMSCTensor { writer->WriteObjectKeyValue("dtype", dtype); writer->WriteObjectKeyValue("layout", layout); writer->WriteObjectKeyValue("shape", shape); + writer->WriteObjectKeyValue("prims", prims); writer->EndObject(); } @@ -77,6 +79,8 @@ struct JsonMSCTensor { } else if (key == "shape") { reader->Read(&shape); bitmask |= 4; + } else if (key == "prims") { + reader->Read(&prims); } } ICHECK_EQ(bitmask, 1 | 2 | 4) << "name, dtype and shape should be given"; @@ -147,6 +151,51 @@ struct JsonMSCJoint { } }; +/*! + * \brief Json serialize and deserialize for MSCPrim. + * MSCPrim is node in MSCGraph with name, op and attrbutes. + */ +struct JsonMSCPrim { + size_t index; + std::string name; + std::string optype; + std::vector parents; + std::unordered_map attrs; + + void Save(dmlc::JSONWriter* writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("index", index); + writer->WriteObjectKeyValue("name", name); + writer->WriteObjectKeyValue("optype", optype); + writer->WriteObjectKeyValue("parents", parents); + writer->WriteObjectKeyValue("attrs", attrs); + writer->EndObject(); + } + + void Load(dmlc::JSONReader* reader) { + int bitmask = 0; + std::string key; + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "index") { + reader->Read(&index); + bitmask |= 1; + } else if (key == "name") { + reader->Read(&name); + bitmask |= 2; + } else if (key == "optype") { + reader->Read(&optype); + bitmask |= 4; + } else if (key == "parents") { + reader->Read(&parents); + } else if (key == "attrs") { + reader->Read(&attrs); + } + } + ICHECK_EQ(bitmask, 1 | 2 | 4) << "index, name and optype should be given"; + } +}; + /*! * \brief Json serialize and deserialize for WeightJoint. * WeightJoint is node in WeightGraph with name, wtype and attrbutes. @@ -216,6 +265,7 @@ struct JsonMSCGraph { std::vector inputs; std::vector outputs; std::vector nodes; + std::vector prims; void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); @@ -223,6 +273,7 @@ struct JsonMSCGraph { writer->WriteObjectKeyValue("inputs", inputs); writer->WriteObjectKeyValue("outputs", outputs); writer->WriteObjectKeyValue("nodes", nodes); + writer->WriteObjectKeyValue("prims", prims); writer->EndObject(); } @@ -243,6 +294,8 @@ struct JsonMSCGraph { } else if (key == "nodes") { reader->Read(&nodes); bitmask |= 8; + } else if (key == "prims") { + reader->Read(&prims); } } ICHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "name, inputs, outputs and nodes should be given"; @@ -297,6 +350,8 @@ class MSCTensorNode : public Object { tvm::tir::Layout layout; /*! \brief The shape of tensor. */ Array shape; + /*! \brief The prims of tensor. */ + Array prims; /*! \brief Export tensor to json. */ const JsonMSCTensor ToJson() const; /*! \brief Load tensor from json struct. */ @@ -309,6 +364,10 @@ class MSCTensorNode : public Object { const Integer DimAt(int index) const; /*! \brief Get dim at given axis. */ const Integer DimAt(const String& axis) const; + /*! \brief Get prim at given index. */ + const String PrimAt(int index) const; + /*! \brief Get prim at given axis. */ + const String PrimAt(const String& axis) const; /*! \brief Get layout index of given axis. */ int32_t LayoutOf(const String& axis) const; /*! \brief Get size of the tensor. */ @@ -322,11 +381,12 @@ class MSCTensorNode : public Object { v->Visit("dtype", &dtype); v->Visit("layout", &layout); v->Visit("shape", &shape); + v->Visit("prims", &prims); } bool SEqualReduce(const MSCTensorNode* other, SEqualReducer equal) const { return equal(name, other->name) && equal(dtype, other->dtype) && equal(shape, other->shape) && - equal(layout, other->layout); + equal(layout, other->layout) && equal(prims, other->prims); } void SHashReduce(SHashReducer hash_reduce) const { @@ -334,6 +394,7 @@ class MSCTensorNode : public Object { hash_reduce(dtype); hash_reduce(shape); hash_reduce(layout); + hash_reduce(prims); } static constexpr const char* _type_key = "msc.core.MSCTensor"; @@ -353,9 +414,11 @@ class MSCTensor : public ObjectRef { * \param layout The layout of the tensor. * \param shape The shape of the tensor. * \param alias The alias of the tensor. + * \param prims The prims of the tensor shape. */ TVM_DLL MSCTensor(const String& name, const DataType& dtype, const String& layout, - const Array& shape, const String& alias = ""); + const Array& shape, const String& alias = "", + const Array& prims = Array()); /*! * \brief The json constructor. @@ -576,6 +639,76 @@ class MSCJoint : public BaseJoint { TVM_DEFINE_OBJECT_REF_METHODS(MSCJoint, BaseJoint, MSCJointNode); }; +/*! + * \brief MSCPrim in MSCGraph. + */ +class MSCPrim; +class MSCPrimNode : public BaseJointNode { + public: + /*! \brief The op of prim. */ + String optype; + /*! \brief Export prim to json. */ + const JsonMSCPrim ToJson() const; + /*! \brief Load prim from json struct. */ + void FromJson(const JsonMSCPrim& j_prim, const Map& prims); + /*! \brief Load prim from json string. */ + void FromJson(const std::string& json_str, const Map& prims); + /*! \brief Get parent from the prim. */ + const MSCPrim ParentAt(int index) const; + /*! \brief Get child from the prim. */ + const MSCPrim ChildAt(int index) const; + + void VisitAttrs(AttrVisitor* v) { + BaseJointNode::VisitAttrs(v); + v->Visit("optype", &optype); + } + + bool SEqualReduce(const MSCPrimNode* other, SEqualReducer equal) const { + return BaseJointNode::SEqualReduce(other, equal) && equal(optype, other->optype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + BaseJointNode::SHashReduce(hash_reduce); + hash_reduce(optype); + } + + static constexpr const char* _type_key = "msc.core.MSCPrim"; + TVM_DECLARE_FINAL_OBJECT_INFO(MSCPrimNode, BaseJointNode); +}; + +/*! + * \brief Managed reference to MSCPrimNode. + * \sa MSCPrimNode + */ +class MSCPrim : public BaseJoint { + public: + /*! + * \brief The constructor. + * \param index The index of the prim. + * \param name The name of the prim. + * \param optype The optype of the prim. + * \param parents The parents of the prim. + * \param attrs The attributes of the prim. + */ + TVM_DLL MSCPrim(int index, const String& name, const String& optype, + const Array& parents, + const Map& attrs = Map()); + + /*! + * \brief The json constructor. + * \param j_prim The json describe of the prim. + */ + TVM_DLL MSCPrim(const JsonMSCPrim& j_prim, const Map& prims); + + /*! + * \brief The json constructor. + * \param json_str The json describe of the prim. + */ + TVM_DLL MSCPrim(const std::string& json_str, const Map& prims); + + TVM_DEFINE_OBJECT_REF_METHODS(MSCPrim, BaseJoint, MSCPrimNode); +}; + /*! * \brief Node in WeightGraph. */ @@ -713,6 +846,10 @@ class BaseGraph : public ObjectRef { class MSCGraph; class MSCGraphNode : public BaseGraphNode { public: + /*! \brief The shape node names in graph. */ + Array prim_names; + /*! \brief The shape nodes in graph. */ + Map prims; /*! \brief The input names of graph. */ Array input_names; /*! \brief The output names of graph. */ @@ -731,6 +868,8 @@ class MSCGraphNode : public BaseGraphNode { const String ToPrototxt() const; /*! \brief Find node in graph. */ const MSCJoint FindNode(const String& name) const; + /*! \brief Find prim in graph. */ + const MSCPrim FindPrim(const String& name) const; /*! \brief Get input from the graph. */ const MSCTensor InputAt(int index) const; /*! \brief Get inputs from the graph. */ @@ -769,18 +908,23 @@ class MSCGraphNode : public BaseGraphNode { void VisitAttrs(AttrVisitor* v) { BaseGraphNode::VisitAttrs(v); + v->Visit("prims", &prims); + v->Visit("prim_names", &prim_names); v->Visit("input_names", &input_names); v->Visit("output_names", &output_names); v->Visit("weight_holders", &weight_holders); } bool SEqualReduce(const MSCGraphNode* other, SEqualReducer equal) const { - return BaseGraphNode::SEqualReduce(other, equal) && equal(input_names, other->input_names) && + return BaseGraphNode::SEqualReduce(other, equal) && equal(prims, other->prims) && + equal(prim_names, other->prim_names) && equal(input_names, other->input_names) && equal(output_names, other->output_names) && equal(weight_holders, other->weight_holders); } void SHashReduce(SHashReducer hash_reduce) const { BaseGraphNode::SHashReduce(hash_reduce); + hash_reduce(prims); + hash_reduce(prim_names); hash_reduce(input_names); hash_reduce(output_names); hash_reduce(weight_holders); @@ -799,14 +943,14 @@ class MSCGraph : public BaseGraph { /*! * \brief The constructor. * \param name The name of the node. - * \param node_names The node names in the graph * \param nodes The nodes in the graph. * \param input_names The input names of the graph. * \param output_names The output names of the graph. - * \param weight_holders The weights info of the graph. + * \param prims The prims in the graph. */ TVM_DLL MSCGraph(const String& name, const Array& nodes, - const Array& input_names, const Array& output_names); + const Array& input_names, const Array& output_names, + const Array& prims = Array()); /*! * \brief The json constructor. diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index a968df4204a2..20c7dbcc9172 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -138,6 +138,27 @@ const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) { // Add input nodes and record inputs; Array input_names, output_names; std::set added_inputs; + // Add prims + for (const auto& p : func->params) { + if (!p->struct_info_.defined()) { + continue; + } + if (p->struct_info_.value()->IsInstance()) { + const auto& shape = ExprUtils::GetShape(p, false); + for (size_t i = 0; i < shape.size(); i++) { + if (shape[i]->IsInstance()) { + Map attrs; + attrs.Set("producer", p->name_hint()); + attrs.Set("out_idx", "0"); + attrs.Set("dim", std::to_string(i)); + MatchOrCreatePrim(shape[i], "shape", Array(), attrs); + } + } + } else { + LOG_FATAL << "Unexpected func param " << p << "(" << p->GetTypeKey() << ")"; + } + } + for (const auto& p : func->params) { if (expr_tensor_map_.count(p)) { continue; @@ -203,7 +224,7 @@ const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) { } } // build graph - const auto& graph = MSCGraph(name_, valid_nodes, valid_inputs, output_names); + const auto& graph = MSCGraph(name_, valid_nodes, valid_inputs, output_names, prims_); // set inputs and outputs alias if (config_.input_aliases.size() == valid_inputs.size()) { for (size_t i = 0; i < valid_inputs.size(); i++) { @@ -471,14 +492,27 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional } // Build output tensor - auto build_output = [](const relax::StructInfo& sinfo, const String& node_name, - const String& layout) { + auto build_output = [this](const relax::StructInfo& sinfo, const String& node_name, + const String& layout) { ICHECK(sinfo->IsInstance()) << "sinfo should be TensorStructInfo, get " << sinfo->GetTypeKey(); const auto& t_info = Downcast(sinfo); - const auto& shape_opt = t_info->GetShape(); - const auto& shape = - shape_opt.defined() ? ArrayUtils::Cast(shape_opt.value()) : Array(); + const auto& shape = ArrayUtils::Cast(ExprUtils::GetShape(t_info)); + Array prims; + bool has_prims = false; + if (shape.size() > 0) { + for (const auto& s : t_info->GetShape().value()) { + if (prim_map_.count(s)) { + prims.push_back(prim_map_[s]->name); + has_prims = true; + } else { + prims.push_back(StringUtils::ToString(s)); + } + } + } + if (has_prims) { + return MSCTensor(node_name, t_info->dtype, layout, shape, "", prims); + } return MSCTensor(node_name, t_info->dtype, layout, shape); }; @@ -552,6 +586,104 @@ void RelaxGraphBuilder::VisitBindingBlock(const relax::BindingBlock& block) { block_stack_.pop_back(); } +#define ADD_BINARY_PRIM(TypeName) \ + if (prim->IsInstance()) { \ + const auto& binary = Downcast(prim); \ + return MatchOrCreatePrim(prim, "", {AddPrim(binary->a), AddPrim(binary->b)}); \ + } + +const MSCPrim RelaxGraphBuilder::AddPrim(const PrimExpr& prim) { + if (prim_map_.count(prim)) { + return prim_map_[prim]; + } + + // binary + ADD_BINARY_PRIM(tvm::tir::Add) + ADD_BINARY_PRIM(tvm::tir::Sub) + ADD_BINARY_PRIM(tvm::tir::Mul) + ADD_BINARY_PRIM(tvm::tir::Div) + ADD_BINARY_PRIM(tvm::tir::Mod) + ADD_BINARY_PRIM(tvm::tir::FloorDiv) + ADD_BINARY_PRIM(tvm::tir::FloorMod) + ADD_BINARY_PRIM(tvm::tir::Max) + ADD_BINARY_PRIM(tvm::tir::Min) + + // compare + ADD_BINARY_PRIM(tvm::tir::EQ) + ADD_BINARY_PRIM(tvm::tir::NE) + ADD_BINARY_PRIM(tvm::tir::LT) + ADD_BINARY_PRIM(tvm::tir::LE) + ADD_BINARY_PRIM(tvm::tir::GT) + ADD_BINARY_PRIM(tvm::tir::GE) + + // scalar + if (prim->IsInstance()) { + Map attrs; + attrs.Set("value", StringUtils::ToString(prim)); + return MatchOrCreatePrim(prim, "Int", Array(), attrs); + } + + // call + if (const auto* c_node = prim.as()) { + String optype; + Array parents; + if (const auto* op_node = c_node->op.as()) { + optype = StringUtils::Replace(op_node->name, "tir.", ""); + } else { + optype = "Prim"; + } + for (const auto& a : c_node->args) { + parents.push_back(AddPrim(a)); + } + return MatchOrCreatePrim(prim, optype, parents); + } + return MatchOrCreatePrim(prim); +} + +const MSCPrim RelaxGraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const String& optype, + const Array& parents, + const Map& attrs) { + if (prim_map_.count(prim)) { + return prim_map_[prim]; + } + const auto& op_ = + optype.size() == 0 ? StringUtils::Replace(prim->GetTypeKey(), "tir.", "") : optype; + for (const auto& p : prims_) { + if (p->optype != op_ || p->attrs.size() != attrs.size() || + p->parents.size() != parents.size()) { + continue; + } + bool attrs_match = std::all_of(p->attrs.begin(), p->attrs.end(), [&attrs](const auto& pair) { + return attrs.count(pair.first) && attrs[pair.first] == pair.second; + }); + if (!attrs_match) { + continue; + } + bool parents_match = true; + for (size_t i = 0; i < parents.size(); i++) { + if (p->ParentAt(i)->name != parents[i]->name) { + parents_match = false; + break; + } + } + if (!parents_match) { + continue; + } + prim_map_.Set(prim, p); + return p; + } + String name; + if (const auto* v_node = prim.as()) { + name = v_node->name_hint; + } else { + name = StringUtils::Upper(op_) + "_" + std::to_string(prims_.size()); + } + const auto& node = MSCPrim(prims_.size(), name, op_, parents, attrs); + prims_.push_back(node); + prim_map_.Set(prim, node); + return node; +} + void RelaxGraphBuilder::VisitExpr_(const relax::ConstantNode* op) { AddNode(GetRef(op)); } @@ -649,6 +781,13 @@ const std::tuple RelaxGraphBuilder::ParseFunc(const rela return std::make_tuple(node_name, optype, layout); } +void RelaxGraphBuilder::VisitPrimExpr(const PrimExpr& prim) { + RelaxExprVisitor::VisitPrimExpr(prim); + if (!prim->IsInstance() && !prim->IsInstance()) { + AddPrim(prim); + } +} + Array RelaxGraphBuilder::GetPluginInputs(const relax::Expr& expr) { ICHECK(expr->IsInstance()) << "plugin expr should be call"; const auto& call = Downcast(expr); diff --git a/src/contrib/msc/core/ir/graph_builder.h b/src/contrib/msc/core/ir/graph_builder.h index d514a793475d..250fa38ef91b 100644 --- a/src/contrib/msc/core/ir/graph_builder.h +++ b/src/contrib/msc/core/ir/graph_builder.h @@ -265,6 +265,13 @@ class RelaxGraphBuilder : public RelaxExprVisitor { const MSCJoint AddNode(const Expr& expr, const Optional& binding_var = NullOpt, const String& name = ""); + /*! \brief Create and add MSCPrim from prim*/ + const MSCPrim AddPrim(const PrimExpr& prim); + + const MSCPrim MatchOrCreatePrim(const PrimExpr& prim, const String& op = "", + const Array& parents = Array(), + const Map& attrs = Map()); + void VisitBindingBlock(const relax::BindingBlock& block) final; void VisitExpr_(const relax::ConstantNode* op) final; @@ -286,6 +293,8 @@ class RelaxGraphBuilder : public RelaxExprVisitor { void VisitBinding_(const relax::VarBindingNode* binding, const relax::FunctionNode* val) final; + void VisitPrimExpr(const PrimExpr& prim) final; + private: /*! \brief Get the node_name, optype, layout for func*/ const std::tuple ParseFunc(const relax::Function& func); @@ -309,6 +318,9 @@ class RelaxGraphBuilder : public RelaxExprVisitor { // BYOC maps Map target_funcs_; Map func_params_; + // prims + Array prims_; + Map prim_map_; }; class RelaxWeightsExtractor : public RelaxExprVisitor { diff --git a/src/contrib/msc/core/transform/layout_utils.cc b/src/contrib/msc/core/transform/layout_utils.cc index 317a39ab4e1a..a634b8e9e36a 100644 --- a/src/contrib/msc/core/transform/layout_utils.cc +++ b/src/contrib/msc/core/transform/layout_utils.cc @@ -156,29 +156,30 @@ const LayoutDecision LayoutUtils::ExpandLayout(const LayoutDecision& src_layout, std::string new_layout = src_layout.name(); ICHECK_EQ(new_layout.size(), src_layout->layout.ndim()) << "Only support normal layout, get " << src_layout->layout; - std::vector priority_dims{"N", "C", "H", "W", "D", "G", "T"}; - size_t left_size = axes.size(); + std::set used_axes; + for (size_t i = 0; i < src_layout->layout.ndim(); i++) { + used_axes.insert(src_layout->layout[i].name()); + } + std::vector prefer_axes{"N", "C", "H", "W", "D"}; for (const auto& a : axes) { - std::string target = "U"; - if (new_layout.find("H") && !new_layout.find("W")) { - target = "W"; - } else if (new_layout.find("W") && !new_layout.find("H")) { - target = "H"; - } else if (left_size == 1 && new_layout.find("C") && !new_layout.find("D")) { - target = "D"; - } else if (left_size == 1 && new_layout.find("D") && !new_layout.find("C")) { - target = "C"; + bool use_prefer = false; + if (used_axes.size() < prefer_axes.size()) { + use_prefer = + std::all_of(prefer_axes.begin(), prefer_axes.begin() + used_axes.size(), + [&used_axes](const std::string& axis) { return used_axes.count(axis); }); + } + std::string new_axis; + char cur_axis = 'A'; + if (use_prefer) { + new_axis = prefer_axes[used_axes.size()]; } else { - for (const auto& p : priority_dims) { - int pos = new_layout.find(p); - if (pos < 0) { - target = p; - break; - } + while (used_axes.count(std::string(1, cur_axis))) { + cur_axis += 1; } + new_axis = std::string(1, cur_axis); } - new_layout = new_layout.insert(a, target); - left_size--; + used_axes.insert(new_axis); + new_layout = new_layout.insert(a, new_axis); } return LayoutDecision(new_layout); } @@ -220,6 +221,18 @@ const LayoutDecision LayoutUtils::PermuteLayout(const LayoutDecision& src_layout return LayoutDecision(layout_str); } +int LayoutUtils::InferBatchDim(const LayoutDecision& layout) { + if (!layout->layout.defined()) { + return -1; + } + for (size_t i = 0; i < layout->layout.ndim(); i++) { + if (layout->layout[i].name() == "N") { + return static_cast(i); + } + } + return -1; +} + } // namespace msc } // namespace contrib } // namespace tvm diff --git a/src/contrib/msc/core/transform/layout_utils.h b/src/contrib/msc/core/transform/layout_utils.h index 7748f217d6ec..e7781a95a8f7 100644 --- a/src/contrib/msc/core/transform/layout_utils.h +++ b/src/contrib/msc/core/transform/layout_utils.h @@ -123,6 +123,12 @@ class LayoutUtils { const Array& axes); TVM_DLL static const LayoutDecision PermuteLayout(const LayoutDecision& src_layout, const std::vector& axes); + + /*! + * \brief Infer batch dim from the Layout + * \return The batch dim. + */ + TVM_DLL static int InferBatchDim(const LayoutDecision& layout); }; } // namespace msc diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index 56517fdae8d6..a3902a44bfaa 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -34,49 +34,11 @@ namespace relax { using namespace tvm::contrib::msc; -std::tuple AccumulateMatch(const std::vector& in_shape, - const std::vector& out_shape, size_t in_start, +std::tuple AccumulateMatch(const Array& input_shape, + const Array& output_shape, size_t in_start, size_t out_start) { // find input position in_pos and output position out_pos - // cumsum(in_shape[in_start:in_ops])==cumsum(out_shape[out_start:out_pos]) - int64_t in_pos = -1; - int64_t out_pos = -1; - int64_t in_accumulate = 1; - int64_t out_accumulate = 1; - for (size_t i = in_start; i < in_shape.size(); i++) { - in_accumulate *= in_shape[i]; - out_accumulate = 1; - for (size_t j = out_start; j < out_shape.size(); j++) { - out_accumulate *= out_shape[j]; - if (in_accumulate == out_accumulate) { - in_pos = i; - out_pos = j; - break; - } else if (out_accumulate > in_accumulate) { - break; - } - } - if (in_pos >= 0) { - break; - } - } - // append tailed 1s - if (in_pos >= 0) { - int64_t in_size = static_cast(in_shape.size()); - int64_t out_size = static_cast(out_shape.size()); - while (in_pos < in_size - 1 && in_shape[in_pos + 1] == 1) { - in_pos++; - } - while (out_pos < out_size - 1 && out_shape[out_pos + 1] == 1) { - out_pos++; - } - } - return std::make_tuple(in_pos, out_pos); -} - -std::vector InferReduceAxes(const Array& input_shape, - const Array& output_shape) { - std::vector reduce_axes, out_axes; + // cumsum(in_shape[in_start:in_pos])==cumsum(out_shape[out_start:out_pos]) std::vector in_shape, out_shape; for (const auto& s : input_shape) { in_shape.push_back(Downcast(s)->value); @@ -84,71 +46,76 @@ std::vector InferReduceAxes(const Array& input_shape, for (const auto& s : output_shape) { out_shape.push_back(Downcast(s)->value); } - size_t start = 0; - while (start < in_shape.size() && out_axes.size() < out_shape.size()) { - if (in_shape[start] == out_shape[out_axes.size()]) { - out_axes.push_back(start); - start++; - } else { - int64_t in_pos, out_pos; - size_t out_start = out_axes.size(); - std::tie(in_pos, out_pos) = AccumulateMatch(in_shape, out_shape, start, out_start); - if (in_pos == -1) { - return std::vector(); + int64_t in_size = static_cast(in_shape.size()); + int64_t out_size = static_cast(out_shape.size()); + int64_t in_pos = in_start; + int64_t out_pos = out_start; + int64_t in_accumulate = in_shape[in_pos]; + int64_t out_accumulate = out_shape[out_pos]; + while (in_accumulate != out_accumulate) { + if (in_accumulate > out_accumulate) { + out_pos += 1; + if (out_pos >= out_size) { + return std::make_tuple(-1, -1); } - for (size_t i = out_start; i < static_cast(out_pos) + 1; i++) { - out_axes.push_back(i + 1); + out_accumulate *= out_shape[out_pos]; + } else { + in_pos += 1; + if (in_pos >= in_size) { + return std::make_tuple(-1, -1); } - start = in_pos + 1; + in_accumulate *= in_shape[in_pos]; } } - if (out_axes.size() != out_shape.size()) { - return std::vector(); - } - std::set out_axes_set; - for (const auto& a : out_axes) { - out_axes_set.insert(a); + if (in_accumulate != out_accumulate) { + return std::make_tuple(-1, -1); } - for (size_t i = 0; i < in_shape.size(); i++) { - if (!out_axes_set.count(i)) { - reduce_axes.push_back(i); + // append tailing + if (in_pos >= 0) { + while (in_pos < in_size - 1 && in_shape[in_pos + 1] == 1) { + in_pos++; + } + while (out_pos < out_size - 1 && out_shape[out_pos + 1] == 1) { + out_pos++; } } - return reduce_axes; + return std::make_tuple(in_pos - in_start, out_pos - out_start); } -std::vector InferExpandAxes(const Array& input_shape, - const Array& output_shape) { - std::vector expand_axes; - std::vector in_shape, out_shape; - for (const auto& s : input_shape) { - in_shape.push_back(Downcast(s)->value); - } - for (const auto& s : output_shape) { - out_shape.push_back(Downcast(s)->value); - } - size_t start = 0; - while (start < in_shape.size() && expand_axes.size() + in_shape.size() < out_shape.size()) { - if (in_shape[start] == out_shape[start + expand_axes.size()]) { - start++; - } else { - int64_t in_pos, out_pos; - size_t out_start = start + expand_axes.size(); - std::tie(in_pos, out_pos) = AccumulateMatch(in_shape, out_shape, start, out_start); - if (in_pos == -1) { - return std::vector(); +std::tuple, std::vector> InferReshapeAxes( + const Array& input_shape, const Array& output_shape, int batch_dim) { + std::vector expand_axes, reduce_axes; + size_t in_start = 0; + while (in_start < input_shape.size()) { + size_t out_start = in_start + expand_axes.size() - reduce_axes.size(); + int64_t in_dist, out_dist; + std::tie(in_dist, out_dist) = AccumulateMatch(input_shape, output_shape, in_start, out_start); + if (in_dist == -1) { + return std::make_tuple(std::vector(), std::vector()); + } + if (out_dist >= in_dist) { + for (size_t i = 0; i < static_cast(out_dist - in_dist); i++) { + if (batch_dim >= 0 && (out_start + i) == static_cast(batch_dim)) { + expand_axes.push_back(out_start + i + 1); + } else { + expand_axes.push_back(out_start + i); + } } - size_t expand_size = out_pos - in_pos - expand_axes.size(); - for (size_t i = 0; i < expand_size; i++) { - expand_axes.push_back(out_start + i); + } else { + for (size_t i = 0; i < static_cast(in_dist - out_dist); i++) { + if (batch_dim >= 0 && (in_start + i) == static_cast(batch_dim)) { + reduce_axes.push_back(in_start + i + 1); + } else { + reduce_axes.push_back(in_start + i); + } } - start = in_pos + 1; } + in_start += in_dist + 1; } - if (expand_axes.size() + in_shape.size() != out_shape.size()) { - return std::vector(); + if (input_shape.size() + expand_axes.size() - reduce_axes.size() != output_shape.size()) { + return std::make_tuple(std::vector(), std::vector()); } - return expand_axes; + return std::make_tuple(expand_axes, reduce_axes); } // Forward and Backward infer @@ -167,6 +134,11 @@ InferLayoutOutput MSCInferLayoutConv(const Call& call, data_layout = LayoutDecision(attrs->data_layout); kernel_layout = LayoutDecision(attrs->kernel_layout); out_layout = LayoutDecision(attrs->out_layout); + } else if (op_name == "relax.nn.conv2d_transpose") { + const auto* attrs = call->attrs.as(); + data_layout = LayoutDecision(attrs->data_layout); + kernel_layout = LayoutDecision(attrs->kernel_layout); + out_layout = LayoutDecision(attrs->out_layout); } return InferLayoutOutput({data_layout, kernel_layout}, {out_layout}, Attrs()); } @@ -213,18 +185,48 @@ InferLayoutOutput ForwardInferLayoutCommon(const Call& call, if (!layout_hint.defined()) { return InferLayoutOutput(); } - std::vector output_layouts; const auto& sinfo = GetStructInfo(call); if (sinfo->IsInstance()) { - output_layouts.push_back(layout_hint); - } else if (const auto* tuple_sinfo = sinfo.as()) { + return InferLayoutOutput(input_layouts, {layout_hint}, Attrs()); + } + Array output_layouts; + if (const auto* tuple_sinfo = sinfo.as()) { for (size_t i = 0; i < tuple_sinfo->fields.size(); i++) { output_layouts.push_back(layout_hint); } - } else { + return InferLayoutOutput(input_layouts, {output_layouts}, Attrs()); + } + return InferLayoutOutput(); +} + +InferLayoutOutput ForwardInferLayoutBroadcast(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + Array input_layouts; + LayoutDecision layout_hint; + for (const auto& arg : call->args) { + const auto& in_layout = LayoutUtils::InferLayoutDecision(arg, var_layout_map); + if (in_layout->layout.defined()) { + if (!layout_hint.defined() || layout_hint->layout.ndim() < in_layout->layout.ndim()) { + layout_hint = in_layout; + } + } + input_layouts.push_back(in_layout); + } + if (!layout_hint.defined()) { return InferLayoutOutput(); } - return InferLayoutOutput(input_layouts, {output_layouts}, Attrs()); + const auto& sinfo = GetStructInfo(call); + if (sinfo->IsInstance()) { + return InferLayoutOutput(input_layouts, {layout_hint}, Attrs()); + } + return InferLayoutOutput(); +} + +InferLayoutOutput ForwardInferLayoutInplace(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + return ForwardInferLayoutCommon(call, desired_layouts, var_layout_map); } InferLayoutOutput ForwardInferLayoutBinary(const Call& call, @@ -253,12 +255,6 @@ InferLayoutOutput ForwardInferLayoutBinary(const Call& call, return InferLayoutOutput(input_layouts, output->output_layouts, Attrs()); } -InferLayoutOutput ForwardInferLayoutInplace(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - return ForwardInferLayoutCommon(call, desired_layouts, var_layout_map); -} - InferLayoutOutput ForwardInferLayoutArgMaxMin(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { @@ -273,9 +269,7 @@ InferLayoutOutput ForwardInferLayoutArgMaxMin(const Call& call, if (!attrs->axis.defined()) { return InferLayoutOutput({input_layout}, {LayoutDecision("")}, Attrs()); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -288,9 +282,7 @@ InferLayoutOutput ForwardInferLayoutArgMaxMin(const Call& call, InferLayoutOutput ForwardInferLayoutBatchNorm(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -314,9 +306,7 @@ InferLayoutOutput ForkwardInferLayoutExpandDims(const Call& call, if (!input_layout->layout.defined()) { return InferLayoutOutput(); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -332,9 +322,7 @@ InferLayoutOutput ForkwardInferLayoutExpandDims(const Call& call, InferLayoutOutput ForwardInferLayoutNormalize(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -353,12 +341,8 @@ InferLayoutOutput ForwardInferLayoutNormalize(const Call& call, InferLayoutOutput ForwardInferLayoutMatmul(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - Array empty; - const auto& a_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); - const auto& b_shape = - Downcast(GetStructInfo(call->args[1]))->GetShape().value_or(empty); - + const auto& a_shape = ExprUtils::GetShape(call->args[0]); + const auto& b_shape = ExprUtils::GetShape(call->args[1]); if (a_shape.size() == 0) { return InferLayoutOutput(); } @@ -417,9 +401,7 @@ InferLayoutOutput ForwardInferLayoutReduceAxis(const Call& call, if (!attrs->axis.defined()) { return InferLayoutOutput({input_layout}, {LayoutDecision("")}, Attrs()); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -438,29 +420,25 @@ InferLayoutOutput ForwardInferLayoutReshape(const Call& call, if (!input_layout->layout.defined()) { return InferLayoutOutput(); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); - const auto& output_shape = - Downcast(GetStructInfo(call))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& output_shape = ExprUtils::GetShape(call); if (input_shape.size() == 0 || output_shape.size() == 0) { return InferLayoutOutput(); } - LayoutDecision output_layout; - if (input_shape.size() == output_shape.size()) { - output_layout = input_layout; - } else if (input_shape.size() > output_shape.size()) { - const auto& reduce_axes = InferReduceAxes(input_shape, output_shape); - if (reduce_axes.size() == 0) { + LayoutDecision output_layout = input_layout; + if (input_shape.size() != output_shape.size()) { + int batch_dim = LayoutUtils::InferBatchDim(input_layout); + std::vector expand_axes, reduce_axes; + std::tie(expand_axes, reduce_axes) = InferReshapeAxes(input_shape, output_shape, batch_dim); + if (reduce_axes.size() == 0 && expand_axes.size() == 0) { return InferLayoutOutput(); } - output_layout = LayoutUtils::ReduceLayout(input_layout, reduce_axes); - } else { - const auto& expand_axes = InferExpandAxes(input_shape, output_shape); - if (expand_axes.size() == 0) { - return InferLayoutOutput(); + if (reduce_axes.size() > 0) { + output_layout = LayoutUtils::ReduceLayout(output_layout, reduce_axes); + } + if (expand_axes.size() > 0) { + output_layout = LayoutUtils::ExpandLayout(output_layout, expand_axes); } - output_layout = LayoutUtils::ExpandLayout(input_layout, expand_axes); } return InferLayoutOutput({input_layout, LayoutDecision("O")}, {output_layout}, Attrs()); } @@ -472,9 +450,7 @@ InferLayoutOutput ForwardInferLayoutSqueeze(const Call& call, if (!input_layout->layout.defined()) { return InferLayoutOutput(); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -501,12 +477,27 @@ InferLayoutOutput ForwardInferLayoutSqueeze(const Call& call, InferLayoutOutput ForwardInferLayoutTake(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[1], var_layout_map); - if (!input_layout->layout.defined()) { + LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); + LayoutDecision indices_layout = LayoutUtils::InferLayoutDecision(call->args[1], var_layout_map); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& output_shape = ExprUtils::GetShape(call); + if (input_shape.size() == 0) { return InferLayoutOutput(); } - LayoutDecision output_layout = LayoutUtils::ExpandLayout(input_layout, std::vector{0}); - return InferLayoutOutput({LayoutDecision("WE"), input_layout}, {output_layout}, Attrs()); + if (input_layout->layout.defined()) { + if (input_shape.size() == output_shape.size()) { + return InferLayoutOutput({input_layout, indices_layout}, {input_layout}, Attrs()); + } + LayoutDecision output_layout = LayoutUtils::ReduceLayout(input_layout, std::vector{0}); + return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs()); + } + if (indices_layout->layout.defined()) { + size_t indices_size = indices_layout->layout.ndim(); + LayoutDecision output_layout = + LayoutUtils::ExpandLayout(indices_layout, std::vector{indices_size}); + return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs()); + } + return InferLayoutOutput(); } InferLayoutOutput ForwardInferLayoutPlugin(const Call& call, @@ -524,18 +515,27 @@ InferLayoutOutput ForwardInferLayoutPlugin(const Call& call, return (*pf)(args->fields, var_layout_map); } +// nn ops +TVM_REGISTER_OP("relax.nn.avg_pool2d") + .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); +TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") + .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); +TVM_REGISTER_OP("relax.nn.batch_norm") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBatchNorm); TVM_REGISTER_OP("relax.nn.conv1d") .set_attr("FMSCForwardInferLayout", MSCInferLayoutConv); TVM_REGISTER_OP("relax.nn.conv2d") .set_attr("FMSCForwardInferLayout", MSCInferLayoutConv); +TVM_REGISTER_OP("relax.nn.conv2d_transpose") + .set_attr("FMSCForwardInferLayout", MSCInferLayoutConv); +TVM_REGISTER_OP("relax.nn.dropout") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutCommon); +TVM_REGISTER_OP("relax.nn.group_norm") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutNormalize); +TVM_REGISTER_OP("relax.nn.layer_norm") + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutNormalize); TVM_REGISTER_OP("relax.nn.max_pool2d") .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); -TVM_REGISTER_OP("relax.nn.avg_pool2d") - .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); -TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") - .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); -TVM_REGISTER_OP("relax.image.resize2d") - .set_attr("FMSCForwardInferLayout", MSCInferLayoutResize2d); // reduce axis ops TVM_REGISTER_OP("relax.argmax") @@ -554,6 +554,7 @@ TVM_REGISTER_OP("relax.prod") .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); TVM_REGISTER_OP("relax.std") .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); + // binary ops TVM_REGISTER_OP("relax.add") .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBinary); @@ -609,14 +610,8 @@ TVM_REGISTER_OP("relax.squeeze") .set_attr("FMSCForwardInferLayout", ForwardInferLayoutSqueeze); TVM_REGISTER_OP("relax.take") .set_attr("FMSCForwardInferLayout", ForwardInferLayoutTake); - -// nn ops -TVM_REGISTER_OP("relax.nn.batch_norm") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutBatchNorm); -TVM_REGISTER_OP("relax.nn.group_norm") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutNormalize); -TVM_REGISTER_OP("relax.nn.layer_norm") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutNormalize); +TVM_REGISTER_OP("relax.image.resize2d") + .set_attr("FMSCForwardInferLayout", MSCInferLayoutResize2d); // plugin op TVM_REGISTER_OP("relax.call_dps_packed") @@ -695,9 +690,7 @@ InferLayoutOutput BackwardInferLayoutArgMaxMin(const Call& call, if (attrs->keepdims) { return InferLayoutOutput({output_layout}, {output_layout}, Attrs()); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -726,9 +719,7 @@ InferLayoutOutput BackwardInferLayoutExpandDims(const Call& call, if (!output_layout->layout.defined()) { return InferLayoutOutput(); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -759,9 +750,7 @@ InferLayoutOutput BackwardInferLayoutMatmul(const Call& call, if (!output_layout->layout.defined()) { return InferLayoutOutput(); } - Array empty; - const auto& b_shape = - Downcast(GetStructInfo(call->args[1]))->GetShape().value_or(empty); + const auto& b_shape = ExprUtils::GetShape(call->args[1]); if (b_shape.size() == 0) { return InferLayoutOutput(); } @@ -816,9 +805,7 @@ InferLayoutOutput BackwardInferLayoutReduceAxis(const Call& call, if (attrs->keepdims) { return InferLayoutOutput({output_layout}, {output_layout}, Attrs()); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -837,29 +824,25 @@ InferLayoutOutput BackwardInferLayoutReshape(const Call& call, if (!output_layout->layout.defined()) { return InferLayoutOutput(); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); - const auto& output_shape = - Downcast(GetStructInfo(call))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& output_shape = ExprUtils::GetShape(call); if (input_shape.size() == 0 || output_shape.size() == 0) { return InferLayoutOutput(); } - LayoutDecision input_layout; - if (input_shape.size() == output_shape.size()) { - input_layout = output_layout; - } else if (input_shape.size() > output_shape.size()) { - const auto& reduce_axes = InferReduceAxes(input_shape, output_shape); - if (reduce_axes.size() == 0) { + LayoutDecision input_layout = output_layout; + if (input_shape.size() != output_shape.size()) { + int batch_dim = LayoutUtils::InferBatchDim(output_layout); + std::vector expand_axes, reduce_axes; + std::tie(expand_axes, reduce_axes) = InferReshapeAxes(input_shape, output_shape, batch_dim); + if (reduce_axes.size() == 0 && expand_axes.size() == 0) { return InferLayoutOutput(); } - input_layout = LayoutUtils::ExpandLayout(output_layout, reduce_axes); - } else { - const auto& expand_axes = InferExpandAxes(input_shape, output_shape); - if (expand_axes.size() == 0) { - return InferLayoutOutput(); + if (expand_axes.size() > 0) { + input_layout = LayoutUtils::ReduceLayout(input_layout, expand_axes); + } + if (reduce_axes.size() > 0) { + input_layout = LayoutUtils::ExpandLayout(input_layout, reduce_axes); } - input_layout = LayoutUtils::ReduceLayout(output_layout, expand_axes); } return InferLayoutOutput({input_layout, LayoutDecision("O")}, {output_layout}, Attrs()); } @@ -871,9 +854,7 @@ InferLayoutOutput BackwardInferLayoutSqueeze(const Call& call, if (!output_layout->layout.defined()) { return InferLayoutOutput(); } - Array empty; - const auto& input_shape = - Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); if (input_shape.size() == 0) { return InferLayoutOutput(); } @@ -901,12 +882,28 @@ InferLayoutOutput BackwardInferLayoutTake(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); + LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); + LayoutDecision indices_layout = LayoutUtils::InferLayoutDecision(call->args[1], var_layout_map); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& output_shape = ExprUtils::GetShape(call); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } - LayoutDecision input_layout = LayoutUtils::ReduceLayout(output_layout, std::vector{0}); - return InferLayoutOutput({LayoutDecision("WE"), input_layout}, {output_layout}, Attrs()); + if (input_shape.size() == 0) { + return InferLayoutOutput(); + } + if (!indices_layout.defined()) { + indices_layout = LayoutUtils::ReduceLayout(output_layout, std::vector{0}); + } + if (input_shape.size() == output_shape.size()) { + return InferLayoutOutput({output_layout, indices_layout}, {output_layout}, Attrs()); + } + if (!input_layout.defined()) { + input_layout = LayoutUtils::ExpandLayout(output_layout, std::vector{0}); + } + return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs()); } + InferLayoutOutput BackwardInferLayoutTupleInputs(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { @@ -925,18 +922,25 @@ InferLayoutOutput BackwardInferLayoutTupleInputs(const Call& call, return InferLayoutOutput(input_layouts, {output_layout}, Attrs()); } +// nn ops +TVM_REGISTER_OP("relax.nn.avg_pool2d") + .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); +TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") + .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); +TVM_REGISTER_OP("relax.nn.batch_norm") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBatchNorm); TVM_REGISTER_OP("relax.nn.conv1d") .set_attr("FMSCBackwardInferLayout", MSCInferLayoutConv); TVM_REGISTER_OP("relax.nn.conv2d") .set_attr("FMSCBackwardInferLayout", MSCInferLayoutConv); +TVM_REGISTER_OP("relax.nn.conv2d_transpose") + .set_attr("FMSCBackwardInferLayout", MSCInferLayoutConv); +TVM_REGISTER_OP("relax.nn.group_norm") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutNormalize); +TVM_REGISTER_OP("relax.nn.layer_norm") + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutNormalize); TVM_REGISTER_OP("relax.nn.max_pool2d") .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); -TVM_REGISTER_OP("relax.nn.avg_pool2d") - .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); -TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") - .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); -TVM_REGISTER_OP("relax.image.resize2d") - .set_attr("FMSCBackwardInferLayout", MSCInferLayoutResize2d); // reduce axis ops TVM_REGISTER_OP("relax.argmax") @@ -1013,14 +1017,8 @@ TVM_REGISTER_OP("relax.squeeze") .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutSqueeze); TVM_REGISTER_OP("relax.take") .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutTake); - -// nn ops -TVM_REGISTER_OP("relax.nn.batch_norm") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutBatchNorm); -TVM_REGISTER_OP("relax.nn.group_norm") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutNormalize); -TVM_REGISTER_OP("relax.nn.layer_norm") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutNormalize); +TVM_REGISTER_OP("relax.image.resize2d") + .set_attr("FMSCBackwardInferLayout", MSCInferLayoutResize2d); class LayoutInfer : public ExprVisitor { public: @@ -1268,9 +1266,13 @@ class LayoutInfer : public ExprVisitor { SetExprLayout(call->args[i], var_layout_map_[func->params[i]]); } } - if (func->body->body->IsInstance() && - var_layout_map_.count(Downcast(func->body->body))) { - SetExprLayout(ret, var_layout_map_[Downcast(func->body->body)]); + if (const auto* b_node = func->body.as()) { + if (b_node->body->IsInstance() && + var_layout_map_.count(Downcast(b_node->body))) { + SetExprLayout(ret, var_layout_map_[Downcast(b_node->body)]); + } + } else { + LOG(FATAL) << "Function body should be SeqExpr, get " << func->body; } } @@ -1284,9 +1286,13 @@ class LayoutInfer : public ExprVisitor { if (producer->IsInstance() && local_funcs_.count(Downcast(producer)->op)) { const auto& caller = local_funcs_[Downcast(producer)->op]; - if (caller->body->body->IsInstance() && - var_map_.count(Downcast(caller->body->body))) { - SetExprLayout(caller->body->body, param_layout); + if (const auto* b_node = caller->body.as()) { + if (b_node->body->IsInstance() && + var_map_.count(Downcast(b_node->body))) { + SetExprLayout(b_node->body, param_layout); + } + } else { + LOG(FATAL) << "Caller body should be SeqExpr, get " << caller->body; } } } @@ -1298,7 +1304,7 @@ class LayoutInfer : public ExprVisitor { bool infered_; Map var_map_; Array ordered_exprs_; - std::unordered_map var_layout_map_; + std::unordered_map var_layout_map_; Map local_funcs_; }; // class LayoutInfer diff --git a/src/contrib/msc/framework/tensorflow/codegen.cc b/src/contrib/msc/framework/tensorflow/codegen.cc index 9e437f705c34..634dd7969889 100644 --- a/src/contrib/msc/framework/tensorflow/codegen.cc +++ b/src/contrib/msc/framework/tensorflow/codegen.cc @@ -141,7 +141,7 @@ const Array TensorflowCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTFV1OpCodes(); auto it = ops_map->find(node->optype); ICHECK(it != ops_map->end()) << "Unsupported tensorflow op(" << node->optype << "): " << node; - it->second->Config(node, config()); + it->second->Config(node, config(), prims()); try { return it->second->GetDocs(); } catch (runtime::InternalError& err) { @@ -154,6 +154,7 @@ TVM_REGISTER_GLOBAL("msc.framework.tensorflow.GetTensorflowSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String& print_config) -> Map { TensorflowCodeGen codegen = TensorflowCodeGen(graph, codegen_config); + codegen.Init(); return codegen.GetSources(print_config); }); diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc index 717eb75e1f36..a9c16994e5b6 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -544,7 +544,7 @@ const Array TensorRTCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTensorRTOpCodes(); auto it = ops_map->find(GetOpType(node)); ICHECK(it != ops_map->end()) << "Unsupported tensorrt op(" << node->optype << "): " << node; - it->second->Config(node, config()); + it->second->Config(node, config(), prims()); try { return it->second->GetDocs(); } catch (runtime::InternalError& err) { @@ -578,6 +578,7 @@ TVM_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String& print_config) -> Map { TensorRTCodeGen codegen = TensorRTCodeGen(graph, codegen_config); + codegen.Init(); return codegen.GetSources(print_config); }); diff --git a/src/contrib/msc/framework/torch/codegen.cc b/src/contrib/msc/framework/torch/codegen.cc index 54859ad0ce89..86351bdd060b 100644 --- a/src/contrib/msc/framework/torch/codegen.cc +++ b/src/contrib/msc/framework/torch/codegen.cc @@ -142,7 +142,7 @@ const Array TorchCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTorchOpCodes(); auto it = ops_map->find(GetOpType(node)); ICHECK(it != ops_map->end()) << "Unsupported torch op(" << node->optype << "): " << node; - it->second->Config(node, config(), is_init_); + it->second->Config(node, config(), is_init_, prims()); try { return it->second->GetDocs(); } catch (runtime::InternalError& err) { @@ -155,6 +155,7 @@ TVM_REGISTER_GLOBAL("msc.framework.torch.GetTorchSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String& print_config) -> Map { TorchCodeGen codegen = TorchCodeGen(graph, codegen_config); + codegen.Init(); return codegen.GetSources(print_config); }); diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc b/src/contrib/msc/framework/torch/torch_opcode.cc index e355626f859f..9ae825b804aa 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.cc +++ b/src/contrib/msc/framework/torch/torch_opcode.cc @@ -202,6 +202,13 @@ class TorchClipCodeGen : public TorchOpCode { } }; +class TorchConcatCodeGen : public TorchOpCode { + TORCH_OP_CODEGEN_METHODS(TorchConcatCodeGen); + + protected: + void CodeGenForward() final { stack_.op_call().op_inputs_arg().op_arg("axis", "dim"); } +}; + class TorchConstantCodeGen : public TorchOpCode { TORCH_OP_CODEGEN_METHODS(TorchConstantCodeGen); @@ -298,8 +305,8 @@ class TorchEmbeddingCodeGen : public TorchOpCode { void CodeGenInit() final { const auto& weight = node()->WeightAt("weight"); stack_.op_call() - .call_arg(weight->DimAt("W"), "num_embeddings") - .call_arg(weight->DimAt("E"), "embedding_dim"); + .call_arg(weight->DimAt(0), "num_embeddings") + .call_arg(weight->DimAt(1), "embedding_dim"); } }; @@ -706,6 +713,7 @@ const std::shared_ptr>> map->emplace("astype", std::make_shared("", "to")); map->emplace("broadcast_to", std::make_shared("", "expand")); map->emplace("clip", std::make_shared("", "torch.clamp")); + map->emplace("concat", std::make_shared("", "torch.cat")); map->emplace("cumsum", std::make_shared("", "torch.cumsum")); map->emplace("expand_dims", std::make_shared("", "torch.unsqueeze")); map->emplace("permute_dims", std::make_shared("", "torch.permute")); diff --git a/src/contrib/msc/framework/torch/torch_opcode.h b/src/contrib/msc/framework/torch/torch_opcode.h index 6fe5cf5f96c4..80b7f5c60d1d 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.h +++ b/src/contrib/msc/framework/torch/torch_opcode.h @@ -55,9 +55,9 @@ class TorchOpCode : public BaseOpCode { } /*! \brief Config the TorchOpCode*/ - void Config(const MSCJoint& node, const std::shared_ptr config, - bool is_init) { - BaseOpCode::Config(node, config); + void Config(const MSCJoint& node, const std::shared_ptr config, bool is_init, + const Map& prims) { + BaseOpCode::Config(node, config, prims); is_init_ = is_init; module_ref_ = "self." + StringUtils::Replace(node->name, ".", "_"); } diff --git a/src/contrib/msc/framework/tvm/codegen.cc b/src/contrib/msc/framework/tvm/codegen.cc index 783551eed35b..5443cdc96a05 100644 --- a/src/contrib/msc/framework/tvm/codegen.cc +++ b/src/contrib/msc/framework/tvm/codegen.cc @@ -187,11 +187,21 @@ void RelaxCodeGen::CodeGenInference() { } } +const String RelaxCodeGen::DescribePrim(const MSCPrim& prim) { + if (prim->optype == "shape") { + const auto& producer = graph()->FindNode(prim->GetTypeAttr("producer")); + int out_idx = prim->GetTypeAttr("out_idx"); + const auto& dim = prim->GetTypeAttr("dim"); + return IdxOutputBase(producer, out_idx) + ".struct_info.shape[" + dim + "]"; + } + return PyCodeGen::DescribePrim(prim); +} + const Array RelaxCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetRelaxOpCodes(); auto it = ops_map->find(GetOpType(node)); ICHECK(it != ops_map->end()) << "Unsupported relax op(" << node->optype << "): " << node; - it->second->Config(node, config()); + it->second->Config(node, config(), prims()); try { return it->second->GetDocs(); } catch (runtime::InternalError& err) { @@ -204,6 +214,7 @@ TVM_REGISTER_GLOBAL("msc.framework.tvm.GetRelaxSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String& print_config) -> Map { RelaxCodeGen codegen = RelaxCodeGen(graph, codegen_config); + codegen.Init(); return codegen.GetSources(print_config); }); diff --git a/src/contrib/msc/framework/tvm/codegen.h b/src/contrib/msc/framework/tvm/codegen.h index 944d4cdfe1cc..249105b5a50b 100644 --- a/src/contrib/msc/framework/tvm/codegen.h +++ b/src/contrib/msc/framework/tvm/codegen.h @@ -55,6 +55,9 @@ class RelaxCodeGen : public PyCodeGen { /*! \brief Stack the docs for the graph inference*/ void CodeGenInference() final; + /*! \brief Describe the prim*/ + const String DescribePrim(const MSCPrim& prim) final; + /*! \brief Get the docs for the op*/ const Array GetOpCodes(const MSCJoint& node) final; diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc b/src/contrib/msc/framework/tvm/relax_opcode.cc index 0b7ef6aa825e..1913e8ecda8e 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.cc +++ b/src/contrib/msc/framework/tvm/relax_opcode.cc @@ -562,12 +562,8 @@ class RelaxReshapeCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { - stack_.op_call().op_input_arg(); - if (config()->from_relay) { - stack_.op_list_arg("newshape", "shape"); - } else { - stack_.op_list_arg("shape"); - } + const auto& out_shape = GetPrims(node()->OutputAt(0)); + stack_.op_call().op_input_arg().call_arg(DocUtils::ToList(out_shape), "shape"); } }; diff --git a/tests/python/contrib/test_msc/test_graph_build.py b/tests/python/contrib/test_msc/test_graph_build.py index d02767208206..60c8a73dcc67 100644 --- a/tests/python/contrib/test_msc/test_graph_build.py +++ b/tests/python/contrib/test_msc/test_graph_build.py @@ -14,20 +14,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name """ Test graph builder && graph. """ +import pytest import torch from torch import fx from torch.nn import Module import tvm.testing from tvm.relax.frontend.torch import from_fx -from tvm.contrib.msc.core.frontend import translate +from tvm.contrib.msc.core.frontend import translate, normalize_inputs from tvm.contrib.msc.core import utils as msc_utils def verify_model(torch_model, input_info, expected): + input_info = normalize_inputs(input_info) graph_model = fx.symbolic_trace(torch_model) with torch.no_grad(): mod = from_fx(graph_model, input_info) @@ -38,7 +41,8 @@ def verify_model(torch_model, input_info, expected): ) -def test_conv1d(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_conv1d(dynamic): """test graph builder for conv1d""" class Conv1D1(Module): @@ -49,12 +53,6 @@ def __init__(self): def forward(self, data): return self.conv(data) - expected1 = { - "inputs": [{"name": "inp_0", "shape": [1, 3, 10], "dtype": "float32", "layout": "NCW"}], - "outputs": [{"name": "conv1d", "shape": [1, 6, 4], "dtype": "float32", "layout": "NCW"}], - "nodes": {"total": 2, "input": 1, "msc.conv1d_bias": 1}, - } - class Conv1D2(Module): def __init__(self): super().__init__() @@ -63,18 +61,28 @@ def __init__(self): def forward(self, data): return self.conv(data) + bz = "bz" if dynamic else 1 + expected1 = { + "inputs": [{"name": "inp_0", "shape": [bz, 3, 10], "dtype": "float32", "layout": "NCW"}], + "outputs": [{"name": "conv1d", "shape": [bz, 6, 4], "dtype": "float32", "layout": "NCW"}], + "nodes": {"total": 2, "input": 1, "msc.conv1d_bias": 1}, + } expected2 = { - "inputs": [{"name": "inp_0", "shape": [1, 3, 10], "dtype": "float32", "layout": "NCW"}], - "outputs": [{"name": "conv1d", "shape": [1, 6, 4], "dtype": "float32", "layout": "NCW"}], + "inputs": [{"name": "inp_0", "shape": [bz, 3, 10], "dtype": "float32", "layout": "NCW"}], + "outputs": [{"name": "conv1d", "shape": [bz, 6, 4], "dtype": "float32", "layout": "NCW"}], "nodes": {"total": 2, "input": 1, "nn.conv1d": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10], "float32")] + input_info = [([bz, 3, 10], "float32")] verify_model(Conv1D1(), input_info, expected1) verify_model(Conv1D2(), input_info, expected2) -def test_conv2d(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_conv2d(dynamic): """test graph builder for conv2d""" class Conv2D1(Module): @@ -85,44 +93,49 @@ def __init__(self): def forward(self, data): return self.conv(data) + class Conv2D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) + + def forward(self, data): + return self.conv(data) + + bz = "bz" if dynamic else 1 expected1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ { "name": "conv2d", - "shape": [1, 6, 4, 4], + "shape": [bz, 6, 4, 4], "dtype": "float32", "layout": "NCHW", } ], "nodes": {"total": 2, "input": 1, "msc.conv2d_bias": 1}, } - - class Conv2D2(Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) - - def forward(self, data): - return self.conv(data) - expected2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ - {"name": "conv2d", "shape": [1, 6, 4, 4], "dtype": "float32", "layout": "NCHW"} + {"name": "conv2d", "shape": [bz, 6, 4, 4], "dtype": "float32", "layout": "NCHW"} ], "nodes": {"total": 2, "input": 1, "nn.conv2d": 1}, } - input_info = [([1, 3, 10, 10], "float32")] + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} + + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Conv2D1(), input_info, expected1) verify_model(Conv2D2(), input_info, expected2) -def test_linear(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_linear(dynamic): """test graph builder for linear""" class Dense1(Module): @@ -133,123 +146,139 @@ def __init__(self): def forward(self, data): return self.linear(data) + class Dense2(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=False) + + def forward(self, data): + return self.linear(data) + + class MatMul1(Module): + def forward(self, x, y): + return torch.matmul(x, y) + + bz = "bz" if dynamic else 1 + mdim = "mdim" if dynamic else 10 + ndim = "ndim" if dynamic else 20 + kdim = "kdim" if dynamic else 30 + expected1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ { "name": "matmul", - "shape": [1, 3, 10, 7], + "shape": [bz, 3, 10, 7], "dtype": "float32", "layout": "NCHW", } ], "nodes": {"total": 2, "input": 1, "msc.linear_bias": 1}, } - - class Dense2(Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 7, bias=False) - - def forward(self, data): - return self.linear(data) - expected2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ - {"name": "matmul", "shape": [1, 3, 10, 7], "dtype": "float32", "layout": "NCHW"} + {"name": "matmul", "shape": [bz, 3, 10, 7], "dtype": "float32", "layout": "NCHW"} ], "nodes": {"total": 2, "input": 1, "msc.linear": 1}, } - - class MatMul1(Module): - def forward(self, x, y): - return torch.matmul(x, y) - expected3 = { "inputs": [ - {"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": "NC"}, - {"name": "inp_1", "shape": [10, 10], "dtype": "float32", "layout": "IO"}, + {"name": "inp_0", "shape": [mdim, kdim], "dtype": "float32", "layout": "NC"}, + {"name": "inp_1", "shape": [kdim, ndim], "dtype": "float32", "layout": "IO"}, ], - "outputs": [{"name": "matmul", "shape": [10, 10], "dtype": "float32", "layout": "NC"}], + "outputs": [{"name": "matmul", "shape": [mdim, ndim], "dtype": "float32", "layout": "NC"}], "nodes": {"total": 3, "input": 2, "matmul": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} + expected3["prims"] = {"total": 3, "shape": 3} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Dense1(), input_info, expected1) verify_model(Dense2(), input_info, expected2) - verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")], expected3) + verify_model(MatMul1(), [([mdim, kdim], "float32"), ([kdim, ndim], "float32")], expected3) -def test_bmm(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_bmm(dynamic): """test graph builder for bmm""" class BMM(Module): def forward(self, x, y): return torch.bmm(x, y) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [4, 128, 256], "dtype": "float32", "layout": "NCD"}, - {"name": "inp_1", "shape": [4, 256, 512], "dtype": "float32", "layout": "NIO"}, + {"name": "inp_0", "shape": [bz, 128, 256], "dtype": "float32", "layout": "NCD"}, + {"name": "inp_1", "shape": [bz, 256, 512], "dtype": "float32", "layout": "NIO"}, ], "outputs": [ - {"name": "matmul", "shape": [4, 128, 512], "dtype": "float32", "layout": "NCD"} + {"name": "matmul", "shape": [bz, 128, 512], "dtype": "float32", "layout": "NCD"} ], "nodes": {"total": 3, "input": 2, "matmul": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [((4, 128, 256), "float32"), ((4, 256, 512), "float32")] + input_info = [((bz, 128, 256), "float32"), ((bz, 256, 512), "float32")] verify_model(BMM(), input_info, expected) -def test_baddbmm(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_baddbmm(dynamic): """test graph builder for baddbmm""" class BAddBMM1(Module): def forward(self, c, x, y): return torch.baddbmm(c, x, y) + class BAddBMM2(Module): + def forward(self, c, x, y): + return torch.baddbmm(c, x, y, alpha=2, beta=0) + + bz = "bz" if dynamic else 1 expected1 = { "inputs": [ - {"name": "inp_0", "shape": [4, 128, 512], "dtype": "float32", "layout": "NCD"}, - {"name": "inp_1", "shape": [4, 128, 256], "dtype": "float32", "layout": "NCD"}, - {"name": "inp_2", "shape": [4, 256, 512], "dtype": "float32", "layout": "NIO"}, + {"name": "inp_0", "shape": [bz, 128, 512], "dtype": "float32", "layout": "NCD"}, + {"name": "inp_1", "shape": [bz, 128, 256], "dtype": "float32", "layout": "NCD"}, + {"name": "inp_2", "shape": [bz, 256, 512], "dtype": "float32", "layout": "NIO"}, ], - "outputs": [{"name": "add", "shape": [4, 128, 512], "dtype": "float32", "layout": "NCD"}], + "outputs": [{"name": "add", "shape": [bz, 128, 512], "dtype": "float32", "layout": "NCD"}], "nodes": {"total": 5, "input": 3, "matmul": 1, "add": 1}, } - - class BAddBMM2(Module): - def forward(self, c, x, y): - return torch.baddbmm(c, x, y, alpha=2, beta=0) - expected2 = { "inputs": [ - {"name": "inp_0", "shape": [4, 128, 512], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [4, 128, 256], "dtype": "float32", "layout": "NCD"}, - {"name": "inp_2", "shape": [4, 256, 512], "dtype": "float32", "layout": "NIO"}, + {"name": "inp_0", "shape": [bz, 128, 512], "dtype": "float32", "layout": ""}, + {"name": "inp_1", "shape": [bz, 128, 256], "dtype": "float32", "layout": "NCD"}, + {"name": "inp_2", "shape": [bz, 256, 512], "dtype": "float32", "layout": "NIO"}, ], "outputs": [ - {"name": "multiply", "shape": [4, 128, 512], "dtype": "float32", "layout": "NCD"} + {"name": "multiply", "shape": [bz, 128, 512], "dtype": "float32", "layout": "NCD"} ], "nodes": {"total": 6, "input": 3, "matmul": 1, "constant": 1, "multiply": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} input_info = [ - ((4, 128, 512), "float32"), - ((4, 128, 256), "float32"), - ((4, 256, 512), "float32"), + ((bz, 128, 512), "float32"), + ((bz, 128, 256), "float32"), + ((bz, 256, 512), "float32"), ] verify_model(BAddBMM1(), input_info, expected1) verify_model(BAddBMM2(), input_info, expected2) -def test_relu(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_relu(dynamic): """test graph builder for relu""" class ReLU(Module): @@ -264,18 +293,22 @@ class ReLU1(Module): def forward(self, data): return torch.nn.functional.relu(data) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": "AB"}], - "outputs": [{"name": "relu", "shape": [10, 10], "dtype": "float32", "layout": "AB"}], + "inputs": [{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": "AB"}], + "outputs": [{"name": "relu", "shape": [bz, 10], "dtype": "float32", "layout": "AB"}], "nodes": {"total": 2, "input": 1, "nn.relu": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([10, 10], "float32")] + input_info = [([bz, 10], "float32")] verify_model(ReLU(), input_info, expected) verify_model(ReLU1(), input_info, expected) -def test_relu6(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_relu6(dynamic): """test graph builder for relu6""" class ReLU6(Module): @@ -286,16 +319,21 @@ def __init__(self): def forward(self, data): return self.relu6(data) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "clip", "shape": [10, 10], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "clip", "shape": [bz, 10], "dtype": "float32", "layout": ""}], "nodes": {"total": 2, "input": 1, "clip": 1}, } - input_info = [([10, 10], "float32")] + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} + + input_info = [([bz, 10], "float32")] verify_model(ReLU6(), input_info, expected) -def test_maxpool2d(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_maxpool2d(dynamic): """test graph builder for maxpool2d""" class MaxPool2d(Module): @@ -306,16 +344,6 @@ def __init__(self): def forward(self, data): return self.pool(data) - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "max_pool2d", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1}, - } - class MaxPool2d2(Module): def __init__(self): super().__init__() @@ -324,16 +352,6 @@ def __init__(self): def forward(self, data): return self.pool(data) - expected2 = { - "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "max_pool2d", "shape": [1, 3, 4, 4], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1}, - } - class MaxPool2d3(Module): def __init__(self): super().__init__() @@ -342,23 +360,47 @@ def __init__(self): def forward(self, data): return self.pool(data) + bz = "bz" if dynamic else 1 + expected1 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + ], + "outputs": [ + {"name": "max_pool2d", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + ], + "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1}, + } + expected2 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + ], + "outputs": [ + {"name": "max_pool2d", "shape": [bz, 3, 4, 4], "dtype": "float32", "layout": "NCHW"} + ], + "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1}, + } expected3 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ - {"name": "max_pool2d", "shape": [1, 3, 6, 6], "dtype": "float32", "layout": "NCHW"} + {"name": "max_pool2d", "shape": [bz, 3, 6, 6], "dtype": "float32", "layout": "NCHW"} ], "nodes": {"total": 2, "input": 1, "nn.max_pool2d": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} + expected3["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(MaxPool2d(), input_info, expected1) verify_model(MaxPool2d2(), input_info, expected2) verify_model(MaxPool2d3(), input_info, expected3) -def test_avgpool2d(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_avgpool2d(dynamic): """test graph builder for avgpool2d""" class AvgPool2d(Module): @@ -369,16 +411,6 @@ def __init__(self): def forward(self, data): return self.pool(data) - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "outputs": [ - {"name": "avg_pool2d", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} - ], - "nodes": {"total": 2, "input": 1, "nn.avg_pool2d": 1}, - } - class AvgPool2d2(Module): def __init__(self): super().__init__() @@ -387,22 +419,36 @@ def __init__(self): def forward(self, data): return self.pool(data) + bz = "bz" if dynamic else 1 + expected1 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + ], + "outputs": [ + {"name": "avg_pool2d", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + ], + "nodes": {"total": 2, "input": 1, "nn.avg_pool2d": 1}, + } expected2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ - {"name": "avg_pool2d", "shape": [1, 3, 6, 6], "dtype": "float32", "layout": "NCHW"} + {"name": "avg_pool2d", "shape": [bz, 3, 6, 6], "dtype": "float32", "layout": "NCHW"} ], "nodes": {"total": 2, "input": 1, "nn.avg_pool2d": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(AvgPool2d(), input_info, expected1) verify_model(AvgPool2d2(), input_info, expected2) -def test_adaptive_avgpool2d(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_adaptive_avgpool2d(dynamic): """test graph builder for adaptive_avgpool2d""" class AdaptiveAvgPool2d0(Module): @@ -413,26 +459,30 @@ def __init__(self): def forward(self, data): return self.pool(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ { "name": "adaptive_avg_pool2d", - "shape": [1, 3, 10, 10], + "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW", } ], "nodes": {"total": 2, "input": 1, "nn.adaptive_avg_pool2d": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(AdaptiveAvgPool2d0(), input_info, expected) -def test_flatten(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_flatten(dynamic): """test graph builder for flatten""" class Flatten(Module): @@ -443,18 +493,26 @@ def __init__(self): def forward(self, data): return self.f(data) + bz = "bz" if dynamic else 1 + dim = "dim" if dynamic else 10 + out_dim = "MUL_3" if dynamic else 100 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "reshape", "shape": [1, 3, 100], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 3, 10, dim], "dtype": "float32", "layout": ""}], + "outputs": [ + {"name": "reshape", "shape": [bz, 3, out_dim], "dtype": "float32", "layout": ""} + ], "nodes": {"total": 2, "input": 1, "reshape": 1}, } + if dynamic: + expected["prims"] = {"total": 4, "shape": 2, "Int": 1, "Mul": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, dim], "float32")] verify_model(Flatten(), input_info, expected) verify_model(torch.nn.Flatten(2, -1), input_info, expected) -def test_batchnorm2d(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_batchnorm2d(dynamic): """test graph builder for batchnorm2d""" class BatchNorm2d(Module): @@ -465,26 +523,30 @@ def __init__(self): def forward(self, data): return self.batchnorm(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ { "name": "batch_norm.0", - "shape": [1, 3, 10, 10], + "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW", } ], "nodes": {"total": 3, "input": 1, "nn.batch_norm": 1, "get_item": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(BatchNorm2d(), input_info, expected) -def test_embedding(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_embedding(dynamic): """test graph builder for embedding""" class Embedding(Module): @@ -495,23 +557,34 @@ def __init__(self): def forward(self, data): return self.embedding(data) + vocab = "vocab" if dynamic else 4 expected1 = { - "inputs": [{"name": "inp_0", "shape": [4], "dtype": "int64", "layout": "A"}], - "outputs": [{"name": "take", "shape": [4, 3], "dtype": "float32", "layout": "NA"}], + "inputs": [{"name": "inp_0", "shape": [vocab], "dtype": "int64", "layout": "A"}], + "outputs": [{"name": "take", "shape": [vocab, 3], "dtype": "float32", "layout": "AB"}], "nodes": {"total": 2, "input": 1, "msc.embedding": 1}, } - expected2 = { - "inputs": [{"name": "inp_0", "shape": [4, 5], "dtype": "int64", "layout": "AB"}], - "outputs": [{"name": "take", "shape": [4, 5, 3], "dtype": "float32", "layout": "CNB"}], + "inputs": [{"name": "inp_0", "shape": [vocab, 5], "dtype": "int64", "layout": "AB"}], + "outputs": [ + { + "name": "take", + "shape": [vocab, 5, 3], + "dtype": "float32", + "layout": "" if dynamic else "CBA", + } + ], "nodes": {"total": 2, "input": 1, "msc.embedding": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1} - verify_model(Embedding(), [([4], "int64")], expected1) - verify_model(Embedding(), [([4, 5], "int64")], expected2) + verify_model(Embedding(), [([vocab], "int64")], expected1) + verify_model(Embedding(), [([vocab, 5], "int64")], expected2) -def test_dropout(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_dropout(dynamic): """test graph builder for dropout""" class Dropout1(Module): @@ -526,18 +599,22 @@ class Dropout2(Module): def forward(self, data): return torch.dropout(data, 0.5, train=True) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], "nodes": {"total": 1, "input": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Dropout1(), input_info, expected) verify_model(Dropout2(), input_info, expected) -def test_layernorm(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_layernorm(dynamic): """test graph builder for layernorm""" class LayerNorm(Module): @@ -548,21 +625,25 @@ def __init__(self): def forward(self, data): return self.layernorm(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ - {"name": "layer_norm", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "layer_norm", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "nodes": {"total": 2, "input": 1, "nn.layer_norm": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(LayerNorm(), input_info, expected) -def test_functional_layernorm(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_functional_layernorm(dynamic): """test graph builder for functional_layernorm""" class LayerNorm(Module): @@ -576,21 +657,25 @@ def forward(self, data): data, self.weight.shape, self.weight, self.bias, 1e-5 ) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ - {"name": "layer_norm", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "layer_norm", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "nodes": {"total": 2, "input": 1, "nn.layer_norm": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(LayerNorm((10, 10)), input_info, expected) -def test_cross_entropy(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_cross_entropy(dynamic): """test graph builder for cross_entropy""" class CrossEntropy1(Module): @@ -601,15 +686,6 @@ def __init__(self): def forward(self, logits, targets): return self.loss(logits, targets) - expected1 = { - "inputs": [ - {"name": "inp_0", "shape": [3, 2], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [3], "dtype": "int32", "layout": ""}, - ], - "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}], - "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 1}, - } - class CrossEntropy2(Module): def __init__(self): super().__init__() @@ -619,15 +695,6 @@ def __init__(self): def forward(self, logits, targets): return self.loss(logits, targets) - expected2 = { - "inputs": [ - {"name": "inp_0", "shape": [3, 2], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [3], "dtype": "int32", "layout": ""}, - ], - "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}], - "nodes": {"total": 5, "input": 2, "nn.log_softmax": 1, "constant": 1, "nn.nll_loss": 1}, - } - class CrossEntropy3(Module): def __init__(self): super().__init__() @@ -636,42 +703,68 @@ def __init__(self): def forward(self, logits, targets): return self.loss(logits, targets) + bz = "bz" if dynamic else 1 + expected1 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 2], "dtype": "float32", "layout": ""}, + {"name": "inp_1", "shape": [bz], "dtype": "int32", "layout": ""}, + ], + "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}], + "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 1}, + } + expected2 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 2], "dtype": "float32", "layout": ""}, + {"name": "inp_1", "shape": [bz], "dtype": "int32", "layout": ""}, + ], + "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}], + "nodes": {"total": 5, "input": 2, "nn.log_softmax": 1, "constant": 1, "nn.nll_loss": 1}, + } expected3 = { "inputs": [ - {"name": "inp_0", "shape": [3, 2], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [3], "dtype": "int32", "layout": ""}, + {"name": "inp_0", "shape": [bz, 2], "dtype": "float32", "layout": ""}, + {"name": "inp_1", "shape": [bz], "dtype": "int32", "layout": ""}, ], "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}], "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} + expected3["prims"] = {"total": 1, "shape": 1} - input_info = [([3, 2], "float32"), ([3], "int32")] + input_info = [([bz, 2], "float32"), ([bz], "int32")] verify_model(CrossEntropy1(), input_info, expected1) verify_model(CrossEntropy2(), input_info, expected2) verify_model(CrossEntropy3(), input_info, expected3) -def test_functional_cross_entropy(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_functional_cross_entropy(dynamic): """test graph builder for functional_cross_entropy""" class CrossEntropy(Module): def forward(self, logits, targets): return torch.nn.functional.cross_entropy(logits, targets) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [3, 10], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [3], "dtype": "int32", "layout": ""}, + {"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": ""}, + {"name": "inp_1", "shape": [bz], "dtype": "int32", "layout": ""}, ], "outputs": [{"name": "nll_loss", "shape": [], "dtype": "float32", "layout": ""}], "nodes": {"total": 4, "input": 2, "nn.log_softmax": 1, "nn.nll_loss": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([3, 10], "float32"), ([3], "int32")] + input_info = [([bz, 10], "float32"), ([bz], "int32")] verify_model(CrossEntropy(), input_info, expected) -def test_silu(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_silu(dynamic): """test graph builder for silu""" class SiLU(Module): @@ -686,22 +779,26 @@ class SiLU2(Module): def forward(self, data): return torch.nn.functional.silu(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "silu", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "silu", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "nn.silu": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(SiLU(), input_info, expected) verify_model(SiLU2(), input_info, expected) -def test_groupnorm(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_groupnorm(dynamic): """test graph builder for groupnorm""" class GroupNorm(Module): @@ -712,21 +809,25 @@ def __init__(self): def forward(self, data): return self.groupnorm(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ - {"name": "group_norm", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "group_norm", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "nodes": {"total": 2, "input": 1, "nn.group_norm": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(GroupNorm(), input_info, expected) -def test_softmax(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_softmax(dynamic): """test graph builder for softmax""" class Softmax(Module): @@ -737,51 +838,62 @@ def __init__(self): def forward(self, data): return self.softmax(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "softmax", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "softmax", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "nn.softmax": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Softmax(), input_info, expected) -def test_binary(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_binary(dynamic): """test graph builder for binary""" - input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")] - input_info2 = [([1, 3, 10, 10], "float32")] + bz = "bz" if dynamic else 1 + input_info1 = [([bz, 3, 10, 10], "float32"), ([bz, 3, 10, 10], "float32")] + input_info2 = [([bz, 3, 10, 10], "float32")] # Add class Add1(Module): def forward(self, lhs, rhs): return lhs + rhs + class Add2(Module): + def forward(self, lhs): + return lhs + 1.0 + expected_add1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + ], + "outputs": [ + {"name": "add", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], - "outputs": [{"name": "add", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}], "nodes": {"total": 3, "input": 2, "add": 1}, } - - class Add2(Module): - def forward(self, lhs): - return lhs + 1.0 - expected_add2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + ], + "outputs": [ + {"name": "add", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], - "outputs": [{"name": "add", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}], "nodes": {"total": 3, "input": 1, "constant": 1, "add": 1}, } + if dynamic: + expected_add1["prims"] = {"total": 1, "shape": 1} + expected_add2["prims"] = {"total": 1, "shape": 1} verify_model(Add1(), input_info1, expected_add1) verify_model(Add2(), input_info2, expected_add2) @@ -791,30 +903,32 @@ class Sub1(Module): def forward(self, lhs, rhs): return lhs - rhs + class Sub2(Module): + def forward(self, lhs): + return lhs - 1.0 + expected_sub1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], "outputs": [ - {"name": "subtract", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "subtract", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 3, "input": 2, "subtract": 1}, } - - class Sub2(Module): - def forward(self, lhs): - return lhs - 1.0 - expected_sub2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "subtract", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "subtract", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 3, "input": 1, "constant": 1, "subtract": 1}, } + if dynamic: + expected_sub1["prims"] = {"total": 1, "shape": 1} + expected_sub2["prims"] = {"total": 1, "shape": 1} verify_model(Sub1(), input_info1, expected_sub1) verify_model(Sub2(), input_info2, expected_sub2) @@ -824,30 +938,32 @@ class Mul1(Module): def forward(self, lhs, rhs): return lhs * rhs + class Mul2(Module): + def forward(self, lhs): + return lhs * 1.0 + expected_mul1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], "outputs": [ - {"name": "multiply", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "multiply", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 3, "input": 2, "multiply": 1}, } - - class Mul2(Module): - def forward(self, lhs): - return lhs * 1.0 - expected_mul2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "multiply", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "multiply", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 3, "input": 1, "constant": 1, "multiply": 1}, } + if dynamic: + expected_mul1["prims"] = {"total": 1, "shape": 1} + expected_mul2["prims"] = {"total": 1, "shape": 1} verify_model(Mul1(), input_info1, expected_mul1) verify_model(Mul2(), input_info2, expected_mul2) @@ -857,30 +973,32 @@ class TrueDiv1(Module): def forward(self, lhs, rhs): return lhs / rhs + class TrueDiv2(Module): + def forward(self, lhs): + return lhs / 1.0 + expected_div1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], "outputs": [ - {"name": "divide", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "divide", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 3, "input": 2, "divide": 1}, } - - class TrueDiv2(Module): - def forward(self, lhs): - return lhs / 1.0 - expected_div2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "divide", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "divide", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 3, "input": 1, "constant": 1, "divide": 1}, } + if dynamic: + expected_div1["prims"] = {"total": 1, "shape": 1} + expected_div2["prims"] = {"total": 1, "shape": 1} verify_model(TrueDiv1(), input_info1, expected_div1) verify_model(TrueDiv2(), input_info2, expected_div2) @@ -890,40 +1008,42 @@ class FloorDiv1(Module): def forward(self, lhs, rhs): return lhs // rhs + class FloorDiv2(Module): + def forward(self, lhs): + return lhs // 1.0 + expected_floordiv1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], "outputs": [ { "name": "floor_divide", - "shape": [1, 3, 10, 10], + "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD", } ], "nodes": {"total": 3, "input": 2, "floor_divide": 1}, } - - class FloorDiv2(Module): - def forward(self, lhs): - return lhs // 1.0 - expected_floordiv2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ { "name": "floor_divide", - "shape": [1, 3, 10, 10], + "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD", } ], "nodes": {"total": 3, "input": 1, "constant": 1, "floor_divide": 1}, } + if dynamic: + expected_floordiv1["prims"] = {"total": 1, "shape": 1} + expected_floordiv2["prims"] = {"total": 1, "shape": 1} verify_model(FloorDiv1(), input_info1, expected_floordiv1) verify_model(FloorDiv2(), input_info2, expected_floordiv2) @@ -933,30 +1053,32 @@ class Power1(Module): def forward(self, lhs, rhs): return lhs**rhs + class Power2(Module): + def forward(self, lhs): + return lhs**1.0 + expected_power1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], "outputs": [ - {"name": "power", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "power", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 3, "input": 2, "power": 1}, } - - class Power2(Module): - def forward(self, lhs): - return lhs**1.0 - expected_power2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "power", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "power", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 3, "input": 1, "constant": 1, "power": 1}, } + if dynamic: + expected_power1["prims"] = {"total": 1, "shape": 1} + expected_power2["prims"] = {"total": 1, "shape": 1} verify_model(Power1(), input_info1, expected_power1) verify_model(Power2(), input_info2, expected_power2) @@ -966,176 +1088,214 @@ class LT1(Module): def forward(self, lhs, rhs): return lhs < rhs + class LT2(Module): + def forward(self, lhs): + return lhs < 1.0 + expected_lt1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "inp_1", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_1", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], - "outputs": [{"name": "less", "shape": [1, 3, 10, 10], "dtype": "bool", "layout": "ABCD"}], + "outputs": [{"name": "less", "shape": [bz, 3, 10, 10], "dtype": "bool", "layout": "ABCD"}], "nodes": {"total": 3, "input": 2, "less": 1}, } - - class LT2(Module): - def forward(self, lhs): - return lhs < 1.0 - expected_lt2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], - "outputs": [{"name": "less", "shape": [1, 3, 10, 10], "dtype": "bool", "layout": "ABCD"}], + "outputs": [{"name": "less", "shape": [bz, 3, 10, 10], "dtype": "bool", "layout": "ABCD"}], "nodes": {"total": 3, "input": 1, "constant": 1, "less": 1}, } + if dynamic: + expected_lt1["prims"] = {"total": 1, "shape": 1} + expected_lt2["prims"] = {"total": 1, "shape": 1} verify_model(LT1(), input_info1, expected_lt1) verify_model(LT2(), input_info2, expected_lt2) -def test_size(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_size(dynamic): """test graph builder for size""" class Size(Module): def forward(self, data): return data.size() + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], "outputs": [{"name": "shape", "shape": [4], "dtype": "int32", "layout": "O"}], "nodes": {"total": 2, "input": 1, "shape": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Size(), input_info, expected) -def test_squeeze(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_squeeze(dynamic): """test graph builder for squeeze""" class Squeeze1(Module): def forward(self, data): return data.squeeze(1) - expected1 = { - "inputs": [{"name": "inp_0", "shape": [3, 1, 4, 1], "dtype": "float32", "layout": "ANBC"}], - "outputs": [{"name": "squeeze", "shape": [3, 4, 1], "dtype": "float32", "layout": "ABC"}], - "nodes": {"total": 2, "input": 1, "squeeze": 1}, - } - class Squeeze2(Module): def forward(self, data): return data.squeeze() - expected2 = { - "inputs": [{"name": "inp_0", "shape": [3, 1, 4, 1], "dtype": "float32", "layout": "ANBC"}], - "outputs": [{"name": "squeeze", "shape": [3, 4], "dtype": "float32", "layout": "AB"}], + bz = "bz" if dynamic else 10 + expected1 = { + "inputs": [{"name": "inp_0", "shape": [bz, 1, 4, 1], "dtype": "float32", "layout": "ADBC"}], + "outputs": [{"name": "squeeze", "shape": [bz, 4, 1], "dtype": "float32", "layout": "ABC"}], "nodes": {"total": 2, "input": 1, "squeeze": 1}, } - - input_info = [([3, 1, 4, 1], "float32")] + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 1, 4, 1], "dtype": "float32", "layout": "ACBD"} + ], + "outputs": [{"name": "squeeze", "shape": [], "dtype": "float32", "layout": "AB"}], + "nodes": {"total": 2, "input": 1, "squeeze": 1}, + "prims": {"total": 1, "shape": 1}, + } + else: + expected2 = { + "inputs": [ + {"name": "inp_0", "shape": [bz, 1, 4, 1], "dtype": "float32", "layout": "ACBD"} + ], + "outputs": [{"name": "squeeze", "shape": [bz, 4], "dtype": "float32", "layout": "AB"}], + "nodes": {"total": 2, "input": 1, "squeeze": 1}, + } + input_info = [([bz, 1, 4, 1], "float32")] verify_model(Squeeze1(), input_info, expected1) verify_model(Squeeze2(), input_info, expected2) -def test_unsqueeze(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_unsqueeze(dynamic): """test graph builder for unsqueeze""" class Unsqueeze1(Module): def forward(self, data): return data.unsqueeze(1) + class Unsqueeze2(Module): + def forward(self, data): + return data.unsqueeze(-1) + + bz = "bz" if dynamic else 1 expected1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ACDE"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ACDE"} ], "outputs": [ { "name": "expand_dims", - "shape": [1, 1, 3, 10, 10], + "shape": [bz, 1, 3, 10, 10], "dtype": "float32", "layout": "ABCDE", } ], "nodes": {"total": 2, "input": 1, "expand_dims": 1}, } - - class Unsqueeze2(Module): - def forward(self, data): - return data.unsqueeze(-1) - expected2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCE"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCE"} ], "outputs": [ { "name": "expand_dims", - "shape": [1, 3, 10, 10, 1], + "shape": [bz, 3, 10, 10, 1], "dtype": "float32", "layout": "ABCDE", } ], "nodes": {"total": 2, "input": 1, "expand_dims": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Unsqueeze1(), input_info, expected1) verify_model(Unsqueeze2(), input_info, expected2) -def test_getattr(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_getattr(dynamic): """test graph builder for getattr""" class GetAttr1(Module): def forward(self, data): return data.shape + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], "outputs": [{"name": "shape", "shape": [4], "dtype": "int32", "layout": "O"}], "nodes": {"total": 2, "input": 1, "shape": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(GetAttr1(), input_info, expected) -def test_getitem(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_getitem(dynamic): """test graph builder for getitem""" class Slice1(Module): def forward(self, x): return x[0, 1::2, :, :3] + class Slice2(Module): + def forward(self, x): + return x[:, None, None, :, None] + + bz = "bz" if dynamic else 1 expected1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "reshape", "shape": [1, 1, 10, 3], "dtype": "float32", "layout": "ABCD"} + { + "name": "reshape", + "shape": ["MIN_2" if dynamic else 1, 1, 10, 3], + "dtype": "float32", + "layout": "ABCD", + } ], "nodes": {"total": 3, "input": 1, "strided_slice": 1, "reshape": 1}, } - - class Slice2(Module): - def forward(self, x): - return x[:, None, None, :, None] - expected2 = { - "inputs": [{"name": "inp_0", "shape": [8, 16], "dtype": "float32", "layout": "AB"}], + "inputs": [{"name": "inp_0", "shape": [bz, 16], "dtype": "float32", "layout": "AB"}], "outputs": [ - {"name": "reshape", "shape": [8, 1, 1, 16, 1], "dtype": "float32", "layout": "ANCHB"} + {"name": "reshape", "shape": [bz, 1, 1, 16, 1], "dtype": "float32", "layout": "CDAEB"} ], "nodes": {"total": 3, "input": 1, "strided_slice": 1, "reshape": 1}, } + if dynamic: + expected1["prims"] = {"total": 3, "shape": 1, "Int": 1, "Min": 1} + expected2["prims"] = {"total": 1, "shape": 1} - verify_model(Slice1(), [([1, 3, 10, 10], "float32")], expected1) - verify_model(Slice2(), [([8, 16], "float32")], expected2) + verify_model(Slice1(), [([bz, 3, 10, 10], "float32")], expected1) + verify_model(Slice2(), [([bz, 16], "float32")], expected2) -def test_unary(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_unary(dynamic): """test graph builder for unary""" - input_info = [([1, 3, 10, 10], "float32")] + bz = "bz" if dynamic else 1 + input_info = [([bz, 3, 10, 10], "float32")] # sin class Sin(Module): @@ -1144,11 +1304,15 @@ def forward(self, data): expected_sin = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + ], + "outputs": [ + {"name": "sin", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], - "outputs": [{"name": "sin", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}], "nodes": {"total": 2, "input": 1, "sin": 1}, } + if dynamic: + expected_sin["prims"] = {"total": 1, "shape": 1} verify_model(Sin(), input_info, expected_sin) @@ -1159,11 +1323,15 @@ def forward(self, data): expected_cos = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + ], + "outputs": [ + {"name": "cos", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], - "outputs": [{"name": "cos", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}], "nodes": {"total": 2, "input": 1, "cos": 1}, } + if dynamic: + expected_cos["prims"] = {"total": 1, "shape": 1} verify_model(Cos(), input_info, expected_cos) @@ -1174,11 +1342,15 @@ def forward(self, data): expected_exp = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + ], + "outputs": [ + {"name": "exp", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], - "outputs": [{"name": "exp", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"}], "nodes": {"total": 2, "input": 1, "exp": 1}, } + if dynamic: + expected_exp["prims"] = {"total": 1, "shape": 1} verify_model(Exp(), input_info, expected_exp) @@ -1189,13 +1361,15 @@ def forward(self, data): expected_sqrt = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "sqrt", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "sqrt", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "sqrt": 1}, } + if dynamic: + expected_sqrt["prims"] = {"total": 1, "shape": 1} verify_model(Sqrt(), input_info, expected_sqrt) @@ -1206,13 +1380,15 @@ def forward(self, data): expected_sigmoid = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "sigmoid", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "sigmoid", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "sigmoid": 1}, } + if dynamic: + expected_sigmoid["prims"] = {"total": 1, "shape": 1} verify_model(Sigmoid(), input_info, expected_sigmoid) @@ -1223,123 +1399,144 @@ def forward(self, data): expected_round = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "round", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "round", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "round": 1}, } + if dynamic: + expected_round["prims"] = {"total": 1, "shape": 1} verify_model(Round(), input_info, expected_round) -def test_gelu(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_gelu(dynamic): """test graph builder for gelu""" class Gelu(Module): def forward(self, data): return torch.nn.functional.gelu(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "gelu", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "gelu", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "nn.gelu": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Gelu(), input_info, expected) -def test_tanh(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_tanh(dynamic): """test graph builder for tanh""" class Tanh(Module): def forward(self, data): return torch.tanh(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "tanh", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "tanh", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "tanh": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Tanh(), input_info, expected) -def test_clamp(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_clamp(dynamic): """test graph builder for clamp""" class Clamp(Module): def forward(self, data): return torch.clamp(data, min=0.1, max=0.5) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "clip", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "clip", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": ""}], "nodes": {"total": 2, "input": 1, "clip": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Clamp(), input_info, expected) -def test_interpolate(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_interpolate(dynamic): """test graph builder for interpolate""" class Interpolate(Module): def forward(self, data): return torch.nn.functional.interpolate(data, (5, 5)) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ - {"name": "resize2d", "shape": [1, 3, 5, 5], "dtype": "float32", "layout": "NCHW"} + {"name": "resize2d", "shape": [bz, 3, 5, 5], "dtype": "float32", "layout": "NCHW"} ], "nodes": {"total": 2, "input": 1, "image.resize2d": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Interpolate(), input_info, expected) -def test_addmm(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_addmm(dynamic): """test graph builder for addmm""" class Addmm(Module): def forward(self, x_1, x_2, x_3): return torch.addmm(x_1, x_2, x_3) + mdim = "mdim" if dynamic else 10 + ndim = "ndim" if dynamic else 20 + kdim = "kdim" if dynamic else 30 expected = { "inputs": [ - {"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": "NC"}, - {"name": "inp_1", "shape": [10, 10], "dtype": "float32", "layout": "NC"}, - {"name": "inp_2", "shape": [10, 10], "dtype": "float32", "layout": "IO"}, + {"name": "inp_0", "shape": [mdim, ndim], "dtype": "float32", "layout": "NC"}, + {"name": "inp_1", "shape": [mdim, kdim], "dtype": "float32", "layout": "NC"}, + {"name": "inp_2", "shape": [kdim, ndim], "dtype": "float32", "layout": "IO"}, ], - "outputs": [{"name": "add", "shape": [10, 10], "dtype": "float32", "layout": "NC"}], + "outputs": [{"name": "add", "shape": [mdim, ndim], "dtype": "float32", "layout": "NC"}], "nodes": {"total": 5, "input": 3, "matmul": 1, "add": 1}, } + if dynamic: + expected["prims"] = {"total": 3, "shape": 3} - input_info = [ - ([10, 10], "float32"), - ([10, 10], "float32"), - ([10, 10], "float32"), - ] + input_info = [([mdim, ndim], "float32"), ([mdim, kdim], "float32"), ([kdim, ndim], "float32")] verify_model(Addmm(), input_info, expected) -def test_split(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_split(dynamic): """test graph builder for split""" class Split1(Module): @@ -1350,98 +1547,114 @@ class Split2(Module): def forward(self, data): return torch.split(data, [1, 2], dim=1) + bz = "bz" if dynamic else 1 expected1 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "split_0", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "split_1", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "split_2", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_0", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_1", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_2", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], "nodes": {"total": 2, "input": 1, "split": 1}, } - expected2 = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "split_0", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "split_1", "shape": [1, 2, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_0", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_1", "shape": [bz, 2, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], "nodes": {"total": 2, "input": 1, "split": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Split1(), input_info, expected1) verify_model(Split2(), input_info, expected2) -def test_unbind(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_unbind(dynamic): """test graph builder for unbind""" class Unbind(Module): def forward(self, data): return torch.unbind(data, dim=1) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "tuple_0", "shape": [1, 10, 10], "dtype": "float32", "layout": "ACD"}, - {"name": "tuple_1", "shape": [1, 10, 10], "dtype": "float32", "layout": "ACD"}, - {"name": "tuple_2", "shape": [1, 10, 10], "dtype": "float32", "layout": "ACD"}, + {"name": "tuple_0", "shape": [bz, 10, 10], "dtype": "float32", "layout": "ACD"}, + {"name": "tuple_1", "shape": [bz, 10, 10], "dtype": "float32", "layout": "ACD"}, + {"name": "tuple_2", "shape": [bz, 10, 10], "dtype": "float32", "layout": "ACD"}, ], "nodes": {"total": 9, "input": 1, "split": 1, "get_item": 3, "squeeze": 3, "tuple": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Unbind(), input_info, expected) -def test_cumsum(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_cumsum(dynamic): """test graph builder for cumsum""" class Cumsum(Module): def forward(self, data): return torch.cumsum(data, dim=1, dtype=torch.int32) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "cumsum", "shape": [1, 2, 3, 4], "dtype": "int32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "cumsum", "shape": [bz, 2, 3, 4], "dtype": "int32", "layout": ""}], "nodes": {"total": 2, "input": 1, "cumsum": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 2, 3, 4], "float32")] + input_info = [([bz, 2, 3, 4], "float32")] verify_model(Cumsum(), input_info, expected) -def test_chunk(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_chunk(dynamic): """test graph builder for chunk""" class Chunk(Module): def forward(self, data): return torch.chunk(data, 3, dim=1) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "ABCD"} ], "outputs": [ - {"name": "split_0", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "split_1", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, - {"name": "split_2", "shape": [1, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_0", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_1", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, + {"name": "split_2", "shape": [bz, 1, 10, 10], "dtype": "float32", "layout": "ABCD"}, ], "nodes": {"total": 2, "input": 1, "split": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 3, 10, 10], "float32")] + input_info = [([bz, 3, 10, 10], "float32")] verify_model(Chunk(), input_info, expected) -def test_inplace_fill(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_inplace_fill(dynamic): """test graph builder for inplace_fill""" class InplaceFill(Module): @@ -1449,13 +1662,21 @@ def forward(self, data): data.fill_(1.5) return data - expected = { - "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "const", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "nodes": {"total": 2, "input": 1, "constant": 1}, - } - - verify_model(InplaceFill(), [([10, 10], "float32")], expected) + bz = "bz" if dynamic else 1 + if dynamic: + expected = { + "inputs": [{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "full", "shape": [bz, 10], "dtype": "float32", "layout": ""}], + "nodes": {"total": 3, "input": 1, "constant": 1, "full": 1}, + "prims": {"total": 1, "shape": 1}, + } + else: + expected = { + "inputs": [{"name": "inp_0", "shape": [bz, 10], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "const", "shape": [bz, 10], "dtype": "float32", "layout": ""}], + "nodes": {"total": 2, "input": 1, "constant": 1}, + } + verify_model(InplaceFill(), [([bz, 10], "float32")], expected) def test_arange(): @@ -1517,7 +1738,8 @@ def forward(self): verify_model(Empty2(), [([10, 10], "float32")], expected2) -def test_tril(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_tril(dynamic): """test graph builder for tril""" class Tril(Module): @@ -1529,18 +1751,23 @@ def forward(self, data): data.tril_(1) return data + row = "row" if dynamic else 10 + col = "col" if dynamic else 10 expected = { - "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "tril", "shape": [10, 10], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [row, col], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "tril", "shape": [row, col], "dtype": "float32", "layout": ""}], "nodes": {"total": 2, "input": 1, "tril": 1}, } + if dynamic: + expected["prims"] = {"total": 2, "shape": 2} - input_info = [([10, 10], "float32")] + input_info = [([row, col], "float32")] verify_model(Tril(), input_info, expected) verify_model(InplaceTril(), input_info, expected) -def test_triu(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_triu(dynamic): """test graph builder for triu""" class Triu(Module): @@ -1552,13 +1779,17 @@ def forward(self, data): data.triu_(1) return data + row = "row" if dynamic else 10 + col = "col" if dynamic else 10 expected = { - "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "triu", "shape": [10, 10], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [row, col], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "triu", "shape": [row, col], "dtype": "float32", "layout": ""}], "nodes": {"total": 2, "input": 1, "triu": 1}, } + if dynamic: + expected["prims"] = {"total": 2, "shape": 2} - input_info = [([10, 10], "float32")] + input_info = [([row, col], "float32")] verify_model(Triu(), input_info, expected) verify_model(InplaceTriu(), input_info, expected) @@ -1580,7 +1811,8 @@ def forward(self, x): verify_model(NewOnes(), input_info, expected) -def test_expand(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_expand(dynamic): """test graph builder for expand""" class Expand1(Module): @@ -1591,20 +1823,24 @@ class Expand2(Module): def forward(self, x): return x.expand(4, -1, -1, 4) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": ""}], "outputs": [ {"name": "broadcast_to", "shape": [4, 2, 3, 4], "dtype": "float32", "layout": ""} ], "nodes": {"total": 2, "input": 1, "broadcast_to": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 2, 3, 4], "float32")] + input_info = [([bz, 2, 3, 4], "float32")] verify_model(Expand1(), input_info, expected) verify_model(Expand2(), input_info, expected) -def test_reduce(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_reduce(dynamic): """test graph builder for reduce""" # sum @@ -1612,20 +1848,25 @@ class Sum(Module): def forward(self, x): return torch.sum(x, (2, 1)) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ANCB"}], - "outputs": [{"name": "sum", "shape": [1, 4], "dtype": "float32", "layout": "AB"}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ACDB"}], + "outputs": [{"name": "sum", "shape": [bz, 4], "dtype": "float32", "layout": "AB"}], "nodes": {"total": 2, "input": 1, "sum": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([1, 2, 3, 4], "float32")] + input_info = [([bz, 2, 3, 4], "float32")] verify_model(Sum(), input_info, expected) -def test_datatype(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_datatype(dynamic): """test graph builder for datatype""" - input_info = [([1, 2, 3, 4], "float32")] + bz = "bz" if dynamic else 1 + input_info = [([bz, 2, 3, 4], "float32")] # float class ToFloat(Module): @@ -1633,12 +1874,14 @@ def forward(self, x): return x.float() expected1 = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], "outputs": [ - {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} + {"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "astype": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} verify_model(ToFloat(), input_info, expected1) @@ -1648,12 +1891,14 @@ def forward(self, x): return x.half() expected2 = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], "outputs": [ - {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float16", "layout": "ABCD"} + {"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float16", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "astype": 1}, } + if dynamic: + expected2["prims"] = {"total": 1, "shape": 1} verify_model(ToHalf(), input_info, expected2) @@ -1663,12 +1908,14 @@ def forward(self, x): return x.type(torch.float32) expected3 = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], "outputs": [ - {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} + {"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "astype": 1}, } + if dynamic: + expected3["prims"] = {"total": 1, "shape": 1} # type class TypeFromAttr(Module): @@ -1676,12 +1923,14 @@ def forward(self, x): return x.type(x.getattr("dtype")) expected4 = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], "outputs": [ - {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} + {"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "astype": 1}, } + if dynamic: + expected4["prims"] = {"total": 1, "shape": 1} # astype class AsType(Module): @@ -1689,91 +1938,140 @@ def forward(self, x): return x.astype(torch.float32) expected5 = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"}], "outputs": [ - {"name": "astype", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} + {"name": "astype", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 2, "input": 1, "astype": 1}, } + if dynamic: + expected5["prims"] = {"total": 1, "shape": 1} verify_model(Type(), input_info, expected3) verify_model(TypeFromAttr(), input_info, expected4) verify_model(AsType(), input_info, expected5) -def test_permute(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_permute(dynamic): """test graph builder for permute""" class Permute(Module): def forward(self, x): return x.permute(0, 3, 2, 1) + bz = "bz" if dynamic else 1 + channel = "channel" if dynamic else 2 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ADCB"}], + "inputs": [ + {"name": "inp_0", "shape": [bz, channel, 3, 4], "dtype": "float32", "layout": "ADCB"} + ], "outputs": [ - {"name": "permute_dims", "shape": [1, 4, 3, 2], "dtype": "float32", "layout": "ABCD"} + { + "name": "permute_dims", + "shape": [bz, 4, 3, channel], + "dtype": "float32", + "layout": "ABCD", + } ], "nodes": {"total": 2, "input": 1, "permute_dims": 1}, } + if dynamic: + expected["prims"] = {"total": 2, "shape": 2} - input_info = [([1, 2, 3, 4], "float32")] + input_info = [([bz, channel, 3, 4], "float32")] verify_model(Permute(), input_info, expected) -def test_reshape(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_reshape(dynamic): """test graph builder for reshape""" class Reshape(Module): def forward(self, x): - return x.reshape(2, 12) + return x.reshape(-1, 12) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "reshape", "shape": [2, 12], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": ""}], + "outputs": [ + { + "name": "reshape", + "shape": ["MUL_2" if dynamic else 2, 12], + "dtype": "float32", + "layout": "", + } + ], "nodes": {"total": 2, "input": 1, "reshape": 1}, } + if dynamic: + expected["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1} - input_info = [([1, 2, 3, 4], "float32")] + input_info = [([bz, 2, 3, 4], "float32")] verify_model(Reshape(), input_info, expected) -def test_transpose(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_transpose(dynamic): """test graph builder for transpose""" class Transpose(Module): def forward(self, x): return x.transpose(1, 3) + bz = "bz" if dynamic else 1 + hidden = "hidden" if dynamic else 4 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": "ADCB"}], + "inputs": [ + {"name": "inp_0", "shape": [bz, 2, 3, hidden], "dtype": "float32", "layout": "ADCB"} + ], "outputs": [ - {"name": "permute_dims", "shape": [1, 4, 3, 2], "dtype": "float32", "layout": "ABCD"} + { + "name": "permute_dims", + "shape": [bz, hidden, 3, 2], + "dtype": "float32", + "layout": "ABCD", + } ], "nodes": {"total": 2, "input": 1, "permute_dims": 1}, } + if dynamic: + expected["prims"] = {"total": 2, "shape": 2} - input_info = [([1, 2, 3, 4], "float32")] + input_info = [([bz, 2, 3, hidden], "float32")] verify_model(Transpose(), input_info, expected) -def test_view(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_view(dynamic): """test graph builder for view""" class View(Module): def forward(self, x): - return x.view(2, 12) + return x.view(-1, 12) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [1, 2, 3, 4], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "reshape", "shape": [2, 12], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 2, 3, 4], "dtype": "float32", "layout": ""}], + "outputs": [ + { + "name": "reshape", + "shape": ["MUL_2" if dynamic else 2, 12], + "dtype": "float32", + "layout": "", + } + ], "nodes": {"total": 2, "input": 1, "reshape": 1}, } + if dynamic: + expected["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1} - input_info = [([1, 2, 3, 4], "float32")] + input_info = [([bz, 2, 3, 4], "float32")] verify_model(View(), input_info, expected) -def test_keep_params(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_keep_params(dynamic): """test graph builder for keep_params""" class Conv2D1(Module): @@ -1784,228 +2082,271 @@ def __init__(self): def forward(self, data): return self.conv(data) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [1, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} + {"name": "inp_0", "shape": [bz, 3, 10, 10], "dtype": "float32", "layout": "NCHW"} ], "outputs": [ { "name": "conv2d", - "shape": [1, 6, 4, 4], + "shape": [bz, 6, 4, 4], "dtype": "float32", "layout": "NCHW", } ], "nodes": {"total": 2, "input": 1, "msc.conv2d_bias": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - verify_model(Conv2D1(), [([1, 3, 10, 10], "float32")], expected) + verify_model(Conv2D1(), [([bz, 3, 10, 10], "float32")], expected) -def test_unwrap_unit_return_tuple(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_unwrap_unit_return_tuple(dynamic): """test graph builder for unwrap_unit_return_tuple""" class Identity(Module): def forward(self, x): return (x,) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "tuple", "shape": [256, 256], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "tuple", "shape": [bz, 256], "dtype": "float32", "layout": ""}], "nodes": {"total": 2, "input": 1, "tuple": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - verify_model(Identity(), [([256, 256], "float32")], expected) + verify_model(Identity(), [([bz, 256], "float32")], expected) -def test_no_bind_return_tuple(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_no_bind_return_tuple(dynamic): """test graph builder for no_bind_return_tuple""" class Identity(Module): def forward(self, x, y): return (x, y) + bz_x = "bz" if dynamic else 1 + bz_y = "bz" if dynamic else 2 expected = { "inputs": [ - {"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [256, 256], "dtype": "float32", "layout": ""}, + {"name": "inp_0", "shape": [bz_x, 256], "dtype": "float32", "layout": ""}, + {"name": "inp_1", "shape": [bz_y, 256], "dtype": "float32", "layout": ""}, ], "outputs": [ - {"name": "tuple_0", "shape": [256, 256], "dtype": "float32", "layout": ""}, - {"name": "tuple_1", "shape": [256, 256], "dtype": "float32", "layout": ""}, + {"name": "tuple_0", "shape": [bz_x, 256], "dtype": "float32", "layout": ""}, + {"name": "tuple_1", "shape": [bz_y, 256], "dtype": "float32", "layout": ""}, ], "nodes": {"total": 3, "input": 2, "tuple": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - input_info = [([256, 256], "float32"), ([256, 256], "float32")] + input_info = [([bz_x, 256], "float32"), ([bz_y, 256], "float32")] verify_model(Identity(), input_info, expected) -def test_argmax(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_argmax(dynamic): """test graph builder for argmax""" class Argmax1(Module): def forward(self, data): return torch.argmax(data, dim=-1) - expected1 = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "argmax", "shape": [256], "dtype": "int64", "layout": ""}], - "nodes": {"total": 2, "input": 1, "argmax": 1}, - } - class Argmax2(Module): def forward(self, data): return torch.argmax(data, dim=-1, keepdim=True) + bz = "bz" if dynamic else 1 + expected1 = { + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "argmax", "shape": [bz], "dtype": "int64", "layout": ""}], + "nodes": {"total": 2, "input": 1, "argmax": 1}, + } expected2 = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "argmax", "shape": [256, 1], "dtype": "int64", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "argmax", "shape": [bz, 1], "dtype": "int64", "layout": ""}], "nodes": {"total": 2, "input": 1, "argmax": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} - verify_model(Argmax1(), [([256, 256], "float32")], expected1) - verify_model(Argmax2(), [([256, 256], "float32")], expected2) + verify_model(Argmax1(), [([bz, 256], "float32")], expected1) + verify_model(Argmax2(), [([bz, 256], "float32")], expected2) -def test_argmin(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_argmin(dynamic): """test graph builder for argmin""" class Argmin1(Module): def forward(self, data): return torch.argmin(data) - expected1 = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "argmin", "shape": [], "dtype": "int64", "layout": ""}], - "nodes": {"total": 2, "input": 1, "argmin": 1}, - } - class Argmin2(Module): def forward(self, data): return torch.argmin(data, keepdim=True) + bz = "bz" if dynamic else 1 + expected1 = { + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "argmin", "shape": [], "dtype": "int64", "layout": ""}], + "nodes": {"total": 2, "input": 1, "argmin": 1}, + } expected2 = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], "outputs": [{"name": "argmin", "shape": [1, 1], "dtype": "int64", "layout": ""}], "nodes": {"total": 2, "input": 1, "argmin": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} - verify_model(Argmin1(), [([256, 256], "float32")], expected1) - verify_model(Argmin2(), [([256, 256], "float32")], expected2) + verify_model(Argmin1(), [([bz, 256], "float32")], expected1) + verify_model(Argmin2(), [([bz, 256], "float32")], expected2) -def test_to(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_to(dynamic): """test graph builder for to""" class To1(Module): def forward(self, data): return data.to(torch.float16) - expected1 = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": "AB"}], - "outputs": [{"name": "astype", "shape": [256, 256], "dtype": "float16", "layout": "AB"}], - "nodes": {"total": 2, "input": 1, "astype": 1}, - } - class To2(Module): def forward(self, data): return data.to("cpu") + bz = "bz" if dynamic else 1 + expected1 = { + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], + "outputs": [{"name": "astype", "shape": [bz, 256], "dtype": "float16", "layout": "AB"}], + "nodes": {"total": 2, "input": 1, "astype": 1}, + } expected2 = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": ""}], + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], + "outputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": ""}], "nodes": {"total": 1, "input": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} - verify_model(To1(), [([256, 256], "float32")], expected1) - verify_model(To2(), [([256, 256], "float32")], expected2) + verify_model(To1(), [([bz, 256], "float32")], expected1) + verify_model(To2(), [([bz, 256], "float32")], expected2) -def test_mean(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_mean(dynamic): """test graph builder for mean""" class Mean(Module): def forward(self, data): return data.mean(-1) - expected1 = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": "AN"}], - "outputs": [{"name": "mean", "shape": [256], "dtype": "float32", "layout": "A"}], - "nodes": {"total": 2, "input": 1, "mean": 1}, - } - class MeanKeepDim(Module): def forward(self, data): return data.mean(-1, keepdim=True) + bz = "bz" if dynamic else 1 + expected1 = { + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], + "outputs": [{"name": "mean", "shape": [bz], "dtype": "float32", "layout": "A"}], + "nodes": {"total": 2, "input": 1, "mean": 1}, + } expected2 = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": "AB"}], - "outputs": [{"name": "mean", "shape": [256, 1], "dtype": "float32", "layout": "AB"}], + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], + "outputs": [{"name": "mean", "shape": [bz, 1], "dtype": "float32", "layout": "AB"}], "nodes": {"total": 2, "input": 1, "mean": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} + expected2["prims"] = {"total": 1, "shape": 1} - verify_model(Mean(), [([256, 256], "float32")], expected1) - verify_model(MeanKeepDim(), [([256, 256], "float32")], expected2) + verify_model(Mean(), [([bz, 256], "float32")], expected1) + verify_model(MeanKeepDim(), [([bz, 256], "float32")], expected2) -def test_rsqrt(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_rsqrt(dynamic): """test graph builder for rsqrt""" class Rsqrt(Module): def forward(self, data): return torch.rsqrt(data) + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": "AB"}], - "outputs": [{"name": "rsqrt", "shape": [256, 256], "dtype": "float32", "layout": "AB"}], + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], + "outputs": [{"name": "rsqrt", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], "nodes": {"total": 2, "input": 1, "rsqrt": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - verify_model(Rsqrt(), [([256, 256], "float32")], expected) + verify_model(Rsqrt(), [([bz, 256], "float32")], expected) -def test_neg(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_neg(dynamic): """test graph builder for neg""" class Neg(Module): def forward(self, data): return -data + bz = "bz" if dynamic else 1 expected = { - "inputs": [{"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": "AB"}], - "outputs": [{"name": "negative", "shape": [256, 256], "dtype": "float32", "layout": "AB"}], + "inputs": [{"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], + "outputs": [{"name": "negative", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], "nodes": {"total": 2, "input": 1, "negative": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - verify_model(Neg(), [([256, 256], "float32")], expected) + verify_model(Neg(), [([bz, 256], "float32")], expected) -def test_max(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_max(dynamic): """test graph builder for max""" class Max(Module): def forward(self, x, y): return torch.max(x, y) + bz = "bz" if dynamic else 1 expected = { "inputs": [ - {"name": "inp_0", "shape": [256, 256], "dtype": "float32", "layout": "AB"}, - {"name": "inp_1", "shape": [256, 256], "dtype": "float32", "layout": "AB"}, + {"name": "inp_0", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}, + {"name": "inp_1", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}, ], - "outputs": [{"name": "maximum", "shape": [256, 256], "dtype": "float32", "layout": "AB"}], + "outputs": [{"name": "maximum", "shape": [bz, 256], "dtype": "float32", "layout": "AB"}], "nodes": {"total": 3, "input": 2, "maximum": 1}, } + if dynamic: + expected["prims"] = {"total": 1, "shape": 1} - verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")], expected) + verify_model(Max(), [([bz, 256], "float32"), ([bz, 256], "float32")], expected) -def test_attention(): +@pytest.mark.parametrize("dynamic", [True, False]) +def test_attention(dynamic): """test graph builder for attention""" # pylint: disable=import-outside-toplevel import torch.nn.functional as F + seq = "seq" if dynamic else 128 + class Attention1(Module): def forward(self, q_data, k_data, v_data): return F.scaled_dot_product_attention(q_data, k_data, v_data) @@ -2016,25 +2357,27 @@ def forward(self, q_data, k_data, v_data): expected1 = { "inputs": [ - {"name": "inp_0", "shape": [32, 8, 128, 64], "dtype": "float32", "layout": "ACBD"}, - {"name": "inp_1", "shape": [32, 8, 128, 64], "dtype": "float32", "layout": "ACBD"}, - {"name": "inp_2", "shape": [32, 8, 128, 64], "dtype": "float32", "layout": "ACBD"}, + {"name": "inp_0", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, + {"name": "inp_1", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, + {"name": "inp_2", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, ], "outputs": [ { "name": "attention", - "shape": [32, 128, 8, 64], + "shape": [1, seq, 8, 64], "dtype": "float32", "layout": "ABCD", } ], "nodes": {"total": 4, "input": 3, "msc.attention": 1}, } + if dynamic: + expected1["prims"] = {"total": 1, "shape": 1} input_info = [ - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 64], "float32"), + ([1, 8, seq, 64], "float32"), + ([1, 8, seq, 64], "float32"), + ([1, 8, seq, 64], "float32"), ] verify_model(Attention1(), input_info, expected1) verify_model(Attention2(), input_info, expected1) @@ -2045,28 +2388,31 @@ def forward(self, q_data, k_data, v_data, mask): expected2 = { "inputs": [ - {"name": "inp_0", "shape": [32, 8, 128, 64], "dtype": "float32", "layout": "ACBD"}, - {"name": "inp_1", "shape": [32, 8, 128, 64], "dtype": "float32", "layout": "ACBD"}, - {"name": "inp_2", "shape": [32, 8, 128, 64], "dtype": "float32", "layout": "ACBD"}, - {"name": "inp_3", "shape": [32, 8, 128, 128], "dtype": "float32", "layout": "ABCD"}, + {"name": "inp_0", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, + {"name": "inp_1", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, + {"name": "inp_2", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, + {"name": "inp_3", "shape": [1, 8, seq, seq], "dtype": "float32", "layout": "ABCD"}, ], "outputs": [ { "name": "attention_bias", - "shape": [32, 128, 8, 64], + "shape": [1, seq, 8, 64], "dtype": "float32", "layout": "ABCD", } ], "nodes": {"total": 5, "input": 4, "msc.attention": 1}, } + if dynamic: + expected2["prims"] = {"total": 1, "shape": 1} + verify_model( Attention3(), [ - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 64], "float32"), - ([32, 8, 128, 128], "float32"), + ([1, 8, seq, 64], "float32"), + ([1, 8, seq, 64], "float32"), + ([1, 8, seq, 64], "float32"), + ([1, 8, seq, seq], "float32"), ], expected2, ) diff --git a/tests/python/contrib/test_msc/test_pipeline.py b/tests/python/contrib/test_msc/test_pipeline.py index 149041959416..ddc70243887b 100644 --- a/tests/python/contrib/test_msc/test_pipeline.py +++ b/tests/python/contrib/test_msc/test_pipeline.py @@ -37,7 +37,7 @@ def _get_config(model_type, compile_type, inputs, outputs, dynamic=False, atol=1 path = "test_pipe_{}_{}_{}".format(model_type, compile_type, "dynamic" if dynamic else "static") return { - "workspace": msc_utils.msc_dir(path), + "workspace": msc_utils.msc_dir(path, keep_history=False), "verbose": "critical", "model_type": model_type, "inputs": inputs, @@ -161,7 +161,7 @@ def test_tvm_pipeline(dynamic): "inputs": [ {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} ], - "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NC"}], + "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NW"}], "nodes": { "total": 229, "input": 1, @@ -217,7 +217,7 @@ def test_torch_pipeline(dynamic): "inputs": [ {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} ], - "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NC"}], + "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NW"}], "nodes": { "total": 229, "input": 1, diff --git a/tests/python/contrib/test_msc/test_runner.py b/tests/python/contrib/test_msc/test_runner.py index 55fc9dd43e4f..031572a98e4a 100644 --- a/tests/python/contrib/test_msc/test_runner.py +++ b/tests/python/contrib/test_msc/test_runner.py @@ -84,13 +84,15 @@ def _test_from_torch(runner_cls, device, training=False, atol=1e-1, rtol=1e-1): torch_model = _get_torch_model("resnet50", training) if torch_model: path = "test_runner_torch_{}_{}".format(runner_cls.__name__, device) - workspace = msc_utils.set_workspace(msc_utils.msc_dir(path)) + workspace = msc_utils.set_workspace(msc_utils.msc_dir(path, keep_history=False)) log_path = workspace.relpath("MSC_LOG", keep_history=False) msc_utils.set_global_logger("critical", log_path) input_info = [([1, 3, 224, 224], "float32")] datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info] torch_datas = [torch.from_numpy(d) for d in datas] graph_model = fx.symbolic_trace(torch_model) + if training: + input_info = [([tvm.tir.Var("bz", "int64"), 3, 224, 224], "float32")] with torch.no_grad(): golden = torch_model(*torch_datas) mod = from_fx(graph_model, input_info) @@ -103,34 +105,34 @@ def _test_from_torch(runner_cls, device, training=False, atol=1e-1, rtol=1e-1): tvm.testing.assert_allclose(gol_r, msc_utils.cast_array(out_r), atol=atol, rtol=rtol) -def test_tvm_runner_cpu(): +@pytest.mark.parametrize("training", [True, False]) +def test_tvm_runner_cpu(training): """Test runner for tvm on cpu""" - for training in [True, False]: - _test_from_torch(TVMRunner, "cpu", training=training) + _test_from_torch(TVMRunner, "cpu", training=training) @tvm.testing.requires_cuda -def test_tvm_runner_cuda(): +@pytest.mark.parametrize("training", [True, False]) +def test_tvm_runner_cuda(training): """Test runner for tvm on cuda""" - for training in [True, False]: - _test_from_torch(TVMRunner, "cuda", training=training) + _test_from_torch(TVMRunner, "cuda", training=training) -def test_torch_runner_cpu(): +@pytest.mark.parametrize("training", [True, False]) +def test_torch_runner_cpu(training): """Test runner for torch on cpu""" - for training in [True, False]: - _test_from_torch(TorchRunner, "cpu", training=training) + _test_from_torch(TorchRunner, "cpu", training=training) @tvm.testing.requires_cuda -def test_torch_runner_cuda(): +@pytest.mark.parametrize("training", [True, False]) +def test_torch_runner_cuda(training): """Test runner for torch on cuda""" - for training in [True, False]: - _test_from_torch(TorchRunner, "cuda", training=training, atol=1e-1, rtol=1e-1) + _test_from_torch(TorchRunner, "cuda", training=training, atol=1e-1, rtol=1e-1) @requires_tensorrt @@ -146,7 +148,7 @@ def test_tensorflow_runner(): tf_graph, graph_def = _get_tf_graph() if tf_graph and graph_def: path = "test_runner_tf" - workspace = msc_utils.set_workspace(msc_utils.msc_dir(path)) + workspace = msc_utils.set_workspace(msc_utils.msc_dir(path, keep_history=False)) log_path = workspace.relpath("MSC_LOG", keep_history=False) msc_utils.set_global_logger("critical", log_path) data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32") diff --git a/tests/python/contrib/test_msc/test_tools.py b/tests/python/contrib/test_msc/test_tools.py index 22354bb2c131..ac6f2d6c6f74 100644 --- a/tests/python/contrib/test_msc/test_tools.py +++ b/tests/python/contrib/test_msc/test_tools.py @@ -47,7 +47,7 @@ def _get_config( path = "_".join(["test_tools", model_type, compile_type] + [t["tool_type"] for t in tools]) return { - "workspace": msc_utils.msc_dir(path), + "workspace": msc_utils.msc_dir(path, keep_history=False), "verbose": "critical", "model_type": model_type, "inputs": inputs, @@ -229,7 +229,7 @@ def get_model_info(compile_type): "inputs": [ {"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"} ], - "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NC"}], + "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": "NW"}], "nodes": { "total": 229, "input": 1,