Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add plugin in manager
Browse files Browse the repository at this point in the history
Archermmt committed Jan 31, 2024
1 parent 0628bdb commit bd2b472
Showing 20 changed files with 375 additions and 74 deletions.
9 changes: 9 additions & 0 deletions python/tvm/contrib/msc/core/runtime/runner.py
Original file line number Diff line number Diff line change
@@ -57,6 +57,8 @@ class BaseRunner(object):
Whether compile model to trainable
stage: str
The stage of runner.
plugin: PluginManager
The plugin manager.
name: str
The name of the runner
debug_level: int
@@ -75,6 +77,7 @@ def __init__(
device: str = "cpu",
training: bool = False,
stage: str = "default",
plugin: Any = None,
name: str = "main",
debug_level: int = 0,
logger: logging.Logger = None,
@@ -86,6 +89,7 @@ def __init__(
self._build_config = msc_utils.copy_dict(build_config)
self._device = device if self._device_enabled(device) else "cpu"
self._stage = stage
self._plugin = plugin
self._name = name
self._debug_level = debug_level
self._training, self._trained = training, training
@@ -123,8 +127,11 @@ def setup(self) -> dict:
stage=self._stage,
**config,
)
if self._plugin:
self._update_codegen({"use_plugin": True})
return {
"tools": {k: v.tool_style() for k, v in self._tools.items()},
"plugin": self._plugin,
"translate_config": self._translate_config,
"generate_config": self._generate_config,
"build_config": self._build_config,
@@ -1069,6 +1076,7 @@ def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.arra
codegen_config=self._generate_config.get("codegen"),
print_config=self._generate_config.get("print"),
build_folder=self._generate_config["build_folder"],
plugin=self._plugin,
)

def _inspect_model(self) -> dict:
@@ -1226,6 +1234,7 @@ def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.arra
extra_options=extra_option,
build_folder=self._generate_config["build_folder"],
output_folder=self._generate_config.get("output_folder", msc_utils.get_output_dir()),
plugin=self._plugin,
)

def _build_runnable(self, model: Any) -> Any:
1 change: 1 addition & 0 deletions python/tvm/contrib/msc/core/tools/prune/pruner.py
Original file line number Diff line number Diff line change
@@ -82,6 +82,7 @@ def _parse_strategys(self, strategy_list: dict) -> Dict[str, ToolStrategy]:
def _update_stages(strategy):
if "stages" not in strategy:
strategy["stages"] = [msc_utils.MSCStage.PRUNE]
strategy["tensor_types"] = ["weight", "output"]
return strategy

