Skip to content

Commit

Permalink
[MSC][Refactor] Support dynamic shape (#17351)
Browse files Browse the repository at this point in the history
* support prims for tir.Var

* minor fix

* bug fix for pruner
  • Loading branch information
Archermmt authored Sep 15, 2024
1 parent 48d661c commit 11198f6
Show file tree
Hide file tree
Showing 33 changed files with 1,939 additions and 842 deletions.
7 changes: 4 additions & 3 deletions python/tvm/contrib/msc/core/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,10 @@ def visit_var_binding_(self, binding: relax.VarBinding) -> None:

def _to_var(tensor: MSCTensor):
v_name = tensor.alias if use_alias else graph.find_producer(tensor).name
return tvm.relax.Var(
v_name, tvm.relax.TensorStructInfo(tensor.get_shape(), tensor.dtype_name)
)
dims = [
d if isinstance(d, int) else tvm.tir.Var(d, "int64") for d in tensor.get_shape(True)
]
return tvm.relax.Var(v_name, tvm.relax.TensorStructInfo(dims, tensor.dtype_name))

def _save_weights(folder: msc_utils.MSCDirectory):
if weights:
Expand Down
38 changes: 38 additions & 0 deletions python/tvm/contrib/msc/core/frontend/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,44 @@
from tvm.contrib.msc.core.ir import MSCGraph, MSCTensor


def normalize_inputs(inputs: List[tuple]) -> List[tuple]:
"""Normalize the inputs info
Parameters
----------
inputs: list of <name, shape, dtype>
The inputs info.
Returns
-------
inputs: list of <name, shape, dtype>
The normalized inputs info.
"""

recorded_vars = {}

def _normalize_input(inp):
def _normalize(info):
if not isinstance(info, (tuple, list)):
return info
dims = []
for dim in info:
if isinstance(dim, int):
dims.append(dim)
elif dim in recorded_vars:
dims.append(recorded_vars[dim])
elif isinstance(dim, str):
recorded_vars[dim] = tvm.tir.Var(dim, "int64")
dims.append(recorded_vars[dim])
else:
raise TypeError("Unexpected dim {} in shape {}".format(dim, info))
return dims

return [_normalize(i) for i in inp]

return [_normalize_input(inp) for inp in inputs]


def normalize_weights(
t_weights: Dict[MSCTensor, tvm.nd.array], graph: MSCGraph
) -> Dict[str, tvm.nd.array]:
Expand Down
93 changes: 84 additions & 9 deletions python/tvm/contrib/msc/core/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class MSCTensor(Object):
The shape of the tensor.
alias: string
The alias of the tensor.
prims: list<str>
The prims of the tensor.
"""

def __init__(
Expand All @@ -50,15 +52,31 @@ def __init__(
layout: str,
shape: List[int],
alias: Optional[str] = None,
prims: List[str] = None,
):
if not isinstance(dtype, tvm.DataType):
dtype = tvm.DataType(dtype)
self.__init_handle_by_constructor__(
_ffi_api.MSCTensor, name, dtype, layout, shape, alias or ""
_ffi_api.MSCTensor, name, dtype, layout, shape, alias or "", prims or []
)

def get_shape(self) -> List[int]:
return [int(i) for i in self.shape]
def get_shape(self, with_prims: bool = False) -> List[Union[int, str]]:
"""Get shape of the tensor
Parameters
-------
with_prims: bool
Whether get shape with prims.
Returns
-------
shape: list<str|int>
The shape of tensor.
"""

if not self.prims or not with_prims:
return [int(i) for i in self.shape]
return [int(p) if p.isdigit() else p for p in self.prims]

def get_size(self) -> int:
return int(_ffi_api.MSCTensorGetSize(self))
Expand Down Expand Up @@ -98,7 +116,7 @@ def equal(self, other: Object) -> bool:

if not isinstance(other, MSCTensor):
return False
if self.get_shape() != other.get_shape():
if self.get_shape(True) != other.get_shape(True):
return False
if self.dtype != other.dtype:
return False
Expand All @@ -124,7 +142,7 @@ def inspect(self) -> dict:
The tensor description in json format.
"""

tensor_des = {"name": self.alias, "shape": self.get_shape(), "dtype": self.dtype_name}
tensor_des = {"name": self.alias, "shape": self.get_shape(True), "dtype": self.dtype_name}
tensor_des["layout"] = self.layout.name if self.layout else ""
return tensor_des

Expand Down Expand Up @@ -405,6 +423,30 @@ def equal(self, other: BaseJoint) -> bool:
return msc_utils.dict_equal(self.get_attrs(), other.get_attrs())


@tvm._ffi.register_object("msc.core.MSCPrim")
class MSCPrim(BaseJoint):
"""Prim in MSCGraph
Parameters
----------
index: int
The index of the prim.
name: string
The name of the prim.
optype: string
The optype of the prim.
attrs: dict<string, string>
The attributes of the node.
parents: list<MSCPrim>
The parents of the prim.
"""

def __init__(
self, index: int, name: str, optype: str, attrs: Dict[str, str], parents: List[BaseJoint]
):
self.__init_handle_by_constructor__(_ffi_api.MSCPrim, index, name, optype, attrs, parents)


@tvm._ffi.register_object("msc.core.WeightJoint")
class WeightJoint(BaseJoint):
"""Node in WeightGraph
Expand Down Expand Up @@ -586,6 +628,22 @@ def find_node(self, name: str) -> MSCJoint:

return _ffi_api.MSCGraphFindNode(self, name)

def find_prim(self, name: str) -> MSCPrim:
"""Find prim by name.
Parameters
----------
name: string
The name of the prim.
Returns
-------
prim: MSCPrim
The found prim.
"""

return _ffi_api.MSCGraphFindPrim(self, name)

def has_tensor(self, name: str) -> bool:
"""Check if tensor in the graph.
Expand Down Expand Up @@ -679,6 +737,18 @@ def get_nodes(self) -> Iterable[MSCJoint]:
for n in self.node_names:
yield self.find_node(n)

def get_prims(self) -> Iterable[MSCPrim]:
"""Get all the prims in the graph.
Returns
-------
prims: generator<MSCPrim>
The generator of prims.
"""

for n in self.prim_names:
yield self.find_prim(n)

def get_weights(self) -> Iterable[MSCTensor]:
"""Get all the weights in the graph.
Expand Down Expand Up @@ -789,11 +859,16 @@ def inspect(self) -> dict:
"nodes": {"total": 0},
}
for node in self.get_nodes():
graph_des["nodes"].setdefault(node.optype, 0)
graph_des["nodes"]["total"] += 1
if node.optype not in graph_des["nodes"]:
graph_des["nodes"][node.optype] = 1
else:
graph_des["nodes"][node.optype] += 1
graph_des["nodes"][node.optype] += 1
prims = {"total": 0}
for prim in self.get_prims():
prims.setdefault(prim.optype, 0)
prims["total"] += 1
prims[prim.optype] += 1
if prims["total"] > 0:
graph_des["prims"] = prims
return graph_des

@classmethod
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/contrib/msc/core/tools/prune/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,12 @@ def _prune_by_shape(tensor: MSCTensor, shape: List[int]):
def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None):
shape = tensor.get_shape()
if channel_axis is None:
channel_axis = tensor.layout_of("C")
if self.has_w_node(tensor.name):
w_node = self.find_w_node(tensor.name)
_, channel_axis = self._get_io_axes(w_node)
else:
channel_axis = tensor.layout_of("C")
assert channel_axis >= 0, "Can not infer channel_axis for " + str(tensor)
shape[channel_axis] = dim
return _prune_by_shape(tensor, shape)

Expand Down
3 changes: 3 additions & 0 deletions python/tvm/contrib/msc/core/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,6 +1620,9 @@ def _get_io_axes(self, w_node: WeightJoint) -> Tuple[int, int]:
in_axis, out_axis = w_node.weight.layout_of("I"), w_node.weight.layout_of("O")
if in_axis >= 0 and out_axis >= 0:
return in_axis, out_axis
if w_node.weight.ndim == 2 and w_node.weight.dim_at("N") > 0:
io_axis = 1 - w_node.weight.layout_of("N")
return io_axis, io_axis
if w_node.weight.layout_of("C") >= 0:
return w_node.weight.layout_of("C"), w_node.weight.layout_of("C")
raise Exception("Can not infer in_axis/out_axis from " + str(w_node))
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/contrib/msc/framework/torch/frontend/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
import torch
import tvm
from tvm.relax.frontend.torch import from_fx

from tvm.contrib.msc.core.ir.graph import MSCGraph
from tvm.contrib.msc.core.frontend import from_relax
from tvm.contrib.msc.core.frontend import from_relax, normalize_inputs
from tvm.contrib.msc.core.codegen import relay_to_relax


Expand Down Expand Up @@ -104,6 +103,7 @@ def from_torch(
"""

if via_relax:
input_info = normalize_inputs(input_info)
graph_model, params = torch.fx.symbolic_trace(model), None
with torch.no_grad():
relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map)
Expand Down
12 changes: 11 additions & 1 deletion python/tvm/contrib/msc/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,10 +676,20 @@ def _get_loader(self, name: str = MSCStage.PREPARE) -> Any:
max_batch = config.get("max_batch", 5)

