Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MSC][Refactor] Support dynamic shape #17351

Merged
merged 3 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions python/tvm/contrib/msc/core/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,10 @@ def visit_var_binding_(self, binding: relax.VarBinding) -> None:

def _to_var(tensor: MSCTensor):
v_name = tensor.alias if use_alias else graph.find_producer(tensor).name
return tvm.relax.Var(
v_name, tvm.relax.TensorStructInfo(tensor.get_shape(), tensor.dtype_name)
)
dims = [
d if isinstance(d, int) else tvm.tir.Var(d, "int64") for d in tensor.get_shape(True)
]
return tvm.relax.Var(v_name, tvm.relax.TensorStructInfo(dims, tensor.dtype_name))

def _save_weights(folder: msc_utils.MSCDirectory):
if weights:
Expand Down
38 changes: 38 additions & 0 deletions python/tvm/contrib/msc/core/frontend/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,44 @@
from tvm.contrib.msc.core.ir import MSCGraph, MSCTensor


def normalize_inputs(inputs: List[tuple]) -> List[tuple]:
"""Normalize the inputs info

Parameters
----------
inputs: list of <name, shape, dtype>
The inputs info.

Returns
-------
inputs: list of <name, shape, dtype>
The normalized inputs info.
"""

recorded_vars = {}

def _normalize_input(inp):
def _normalize(info):
if not isinstance(info, (tuple, list)):
return info
dims = []
for dim in info:
if isinstance(dim, int):
dims.append(dim)
elif dim in recorded_vars:
dims.append(recorded_vars[dim])
elif isinstance(dim, str):
recorded_vars[dim] = tvm.tir.Var(dim, "int64")
dims.append(recorded_vars[dim])
else:
raise TypeError("Unexpected dim {} in shape {}".format(dim, info))
return dims

return [_normalize(i) for i in inp]

return [_normalize_input(inp) for inp in inputs]


def normalize_weights(
t_weights: Dict[MSCTensor, tvm.nd.array], graph: MSCGraph
) -> Dict[str, tvm.nd.array]:
Expand Down
93 changes: 84 additions & 9 deletions python/tvm/contrib/msc/core/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class MSCTensor(Object):
The shape of the tensor.
alias: string
The alias of the tensor.
prims: list<str>
The prims of the tensor.
"""

def __init__(
Expand All @@ -50,15 +52,31 @@ def __init__(
layout: str,
shape: List[int],
alias: Optional[str] = None,
prims: List[str] = None,
):
if not isinstance(dtype, tvm.DataType):
dtype = tvm.DataType(dtype)
self.__init_handle_by_constructor__(
_ffi_api.MSCTensor, name, dtype, layout, shape, alias or ""
_ffi_api.MSCTensor, name, dtype, layout, shape, alias or "", prims or []
)

def get_shape(self) -> List[int]:
return [int(i) for i in self.shape]
def get_shape(self, with_prims: bool = False) -> List[Union[int, str]]:
"""Get shape of the tensor

Parameters
-------
with_prims: bool
Whether get shape with prims.

Returns
-------
shape: list<str|int>
The shape of tensor.
"""

if not self.prims or not with_prims:
return [int(i) for i in self.shape]
return [int(p) if p.isdigit() else p for p in self.prims]

def get_size(self) -> int:
return int(_ffi_api.MSCTensorGetSize(self))
Expand Down Expand Up @@ -98,7 +116,7 @@ def equal(self, other: Object) -> bool:

if not isinstance(other, MSCTensor):
return False
if self.get_shape() != other.get_shape():
if self.get_shape(True) != other.get_shape(True):
return False
if self.dtype != other.dtype:
return False
Expand All @@ -124,7 +142,7 @@ def inspect(self) -> dict:
The tensor description in json format.
"""

tensor_des = {"name": self.alias, "shape": self.get_shape(), "dtype": self.dtype_name}
tensor_des = {"name": self.alias, "shape": self.get_shape(True), "dtype": self.dtype_name}
tensor_des["layout"] = self.layout.name if self.layout else ""
return tensor_des

Expand Down Expand Up @@ -405,6 +423,30 @@ def equal(self, other: BaseJoint) -> bool:
return msc_utils.dict_equal(self.get_attrs(), other.get_attrs())