return super()._parse_strategys([_update_stages(s) for s in strategy_list])
5 changes: 5 additions & 0 deletions python/tvm/contrib/msc/core/tools/quantize/quantizer.py
Original file line number Diff line number Diff line change
@@ -114,6 +114,11 @@ def _check_tensor(self, name: str, consumer: str) -> bool:
Whether to process the tensor.
"""

if self._calibrated:
tensor_id = self.to_tensor_id(name, consumer)
if tensor_id not in self._plan:
return False
return self._plan.get(tensor_id, {}).get("nbits", 8) != -1
strategys = self._get_tensor_strategys(name, consumer)
if not strategys:
return False
47 changes: 26 additions & 21 deletions python/tvm/contrib/msc/core/tools/tool.py
Original file line number Diff line number Diff line change
@@ -409,7 +409,7 @@ def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy]
tensor_names = strategy.pop("tensor_names")
marks = [(n, "tensor") for n in tensor_names]
else:
marks = [("default", t) for t in ["input", "output", "weight"]]
marks = [("default." + str(t), t) for t in tensor_types]
stages = strategy.pop("stages") if "stages" in strategy else ["default"]
for mark, t_type in marks:
if mark not in strategys:
@@ -1212,33 +1212,38 @@ def _get_tensor_strategys(self, name: str, consumer: str) -> List[ToolStrategy]:

tensor_id = self.to_tensor_id(name, consumer)
mark = "strategy.{}".format(self._stage)

def _check_strategy(s_ref):
return s_ref in self._strategys and self._strategys[s_ref].support_stage(self._stage)

if mark not in self._tensor_cache.get(tensor_id, {}):
if self.is_weight(name):
strategys = []
tensor_strategy = self._strategys.get(tensor_id)
if tensor_strategy and tensor_strategy.support_stage(self._stage):
strategys.append(tensor_strategy)
elif self.is_weight(name):
consumer = self.find_node(consumer)
name_refs = [consumer.name + ".weight", consumer.optype + ".weight"]
for ref in [consumer.name, consumer.optype, "default"]:
if _check_strategy(ref + ".weight"):
strategys.append(self._strategys[ref + ".weight"])
break
elif consumer == "exit":
producer = self.find_producer(name)
name_refs = [producer.name + ".output", producer.optype + ".output"]
for ref in [producer.name, producer.optype, "exit", "default"]:
if _check_strategy(ref + ".output"):
strategys.append(self._strategys[ref + ".output"])
break
else:
consumer = self.find_node(consumer)
for ref in [consumer.name, consumer.optype, "default"]:
if _check_strategy(ref + ".input"):
strategys.append(self._strategys[ref + ".input"])
break
producer = self.find_producer(name)
name_refs = [
producer.name + ".output",
producer.optype + ".output",
consumer.name + ".input",
consumer.optype + ".input",
]
strategys = []
tensor_strategy = self._strategys.get(tensor_id)
if tensor_strategy and tensor_strategy.support_stage(self._stage):
strategys.append(tensor_strategy)
if not strategys:
for n in name_refs:
if n in self._strategys and self._strategys[n].support_stage(self._stage):
strategys.append(self._strategys[n])
d_strategy = self._strategys.get("default")
if not strategys and d_strategy and d_strategy.support_stage(self._stage):
strategys.append(d_strategy)
for ref in [producer.name, producer.optype, "default"]:
if _check_strategy(ref + ".output"):
strategys.append(self._strategys[ref + ".output"])
break
self._save_tensor_cache(name, consumer, mark, strategys)
return self._get_tensor_cache(name, consumer, mark)

10 changes: 8 additions & 2 deletions python/tvm/contrib/msc/framework/tensorflow/codegen/codegen.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
# under the License.
"""tvm.contrib.msc.framework.tensorflow.codegen.codegen"""

from typing import Dict, Optional
from typing import Dict, Optional, Any

import tvm
from tvm.contrib.msc.core.ir import MSCGraph
@@ -32,6 +32,7 @@ def to_tensorflow(
codegen_config: Optional[Dict[str, str]] = None,
print_config: Optional[Dict[str, str]] = None,
build_folder: msc_utils.MSCDirectory = None,
plugin: Any = None,
) -> tf_v1.Graph:
"""Change MSCGraph to tensorflow graph.
@@ -47,6 +48,8 @@ def to_tensorflow(
The config for print.
build_folder: MSCDirectory
The folder for saving scripts and datas.
plugin: PluginManager
The plugin manager.
Returns
-------
@@ -63,4 +66,7 @@ def _save_weights(folder: msc_utils.MSCDirectory):
codegen = CodeGen(
graph, _ffi_api.GetTensorflowSources, codegen_config, print_config, build_folder
)
return codegen.load(inputs + [weights], pre_load=_save_weights)
model_args = inputs + [weights]
if plugin:
model_args = model_args + [plugin]
return codegen.load(model_args, pre_load=_save_weights)
17 changes: 16 additions & 1 deletion python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@

import os
import subprocess
from typing import Dict, Optional, List, Union
from typing import Dict, Optional, List, Union, Any
import numpy as np

import tvm
@@ -38,6 +38,7 @@ def to_sub_tensorrt(
print_config: Optional[Dict[str, str]] = None,
build_folder: msc_utils.MSCDirectory = None,
output_folder: msc_utils.MSCDirectory = None,
plugin: Any = None,
) -> str:
"""Change MSCGraph to TensorRT engine file.
@@ -55,6 +56,8 @@ def to_sub_tensorrt(
The folder for saving sources and datas.
export_folder: MSCDirectory
The folder for saving outputs.
plugin: PluginManager
The plugin manager.
Returns
-------
@@ -90,6 +93,10 @@ def _create_depends(folder: msc_utils.MSCDirectory) -> str:
f.write("{}\n".format(len(engine_wts)))
for name, data in engine_wts.items():
write_weight(name, msc_utils.cast_array(data), f)
# copy plugin
if plugin:
plugin.copy_libs("plugin_lib")
plugin.copy_includes("plugin")
# save utils sources
with folder.create_dir("utils") as utils_folder:
for name, source in get_trt_sources().items():
@@ -115,6 +122,10 @@ def _build_engine(engine_name: str, folder: msc_utils.MSCDirectory) -> str:

with build_folder as folder:
sub_folder = folder.create_dir(graph.name)
if plugin:
codegen_config["extern_libs"] = [
sub_folder.create_dir("plugin_lib").relpath(f) for f in plugin.list_libs()
]
codegen = CodeGen(
graph,
_ffi_api.GetTensorRTSources,
@@ -140,6 +151,7 @@ def to_tensorrt(
extra_options: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None,
build_folder: msc_utils.MSCDirectory = None,
output_folder: msc_utils.MSCDirectory = None,
plugin: Any = None,
) -> Dict[str, str]:
"""Change all MSCGraphs to TensorRT engine files.
@@ -161,6 +173,8 @@ def to_tensorrt(
The folder for saving sources and datas.
export_folder: MSCDirectory
The folder for saving outputs.
plugin: PluginManager
The plugin manager.
Returns
-------
@@ -183,6 +197,7 @@ def to_tensorrt(
print_configs[idx],
build_folder,
output_folder,
plugin=plugin,
)
if extra_options[idx]:
options.update(extra_options[idx])
48 changes: 48 additions & 0 deletions python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py
Original file line number Diff line number Diff line change
@@ -20,11 +20,13 @@
from typing import Mapping, Tuple, List, Union, Callable, Dict
from functools import wraps, partial

import tvm
from tvm import relax
from tvm.relax.dpl import pattern
from tvm.relax.transform import PatternCheckContext, FusionPattern
from tvm.relax.backend.pattern_registry import register_patterns
from tvm.contrib.msc.core.transform import pattern as msc_pattern
from tvm.contrib.msc.core import _ffi_api


def basic_pattern(
@@ -234,6 +236,43 @@ def _take_check(context: PatternCheckContext) -> bool:
return _check_expr(context.annotated_expr["input_1"], ("int32"))


def _plugin_check(context: PatternCheckContext) -> bool:
"""Check if the plugin pattern is correct.
Returns
-------
pass: bool
Whether the pattern is correct.
"""

ext_func = context.annotated_expr["out"].args[0]
return bool(_ffi_api.IsPlugin(ext_func.global_symbol))


def plugin_attrs_getter(
annotated_expr: Dict[str, tvm.relax.Expr],
) -> Dict[str, str]:
"""Get attributes for plugin pattern
Parameters
----------
annotated_expr: dict<str,Expr>
The annotated exprs during fus pattern
anchor: str
The anchor key of expr
Returns
-------
attrs: dict<str,str>
The extra attributes for msc.
"""

attrs = msc_pattern.msc_attrs_getter(annotated_expr, anchor="out")
ext_func = annotated_expr["out"].args[0]
attrs[_ffi_api.ToAttrKey("optype")] = ext_func.global_symbol
return attrs


def wrap_basic_check(
func: Callable[[PatternCheckContext], bool]
) -> Callable[[PatternCheckContext], bool]:
@@ -410,6 +449,15 @@ def get_patterns(target) -> List[Pattern]:
),
]
)
# plugin ops
patterns.append(
(
target + ".plugin",
*basic_pattern("relax.call_dps_packed", ["input", "input"]),
_plugin_check,
plugin_attrs_getter,
)
)