def get_random():
def _to_data(inp):
shape = [1 if isinstance(d, str) else d for d in inp[1]]
return np.random.rand(*shape).astype(inp[2])

for _ in range(max_batch):
yield {i[0]: np.random.rand(*i[1]).astype(i[2]) for i in self._config["inputs"]}
yield {i[0]: _to_data(i) for i in self._config["inputs"]}

loader, source_type = get_random, "random"
elif isinstance(source_loader, dict):

def load_data():
return [source_loader]

loader, source_type = load_data, "dict"
elif msc_utils.is_io_dataset(source_loader):
max_batch = config.get("max_batch", -1)

Expand Down
37 changes: 24 additions & 13 deletions python/tvm/contrib/msc/pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""tvm.contrib.msc.pipeline.config"""

import copy
from typing import List, Union, Dict, Tuple

from tvm.contrib.msc.core.tools import ToolType
Expand Down Expand Up @@ -129,6 +130,7 @@ def create_config(
dataset: Dict[str, dict] = None,
tools: List[Tuple[str, Union[dict, str]]] = None,
dynamic: bool = False,
run_config: Dict[str, dict] = None,
skip_config: Dict[str, str] = None,
**extra_config,
) -> dict:
Expand Down Expand Up @@ -160,11 +162,13 @@ def create_config(
The extra config.
"""