@tvm._ffi.register_object("msc.core.MSCPrim")
class MSCPrim(BaseJoint):
"""Prim in MSCGraph

Parameters
----------
index: int
The index of the prim.
name: string
The name of the prim.
optype: string
The optype of the prim.
attrs: dict<string, string>
The attributes of the node.
parents: list<MSCPrim>
The parents of the prim.
"""

def __init__(
self, index: int, name: str, optype: str, attrs: Dict[str, str], parents: List[BaseJoint]
):
self.__init_handle_by_constructor__(_ffi_api.MSCPrim, index, name, optype, attrs, parents)


@tvm._ffi.register_object("msc.core.WeightJoint")
class WeightJoint(BaseJoint):
"""Node in WeightGraph
Expand Down Expand Up @@ -586,6 +628,22 @@ def find_node(self, name: str) -> MSCJoint:

return _ffi_api.MSCGraphFindNode(self, name)

def find_prim(self, name: str) -> MSCPrim:
"""Find prim by name.

Parameters
----------
name: string
The name of the prim.

Returns
-------
prim: MSCPrim
The found prim.
"""

return _ffi_api.MSCGraphFindPrim(self, name)

def has_tensor(self, name: str) -> bool:
"""Check if tensor in the graph.

Expand Down Expand Up @@ -679,6 +737,18 @@ def get_nodes(self) -> Iterable[MSCJoint]:
for n in self.node_names:
yield self.find_node(n)

def get_prims(self) -> Iterable[MSCPrim]:
"""Get all the prims in the graph.

Returns
-------
prims: generator<MSCPrim>
The generator of prims.
"""

for n in self.prim_names:
yield self.find_prim(n)

def get_weights(self) -> Iterable[MSCTensor]:
"""Get all the weights in the graph.

Expand Down Expand Up @@ -789,11 +859,16 @@ def inspect(self) -> dict:
"nodes": {"total": 0},
}
for node in self.get_nodes():
graph_des["nodes"].setdefault(node.optype, 0)
graph_des["nodes"]["total"] += 1
if node.optype not in graph_des["nodes"]:
graph_des["nodes"][node.optype] = 1
else:
graph_des["nodes"][node.optype] += 1
graph_des["nodes"][node.optype] += 1
prims = {"total": 0}
for prim in self.get_prims():
prims.setdefault(prim.optype, 0)
prims["total"] += 1
prims[prim.optype] += 1
if prims["total"] > 0:
graph_des["prims"] = prims
return graph_des

@classmethod
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/contrib/msc/core/tools/prune/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,12 @@ def _prune_by_shape(tensor: MSCTensor, shape: List[int]):
def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None):
shape = tensor.get_shape()
if channel_axis is None:
channel_axis = tensor.layout_of("C")
if self.has_w_node(tensor.name):
w_node = self.find_w_node(tensor.name)
_, channel_axis = self._get_io_axes(w_node)
else:
channel_axis = tensor.layout_of("C")
assert channel_axis >= 0, "Can not infer channel_axis for " + str(tensor)
shape[channel_axis] = dim
return _prune_by_shape(tensor, shape)

Expand Down
3 changes: 3 additions & 0 deletions python/tvm/contrib/msc/core/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,6 +1620,9 @@ def _get_io_axes(self, w_node: WeightJoint) -> Tuple[int, int]:
in_axis, out_axis = w_node.weight.layout_of("I"), w_node.weight.layout_of("O")
if in_axis >= 0 and out_axis >= 0:
return in_axis, out_axis
if w_node.weight.ndim == 2 and w_node.weight.dim_at("N") > 0:
io_axis = 1 - w_node.weight.layout_of("N")
return io_axis, io_axis
if w_node.weight.layout_of("C") >= 0:
return w_node.weight.layout_of("C"), w_node.weight.layout_of("C")
raise Exception("Can not infer in_axis/out_axis from " + str(w_node))
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/contrib/msc/framework/torch/frontend/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
import torch
import tvm
from tvm.relax.frontend.torch import from_fx

from tvm.contrib.msc.core.ir.graph import MSCGraph
from tvm.contrib.msc.core.frontend import from_relax
from tvm.contrib.msc.core.frontend import from_relax, normalize_inputs
from tvm.contrib.msc.core.codegen import relay_to_relax


Expand Down Expand Up @@ -104,6 +103,7 @@ def from_torch(
"""