return patterns

8 changes: 6 additions & 2 deletions python/tvm/contrib/msc/framework/torch/codegen/codegen.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
# under the License.
"""tvm.contrib.msc.framework.torch.codegen.codegen"""

from typing import Dict, Optional
from typing import Dict, Optional, Any
import torch

import tvm
@@ -32,6 +32,7 @@ def to_torch(
codegen_config: Optional[Dict[str, str]] = None,
print_config: Optional[Dict[str, str]] = None,
build_folder: msc_utils.MSCDirectory = None,
plugin: Any = None,
) -> torch.nn.Module:
"""Change MSCGraph to torch nn.Module.
@@ -47,6 +48,8 @@ def to_torch(
The config for print.
build_folder: MSCDirectory
The folder for saving scripts and datas.
plugin: PluginManager
The plugin manager.
Returns
-------
@@ -73,4 +76,5 @@ def _bind_weights(model: torch.nn.Module, folder: msc_utils.MSCDirectory) -> tor
return model

codegen = CodeGen(graph, _ffi_api.GetTorchSources, codegen_config, print_config, build_folder)
return codegen.load([], pre_load=_save_weights, post_load=_bind_weights)
model_args = [plugin] if plugin else []
return codegen.load(model_args, pre_load=_save_weights, post_load=_bind_weights)
10 changes: 8 additions & 2 deletions python/tvm/contrib/msc/framework/tvm/codegen/codegen.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
# under the License.
"""tvm.contrib.msc.framework.tvm.codegen.codegen"""

from typing import Dict, Optional
from typing import Dict, Optional, Any

import tvm
from tvm.relax.transform import BindParams
@@ -32,6 +32,7 @@ def to_relax(
codegen_config: Optional[Dict[str, str]] = None,
print_config: Optional[Dict[str, str]] = None,
build_folder: msc_utils.MSCDirectory = None,
plugin: Any = None,
) -> tvm.IRModule:
"""Change MSCGraph to IRModule.
@@ -47,6 +48,8 @@ def to_relax(
The config for print.
build_folder: MSCDirectory
The folder for saving scripts and datas.
plugin: PluginManager
The plugin manager.
Returns
-------
@@ -81,4 +84,7 @@ def _post_proc(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModul
)(mod)

codegen = CodeGen(graph, _ffi_api.GetRelaxSources, codegen_config, print_config, build_folder)
return codegen.load(inputs, pre_load=_save_weights, post_load=_post_proc)
model_args = inputs
if plugin:
model_args = model_args + [plugin]
return codegen.load(model_args, pre_load=_save_weights, post_load=_post_proc)
1 change: 1 addition & 0 deletions python/tvm/contrib/msc/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -17,3 +17,4 @@
"""tvm.contrib.msc.pipeline"""

from .manager import *
from .wrapper import *
Loading

0 comments on commit bd2b472

Please sign in to comment.