From f5ab695179c13d4762006bfa377e142e3a8a181e Mon Sep 17 00:00:00 2001 From: Archermmt Date: Wed, 11 Sep 2024 21:31:12 +0800 Subject: [PATCH] bug fix for pruner --- python/tvm/contrib/msc/core/tools/prune/pruner.py | 7 ++++++- python/tvm/contrib/msc/core/tools/tool.py | 3 +++ src/contrib/msc/core/ir/graph.cc | 5 ++--- tests/python/contrib/test_msc/test_tools.py | 2 +- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/python/tvm/contrib/msc/core/tools/prune/pruner.py b/python/tvm/contrib/msc/core/tools/prune/pruner.py index 90273e25416b1..a008100be2524 100644 --- a/python/tvm/contrib/msc/core/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/core/tools/prune/pruner.py @@ -340,7 +340,12 @@ def _prune_by_shape(tensor: MSCTensor, shape: List[int]): def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None): shape = tensor.get_shape() if channel_axis is None: - channel_axis = tensor.layout_of("C") + if self.has_w_node(tensor.name): + w_node = self.find_w_node(tensor.name) + _, channel_axis = self._get_io_axes(w_node) + else: + channel_axis = tensor.layout_of("C") + assert channel_axis >= 0, "Can not infer channel_axis for " + str(tensor) shape[channel_axis] = dim return _prune_by_shape(tensor, shape) diff --git a/python/tvm/contrib/msc/core/tools/tool.py b/python/tvm/contrib/msc/core/tools/tool.py index 626ae312bcf4e..06a16f2bbe493 100644 --- a/python/tvm/contrib/msc/core/tools/tool.py +++ b/python/tvm/contrib/msc/core/tools/tool.py @@ -1620,6 +1620,9 @@ def _get_io_axes(self, w_node: WeightJoint) -> Tuple[int, int]: in_axis, out_axis = w_node.weight.layout_of("I"), w_node.weight.layout_of("O") if in_axis >= 0 and out_axis >= 0: return in_axis, out_axis + if w_node.weight.ndim == 2 and w_node.weight.dim_at("N") > 0: + io_axis = 1 - w_node.weight.layout_of("N") + return io_axis, io_axis if w_node.weight.layout_of("C") >= 0: return w_node.weight.layout_of("C"), w_node.weight.layout_of("C") raise Exception("Can not infer in_axis/out_axis from " + str(w_node)) diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc index 71d15dccdf30e..ae42537a4ce11 100644 --- a/src/contrib/msc/core/ir/graph.cc +++ b/src/contrib/msc/core/ir/graph.cc @@ -1116,9 +1116,8 @@ void WeightGraphNode::Build(const MSCGraph& graph, const MapOutputAt(0); Map attrs; attrs.Set("producer_type", node->optype); - if (node->optype == "reshape" && node->InputAt(0)->LayoutOf("C") >= 0 && - node->OutputAt(0)->LayoutOf("C") >= 0 && - node->InputAt(0)->DimAt("C")->value == node->OutputAt(0)->DimAt("C")->value) { + if (node->optype == "reshape") { + // TODO(archermmt): check non-passby reshape attrs.Set("weight_strategy", "passby"); } else { attrs.Set("weight_strategy", relation_wtypes[node->optype]); diff --git a/tests/python/contrib/test_msc/test_tools.py b/tests/python/contrib/test_msc/test_tools.py index 7b20f2b6dfcc8..ac6f2d6c6f745 100644 --- a/tests/python/contrib/test_msc/test_tools.py +++ b/tests/python/contrib/test_msc/test_tools.py @@ -47,7 +47,7 @@ def _get_config( path = "_".join(["test_tools", model_type, compile_type] + [t["tool_type"] for t in tools]) return { - "workspace": msc_utils.msc_dir(path), + "workspace": msc_utils.msc_dir(path, keep_history=False), "verbose": "critical", "model_type": model_type, "inputs": inputs,