all_stages = [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]
baseline_type = baseline_type or model_type
optimize_type = optimize_type or baseline_type
compile_type = compile_type or optimize_type
tools = tools or []
tools = [config_tool(t_type, t_config) for t_type, t_config in tools]
extra_config = extra_config or {}
# basic config
config = {
"model_type": model_type,
Expand Down Expand Up @@ -194,27 +198,34 @@ def create_config(
"profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}},
}

# update run config
if run_config:
if "all" in run_config:
all_config = run_config.pop("all")
run_config.update({s: copy.deepcopy(all_config) for s in all_stages})
for stage, r_config in run_config.items():
extra_config.setdefault(stage, {}).setdefault("run_config", {}).update(r_config)

# update config
if extra_config:
config = msc_utils.update_dict(config, extra_config)

# skip stages
skip_config = skip_config or {}
for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]:
if stage not in config:
continue
for key in ["all", stage]:
if key not in skip_config:
if skip_config:
if "all" in run_config:
all_config = skip_config.pop("all")
skip_config.update({s: copy.deepcopy(all_config) for s in all_stages})
for stage, s_type in skip_config.items():
if stage not in config:
continue
if skip_config[key] == "stage":
if s_type == "stage":
config.pop(stage)
elif skip_config[key] == "profile":
elif s_type == "profile":
config[stage].pop("profile")
elif skip_config[key] == "check":
config[stage]["profile"].pop("check")
elif skip_config[key] == "benchmark":
elif s_type == "check":
config[stage]["profile"]["check"]["err_rate"] = -1
elif s_type == "benchmark":
config[stage]["profile"].pop("benchmark")
else:
raise TypeError("Unexpected skip type " + str(skip_config[key]))

raise TypeError("Unexpected skip type " + str(s_type))
return config
3 changes: 3 additions & 0 deletions python/tvm/contrib/msc/pipeline/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ class TorchWrapper(BaseWrapper):
"""Wrapper of torch models"""

def __call__(self, *inputs):
return self.forward(*inputs)

def forward(self, *inputs):
framework = self._get_framework()
if framework != MSCFramework.TORCH:
inputs = [msc_utils.cast_array(i, framework, self.device) for i in inputs]
Expand Down
Loading

0 comments on commit 11198f6

Please sign in to comment.