if via_relax:
input_info = normalize_inputs(input_info)
graph_model, params = torch.fx.symbolic_trace(model), None
with torch.no_grad():
relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map)
Expand Down
12 changes: 11 additions & 1 deletion python/tvm/contrib/msc/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,10 +676,20 @@ def _get_loader(self, name: str = MSCStage.PREPARE) -> Any:
max_batch = config.get("max_batch", 5)

def get_random():
def _to_data(inp):
shape = [1 if isinstance(d, str) else d for d in inp[1]]
return np.random.rand(*shape).astype(inp[2])

for _ in range(max_batch):
yield {i[0]: np.random.rand(*i[1]).astype(i[2]) for i in self._config["inputs"]}
yield {i[0]: _to_data(i) for i in self._config["inputs"]}

loader, source_type = get_random, "random"
elif isinstance(source_loader, dict):

def load_data():
return [source_loader]

loader, source_type = load_data, "dict"
elif msc_utils.is_io_dataset(source_loader):
max_batch = config.get("max_batch", -1)

Expand Down
37 changes: 24 additions & 13 deletions python/tvm/contrib/msc/pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""tvm.contrib.msc.pipeline.config"""

import copy
from typing import List, Union, Dict, Tuple

from tvm.contrib.msc.core.tools import ToolType
Expand Down Expand Up @@ -129,6 +130,7 @@ def create_config(
dataset: Dict[str, dict] = None,
tools: List[Tuple[str, Union[dict, str]]] = None,
dynamic: bool = False,
run_config: Dict[str, dict] = None,
skip_config: Dict[str, str] = None,
**extra_config,
) -> dict:
Expand Down Expand Up @@ -160,11 +162,13 @@ def create_config(
The extra config.
"""

all_stages = [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]
baseline_type = baseline_type or model_type
optimize_type = optimize_type or baseline_type
compile_type = compile_type or optimize_type
tools = tools or []
tools = [config_tool(t_type, t_config) for t_type, t_config in tools]
extra_config = extra_config or {}
# basic config
config = {
"model_type": model_type,
Expand Down Expand Up @@ -194,27 +198,34 @@ def create_config(
"profile": {"check": {"atol": 1e-3, "rtol": 1e-3}, "benchmark": {"repeat": -1}},
}

# update run config
if run_config:
if "all" in run_config:
all_config = run_config.pop("all")
run_config.update({s: copy.deepcopy(all_config) for s in all_stages})
for stage, r_config in run_config.items():
extra_config.setdefault(stage, {}).setdefault("run_config", {}).update(r_config)

# update config
if extra_config:
config = msc_utils.update_dict(config, extra_config)

# skip stages
skip_config = skip_config or {}
for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]:
if stage not in config:
continue
for key in ["all", stage]:
if key not in skip_config:
if skip_config:
if "all" in run_config:
all_config = skip_config.pop("all")
skip_config.update({s: copy.deepcopy(all_config) for s in all_stages})
for stage, s_type in skip_config.items():
if stage not in config:
continue
if skip_config[key] == "stage":
if s_type == "stage":
config.pop(stage)
elif skip_config[key] == "profile":
elif s_type == "profile":
config[stage].pop("profile")
elif skip_config[key] == "check":
config[stage]["profile"].pop("check")
elif skip_config[key] == "benchmark":
elif s_type == "check":
config[stage]["profile"]["check"]["err_rate"] = -1
elif s_type == "benchmark":
config[stage]["profile"].pop("benchmark")
else:
raise TypeError("Unexpected skip type " + str(skip_config[key]))

raise TypeError("Unexpected skip type " + str(s_type))
return config
3 changes: 3 additions & 0 deletions python/tvm/contrib/msc/pipeline/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ class TorchWrapper(BaseWrapper):
"""Wrapper of torch models"""

def __call__(self, *inputs):
return self.forward(*inputs)

def forward(self, *inputs):
framework = self._get_framework()
if framework != MSCFramework.TORCH:
inputs = [msc_utils.cast_array(i, framework, self.device) for i in inputs]
Expand Down
Loading
Loading