Skip to content

Commit

Permalink
bug fix for pruner
Browse files Browse the repository at this point in the history
  • Loading branch information
Archermmt committed Sep 11, 2024
1 parent ad45afd commit 78f6496
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 5 deletions.
7 changes: 6 additions & 1 deletion python/tvm/contrib/msc/core/tools/prune/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,12 @@ def _prune_by_shape(tensor: MSCTensor, shape: List[int]):
def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None):
shape = tensor.get_shape()
if channel_axis is None:
channel_axis = tensor.layout_of("C")
if self.has_w_node(tensor.name):
w_node = self.find_w_node(tensor.name)
_, channel_axis = self._get_io_axes(w_node)
else:
channel_axis = tensor.layout_of("C")
assert channel_axis >= 0, "Can not infer channel_axis for " + str(tensor)
shape[channel_axis] = dim
return _prune_by_shape(tensor, shape)

Expand Down
3 changes: 3 additions & 0 deletions python/tvm/contrib/msc/core/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,6 +1620,9 @@ def _get_io_axes(self, w_node: WeightJoint) -> Tuple[int, int]:
in_axis, out_axis = w_node.weight.layout_of("I"), w_node.weight.layout_of("O")
if in_axis >= 0 and out_axis >= 0:
return in_axis, out_axis
if w_node.weight.ndim == 2 and w_node.weight.dim_at("N") > 0:
io_axis = 1 - w_node.weight.layout_of("N")
return io_axis, io_axis
if w_node.weight.layout_of("C") >= 0:
return w_node.weight.layout_of("C"), w_node.weight.layout_of("C")
raise Exception("Can not infer in_axis/out_axis from " + str(w_node))
Expand Down
5 changes: 2 additions & 3 deletions src/contrib/msc/core/ir/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1116,9 +1116,8 @@ void WeightGraphNode::Build(const MSCGraph& graph, const Map<String, Array<Strin
const auto& tensor = node->OutputAt(0);
Map<String, String> attrs;
attrs.Set("producer_type", node->optype);
if (node->optype == "reshape" && node->InputAt(0)->LayoutOf("C") >= 0 &&
node->OutputAt(0)->LayoutOf("C") >= 0 &&
node->InputAt(0)->DimAt("C")->value == node->OutputAt(0)->DimAt("C")->value) {
if (node->optype == "reshape") {
// TODO(archermmt): check non-passby reshape
attrs.Set("weight_strategy", "passby");
} else {
attrs.Set("weight_strategy", relation_wtypes[node->optype]);
Expand Down
2 changes: 1 addition & 1 deletion tests/python/contrib/test_msc/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _get_config(

path = "_".join(["test_tools", model_type, compile_type] + [t["tool_type"] for t in tools])
return {
"workspace": msc_utils.msc_dir(path),
"workspace": msc_utils.msc_dir(path, keep_history=False),
"verbose": "critical",
"model_type": model_type,
"inputs": inputs,
Expand Down

0 comments on commit 78f6496

Please sign in to comment.