From 8af066b2439f573961321a6f9c459d6fc2d3021d Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 26 Jul 2022 01:57:58 -0700 Subject: [PATCH 01/33] first commit --- python/tvm/contrib/torch/__init__.py | 11 - python/tvm/contrib/torch/module.py | 121 --- python/tvm/contrib/torch/pytorch_tvm.py | 249 ------- .../torch/pt_call_tvm/RuntimeModuleWrapper.cc | 244 +++---- .../pt_call_tvm/RuntimeModuleWrapperTorch.cc | 149 ++++ .../torch/pt_call_tvm/runtime_bridge.h | 52 ++ src/contrib/torch/pt_call_tvm/tvm_class.cc | 686 ------------------ 7 files changed, 290 insertions(+), 1222 deletions(-) delete mode 100644 python/tvm/contrib/torch/module.py delete mode 100644 python/tvm/contrib/torch/pytorch_tvm.py create mode 100644 src/contrib/torch/pt_call_tvm/RuntimeModuleWrapperTorch.cc create mode 100644 src/contrib/torch/pt_call_tvm/runtime_bridge.h delete mode 100644 src/contrib/torch/pt_call_tvm/tvm_class.cc diff --git a/python/tvm/contrib/torch/__init__.py b/python/tvm/contrib/torch/__init__.py index 340f9cef9e58..1c22228381f8 100644 --- a/python/tvm/contrib/torch/__init__.py +++ b/python/tvm/contrib/torch/__init__.py @@ -39,17 +39,6 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): _load_platform_specific_library() -from . import module - -GraphModule = module.GraphModule -VMModule = module.VMModule -TraceTvmModule = module.TraceTvmModule - -from . import pytorch_tvm - -PyTorchTVMModule = pytorch_tvm.PyTorchTVMModule -compile = pytorch_tvm.compile - from . import as_torch TVMScriptIRModule = as_torch.OperatorModuleWrapper diff --git a/python/tvm/contrib/torch/module.py b/python/tvm/contrib/torch/module.py deleted file mode 100644 index 3da9c6f591ce..000000000000 --- a/python/tvm/contrib/torch/module.py +++ /dev/null @@ -1,121 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -"""Module container of PyTorch custom class""" -from typing import List -import torch - - -class GraphModule(torch.nn.Module): - r"""Module container of Pytorch class which wraps exported - TVM op implementation library to be called on Pytorch side""" - - @classmethod - def shape_repr(cls, input_shapes): - return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes) - - def __init__(self, num_inputs, num_outputs, device=None): - super().__init__() - self.dummy_param = torch.nn.Parameter(torch.empty(0)) - self.engine = None - - if device is not None: - self.to(device) - self.engine = torch.classes.tvm_dsoop.TvmGraphModule(num_inputs, num_outputs, self.device) - - def init(self, input_shapes, lib_path, graph_path, params_path): - r"""Load tvm module""" - self.engine.load_tvm_module(input_shapes, lib_path, graph_path, params_path) - - def forward(self, inputs: List[torch.Tensor]): - r"""Call tvm module to forward""" - return self.engine.forward(inputs) - - @property - def device(self): - r"""Get the device string""" - return str(self.dummy_param.device) - - def _apply(self, fn): - r"""Override to device function, manually move tvm module to desired device""" - super()._apply(fn) - if self.engine is not None: - self.engine.to(self.device) - return self - - -class VMModule(torch.nn.Module): - r"""Module container of Pytorch class which wraps exported - TVM op implementation library to be called on Pytorch side""" - - @classmethod - def shape_repr(cls, input_shapes): - return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes) - - def __init__(self, num_inputs, num_outputs, device=None): - super().__init__() - self.dummy_param = torch.nn.Parameter(torch.empty(0)) - self.engine = None - - if device is not None: - self.to(device) - self.engine = torch.classes.tvm_dsoop.TvmVMModule(num_inputs, num_outputs, self.device) - - def init(self, input_shapes, lib_path, code_path): - r"""Load tvm module""" - self.engine.load_tvm_module(input_shapes, lib_path, code_path) - - def forward(self, inputs: List[torch.Tensor]): - r"""Call tvm module to forward""" - return self.engine.forward(inputs) - - @property - def device(self): - r"""Get the device string""" - return str(self.dummy_param.device) - - def _apply(self, fn): - r"""Override to device function, manually move tvm module to desired device""" - super()._apply(fn) - if self.engine is not None: - self.engine.to(self.device) - return self - - -class TraceTvmModule(torch.nn.Module): - r"""Wrapper for trace GraphModule - - GraphModule and VMModule only supports List[Tensor] inputs and cannot be traced. - This is a wrapper class for trace GraphModule or VMModule in order to support - arbitrary number of inputs - - Example: - import tvm.contrib.torch - tvm_module = tvm.contrib.torch.GraphModule(1, 1, 'cuda:0') - tvm_module.init(input_shapes, lib_path, graph_path, params_path) - - trace_wrapper = tvm.contrib.torch.TraceGraphModule(torch.jit.script(tvm_module)) - traced = torch.jit.trace(trace_wrapper, example_inputs) - """ - - def __init__(self, tvm_module): - super().__init__() - self.tvm_module = tvm_module - - def forward(self, *inputs): - outputs = self.tvm_module(inputs) - return outputs[0] if len(outputs) == 1 else tuple(outputs) diff --git a/python/tvm/contrib/torch/pytorch_tvm.py b/python/tvm/contrib/torch/pytorch_tvm.py deleted file mode 100644 index 1e50c98ab883..000000000000 --- a/python/tvm/contrib/torch/pytorch_tvm.py +++ /dev/null @@ -1,249 +0,0 @@ -#!/usr/bin/env python - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=redefined-builtin -"""`compile` api that convert torch module to torch tvm module""" -import os -import tvm -import tvm.testing -from tvm import relay, autotvm -from tvm.runtime import load_module -from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner -from tvm.contrib import graph_executor -from tvm.contrib.debugger import debug_executor -from . import GraphModule - - -def tune_tasks( - tasks, - measure_option, - tuner="xgb", - n_trial=1000, - early_stopping=None, - log_filename="tuning.log", - use_transfer_learning=True, -): - """Tune tasks and generate tuning log to file""" - # create tmp log file - tmp_log_file = log_filename + ".tmp" - if os.path.exists(tmp_log_file): - os.remove(tmp_log_file) - - for i, tsk in enumerate(reversed(tasks)): - prefix = f"[Task {i + 1:2d}/{len(tasks):2d}] " - - # create tuner - if tuner in ("xgb", "sgb-rank"): - tuner_obj = XGBTuner(tsk, loss_type="rank") - elif tuner == "ga": - tuner_obj = GATuner(tsk, pop_size=100) - elif tuner == "random": - tuner_obj = RandomTuner(tsk) - elif tuner == "gridsearch": - tuner_obj = GridSearchTuner(tsk) - else: - raise ValueError("Invalid tuner: " + tuner) - - if use_transfer_learning: - if os.path.isfile(tmp_log_file): - tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) - - # do tuning - tsk_trial = min(n_trial, len(tsk.config_space)) - tuner_obj.tune( - n_trial=tsk_trial, - early_stopping=early_stopping, - measure_option=measure_option, - callbacks=[ - autotvm.callback.progress_bar(tsk_trial, prefix=prefix), - autotvm.callback.log_to_file(tmp_log_file), - ], - ) - - # pick best records to a cache file - if not os.path.exists(log_filename): - with open(log_filename, "w", encoding="utf-8"): - pass - if os.path.exists(tmp_log_file): - autotvm.record.pick_best(tmp_log_file, log_filename) - os.remove(tmp_log_file) - - -def get_tuning_opt(log_file="tuning.log", n_trial=200): - """Returns tuning options""" - tuning_opt = { - "log_filename": log_file, - "tuner": "random", - "n_trial": n_trial, - "early_stopping": 60, - "measure_option": autotvm.measure_option( - builder=autotvm.LocalBuilder(timeout=10), - runner=autotvm.LocalRunner(number=20, repeat=3, timeout=4, min_repeat_ms=150), - ), - } - return tuning_opt - - -TVM_ASSETS = ["mod.so", "graph.json", "params"] - - -class PyTorchTVMModule: - """Helper class for compiling pytorch module to tvm module""" - - def __init__(self, target="cuda", device=tvm.cuda(0)) -> None: - self.script_module = None - self.input_infos = None - self.default_dtype = "float32" - self.mod = None - self.params = None - self.tasks = None - self.target = target - self.dev = device - self.log_file = None - self.tvm_module = None - self.tvm_graph = None - self.tvm_lib = None - self.tvm_params = None - - def from_pytorch(self, script_module, input_infos, default_dtype="float32"): - self.script_module = script_module - self.input_infos = input_infos - self.default_dtype = default_dtype - self.mod, self.params = relay.frontend.from_pytorch( - script_module, input_infos, default_dtype=default_dtype - ) - - def tune_tvm(self, log_file="tuning.log", n_trial=200): - self.tasks = autotvm.task.extract_from_program( - self.mod["main"], - target=self.target, - params=self.params, - ) - self.log_file = log_file - tuning_opt = get_tuning_opt(log_file, n_trial) - tune_tasks(self.tasks, **tuning_opt) - - def build_tvm(self, export_dir, debug_runtime=False): - tvm_mod = self._build_tvm(debug_runtime) - self._export_tvm(export_dir) - return tvm_mod - - def _build_tvm(self, debug_runtime=False): - # compile kernels with history best records - with autotvm.apply_history_best(self.log_file): - with tvm.transform.PassContext(opt_level=3): - self.tvm_graph, self.tvm_lib, self.tvm_params = relay.build( - self.mod, target=self.target, params=self.params - ) - - if not debug_runtime: - self.tvm_module = graph_executor.create(self.tvm_graph, self.tvm_lib, device=self.dev) - else: - self.tvm_module = debug_executor.create(self.tvm_graph, self.tvm_lib, device=self.dev) - self.tvm_module.set_input(**self.tvm_params) - return self.tvm_module - - def _export_tvm(self, export_dir): - if not os.path.isdir(export_dir): - os.makedirs(export_dir) - self.export_dir = export_dir - self.tvm_lib.export_library(os.path.join(export_dir, TVM_ASSETS[0])) - with open(os.path.join(export_dir, TVM_ASSETS[1]), "w", encoding="utf8") as fout: - fout.write(self.tvm_graph) - with open(os.path.join(export_dir, TVM_ASSETS[2]), "wb") as fout: - fout.write(relay.save_param_dict(self.tvm_params)) - - def load_tvm(self, export_dir): - """Load tvm module from export directory""" - self.export_dir = export_dir - self.tvm_lib = load_module(os.path.join(export_dir, TVM_ASSETS[0])) - with open(os.path.join(export_dir, TVM_ASSETS[1]), "r", encoding="utf8") as f: - self.tvm_graph = f.read() - with open(os.path.join(export_dir, TVM_ASSETS[2]), "rb") as f: - self.tvm_params = relay.load_param_dict(f.read()) - - self.tvm_module = graph_executor.create(self.tvm_graph, self.tvm_lib, device=self.dev) - self.tvm_module.set_input(**self.tvm_params) - return self.tvm_module - - def build_pytorch_module(self, num_inputs, num_outputs, input_infos=None): - """Build pytorch module containing TVM Graph Module""" - assert self.export_dir, "you must build_tvm or load_tvm before" - input_infos = input_infos or self.input_infos - assert input_infos - assert len(input_infos) == num_inputs - assets = [os.path.join(self.export_dir, i) for i in TVM_ASSETS] - input_shapes = [i[1] for i in input_infos] - - def _tvm_dev_to_pt_dev(device): - """convert tvm device to pytorch device string""" - if tvm.runtime.Device.MASK2STR[device.device_type] == "cpu": - return "cpu" - if tvm.runtime.Device.MASK2STR[device.device_type] == "cuda": - return f"cuda:{device.device_id}" - raise ValueError(f"unsupported device for pt graph module: {device}") - - mod = GraphModule(num_inputs=num_inputs, num_outputs=num_outputs).to( - _tvm_dev_to_pt_dev(self.dev) - ) - mod.init(input_shapes, *assets) - return mod - - -def compile(script_module, option): - """ - example: - option = { - "input_infos": [ - ("x", (1, 3, 244, 244)), - ], - "default_dtype": "float16", - "export_dir": "pytorch_compiled", - "num_outputs": 1, - "tuning_n_trials": 20, # set zero to skip tuning - "tuning_log_file": "tuning.log", - "target": "llvm", - "device": tvm.cpu(), - } - script_module = torch.jit.script(model) - pytorch_tvm_module = compile(script_module, option) - pytorch_tvm_module("model_tvm.pt") - """ - input_infos = option["input_infos"] - default_dtype = option.get("default_dtype", "float32") - export_dir = option.get("export_dir", "pytorch_compiled") - tuning_log_file = option.get("tuning_log_file", "tuning.log") - tuning_n_trials = option.get("tuning_n_trials", 20) - num_outputs = option.get("num_outputs", 1) - target = option.get("target", "cuda") - device = option.get("device", tvm.cuda(0)) - - mod = PyTorchTVMModule(target=target, device=device) - print("Converting...") - - mod.log_file = tuning_log_file - mod.from_pytorch(script_module, input_infos, default_dtype) - - if tuning_n_trials > 0: - print("Tuning...") - mod.tune_tvm(log_file=tuning_log_file, n_trial=tuning_n_trials) - - print("Building...") - mod.build_tvm(export_dir) - pytorch_mod = mod.build_pytorch_module(num_inputs=len(input_infos), num_outputs=num_outputs) - return pytorch_mod diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc index 12c1017bea76..fdd773df77d8 100644 --- a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc +++ b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc @@ -16,11 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -#include #include #include -#include -#include #include #include #include @@ -33,13 +30,8 @@ #include "../../../runtime/graph_executor/graph_executor_factory.h" #include "../base64.h" +#include "runtime_bridge.h" -namespace tvm { -namespace contrib { - -/** - * We pass the TVM module by TVM's FFI because Torch's FFI cannot recognize such TVM objects - */ struct ThreadLocalStore { tvm::runtime::Module mod; static ThreadLocalStore* ThreadLocal() { @@ -48,9 +40,10 @@ struct ThreadLocalStore { } }; -using SerializationType = std::string; // base64 stream +namespace tvm { +namespace contrib { -SerializationType serialize(tvm::runtime::Module module) { +std::string serialize(tvm::runtime::Module module) { static const runtime::PackedFunc* f_to_str = runtime::Registry::Get("script_torch.save_to_base64"); ICHECK(f_to_str) << "IndexError: Cannot find the packed function " @@ -63,12 +56,12 @@ struct Deleter { // deleter void operator()(FILE* p) const { fclose(p); ICHECK(remove(file_name.c_str()) == 0) - << "Failed to remove temporary file (" << file_name << ")"; + << "remove temporary file (" << file_name << ") unsuccessfully"; } std::string file_name; }; -tvm::runtime::Module deserialize(SerializationType state) { +tvm::runtime::Module deserialize(std::string state) { auto length = tvm::support::b64strlen(state); std::vector bytes(length); @@ -92,168 +85,109 @@ tvm::runtime::Module deserialize(SerializationType state) { return ret; } -/** - * @brief A Torch's module which wraps TVM's OperatorModule Class. - * The basic forward function calling TVM's runtime is provided. - * The TVM module can be serialized/deserialized as a Torch module. - */ -class OperatorModuleWrapper : public torch::jit::CustomClassHolder { - public: - OperatorModuleWrapper() { runtime_module = ThreadLocalStore::ThreadLocal()->mod; } - - void forward(const c10::List& inputs) { - int input_length = inputs.size(); - - std::vector tensors; - - for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i])); - - tvm::runtime::PackedFunc run = runtime_module.GetFunction("__tvm_main__"); - - std::vector tvm_values(input_length); - std::vector tvm_type_codes(input_length); - tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); - for (int k = 0; k < input_length; ++k) { - setter(k, &tensors[k]->dl_tensor); - } - - run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_length), - nullptr); +TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) { + ThreadLocalStore::ThreadLocal()->mod = mod; +}); - for (int k = 0; k < input_length; ++k) { - tensors[k]->deleter(tensors[k]); - } +tvm::runtime::NDArray NDArrayFromDlPackExt(DLPackTensorExt dlpack_ext) { + using tvm::runtime::NDArray; + + NDArray array; + auto& dl_tensor = dlpack_ext.dl_managed_tensor->dl_tensor; + bool is_zero_copy = + tvm::runtime::NDArray::AbilityOfZeroCopyForDLTensor(&dl_tensor, dl_tensor.device); + if (is_zero_copy) { + // Zero-copy if data pointer is aligned + array = NDArray::FromDLPack(dlpack_ext.dl_managed_tensor); + } else { + // Copy if data pointer isn't aligned to the kAllocAlignment of TVM + array = NDArray::NewFromDLTensor(&dl_tensor, dl_tensor.device); + dlpack_ext.dl_managed_tensor->deleter(dlpack_ext.dl_managed_tensor); } - - SerializationType Serialize() { return serialize(runtime_module); } - - explicit OperatorModuleWrapper(SerializationType state) { runtime_module = deserialize(state); } - - private: - tvm::runtime::Module runtime_module; -}; - -tvm::Device getDevice(const at::Tensor& tensor) { - tvm::Device dev; - dev.device_id = tensor.get_device(); - switch (tensor.device().type()) { - case at::DeviceType::CPU: - dev.device_type = DLDeviceType::kDLCPU; - if (dev.device_id == -1) { - /* - * In PyTorch the device ID for cpu is -1, sometimes causing error during tuning - * Thus we manually set the device ID as 0 for avoiding potentially error of index out of - * bounds - */ - dev.device_id = 0; - } - break; - case at::DeviceType::CUDA: - dev.device_type = DLDeviceType::kDLCUDA; - break; - default: - TORCH_CHECK(false, "PyTorch TVM integration doesn't support device " + tensor.device().str()); + if (dlpack_ext.is_bool) { + auto result = tvm::runtime::NDArray::Empty(array.Shape(), DataType::Bool(), array->device); + result.CopyFrom(array); + return result; } - return dev; -} -/** - * @brief A Torch's module which wraps TVM's GraphExecutorFactory Class. - * The basic forward function calling TVM's runtime is provided. - * The TVM module can be serialized/deserialized as a Torch module. - */ -class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { - public: - explicit GraphExecutorFactoryWrapper(tvm::runtime::Module executor_factory) - : executor_factory_(executor_factory) { - CHECK(executor_factory_->IsInstance()) - << "module is not an instance of GraphExecutorFactory"; - } + return array; +} - GraphExecutorFactoryWrapper() - : GraphExecutorFactoryWrapper(ThreadLocalStore::ThreadLocal()->mod) {} +} // namespace contrib +} // namespace tvm - c10::List forward(const c10::List& inputs) { - int input_length = inputs.size(); +extern "C" { - if (!executor_.defined()) { - TORCH_CHECK(input_length > 0, "Receive empty list of input tensors"); - DLDevice input_device = getDevice(inputs.get(0)); +struct RuntimeModulePointer { + tvm::runtime::Module mod; - auto tmp = executor_factory_.GetFunction("default"); + RuntimeModulePointer(tvm::runtime::Module mod) : mod(mod) {} +}; - executor_ = tmp(input_device); - } +RuntimeModulePointer* get_last_saved_runtime_module() { + return new RuntimeModulePointer(ThreadLocalStore::ThreadLocal()->mod); +} - std::vector tensors; +void operator_module_forward(RuntimeModulePointer* runtime_module, TensorList inputs, + size_t input_size) { + tvm::runtime::PackedFunc run = runtime_module->mod.GetFunction("__tvm_main__"); - for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i])); + std::vector tvm_values(input_size); + std::vector tvm_type_codes(input_size); + tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); + for (int k = 0; k < input_size; ++k) { + // auto datum = tvm::contrib::NDArrayFromDlPackExt(inputs[k]); + setter(k, &inputs[k]->dl_tensor); + } - tvm::runtime::PackedFunc run = executor_.GetFunction("run"); - tvm::runtime::PackedFunc set_input = executor_.GetFunction("set_input"); - tvm::runtime::PackedFunc get_output = executor_.GetFunction("get_output"); - tvm::runtime::PackedFunc get_num_outputs = executor_.GetFunction("get_num_outputs"); + run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_size), + nullptr); +} - for (int k = 0; k < input_length; ++k) { - set_input(k, &tensors[k]->dl_tensor); - } +tvm::Device getDeviceInfo(DLManagedTensor* input_device) { + return {.device_type = input_device->dl_tensor.device.device_type, + .device_id = input_device->dl_tensor.device.device_id}; +} - run(); +int64_t graph_executor_module_forward(RuntimeModulePointer* graph_module, TensorList inputs, + size_t input_size, TensorList* outputs) { + tvm::runtime::PackedFunc built_module = graph_module->mod.GetFunction("default"); + tvm::runtime::Module runtime_module = built_module(getDeviceInfo(inputs[0])); + tvm::runtime::PackedFunc run = runtime_module.GetFunction("run"); + tvm::runtime::PackedFunc set_input = runtime_module.GetFunction("set_input"); + tvm::runtime::PackedFunc get_output = runtime_module.GetFunction("get_output"); + tvm::runtime::PackedFunc get_num_outputs = runtime_module.GetFunction("get_num_outputs"); + + for (int k = 0; k < input_size; ++k) { + set_input(k, &inputs[k]->dl_tensor); + } - int64_t output_length = get_num_outputs(); + run(); - c10::List outputs; - outputs.reserve(output_length); + int64_t output_length = get_num_outputs(); - for (int k = 0; k < output_length; ++k) { - tvm::runtime::NDArray results = get_output(k); - at::Tensor atTensor = at::fromDLPack(results.ToDLPack()); - outputs.emplace_back(atTensor); - } + auto out_ptr = new DLManagedTensor*[output_length]; + *outputs = out_ptr; - for (int k = 0; k < input_length; ++k) { - tensors[k]->deleter(tensors[k]); - } - return outputs; + for (int k = 0; k < output_length; ++k) { + tvm::runtime::NDArray results = get_output(k); + auto tensor = results.ToDLPack(); + out_ptr[k] = tensor; } - SerializationType Serialize() { return serialize(executor_factory_); } - - explicit GraphExecutorFactoryWrapper(SerializationType state) { - executor_factory_ = deserialize(state); - } - - private: - tvm::runtime::Module executor_factory_; - tvm::runtime::Module executor_; -}; - -TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) { - ThreadLocalStore::ThreadLocal()->mod = mod; -}); + return output_length; +} -TORCH_LIBRARY(tvm_torch, m) { - m.class_("OperatorModuleWrapper") - .def(torch::init<>()) - .def("forward", &OperatorModuleWrapper::forward) - .def_pickle( - [](const c10::intrusive_ptr& self) -> SerializationType { - return self->Serialize(); - }, - [](SerializationType state) { - return c10::make_intrusive(state); - }); - m.class_("GraphExecutorFactoryWrapper") - .def(torch::init<>()) - .def("forward", &GraphExecutorFactoryWrapper::forward) - .def_pickle( - [](const c10::intrusive_ptr& self) -> SerializationType { - return self->Serialize(); - }, - [](SerializationType state) { - return c10::make_intrusive(state); - }); +char* encode(RuntimeModulePointer* runtime_module) { + auto std = tvm::contrib::serialize(runtime_module->mod); + auto* ret = new char[std.length() + 1]; + strcpy(ret, std.c_str()); + ret[std.length()] = '\0'; + return ret; } -} // namespace contrib -} // namespace tvm +RuntimeModulePointer* decode(const char* state) { + auto ret = tvm::contrib::deserialize(state); + return new RuntimeModulePointer(ret); +} +} \ No newline at end of file diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapperTorch.cc new file mode 100644 index 000000000000..2ce306a167cc --- /dev/null +++ b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapperTorch.cc @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +#include + +#include "runtime_bridge.h" + +namespace tvm { +namespace contrib { + +DLPackTensorExt toDlPackExt(const at::Tensor& src) { + if (!src.is_contiguous()) { + return toDlPackExt(src.contiguous()); + } + + if (src.dtype().isScalarType(torch::kBool)) { + auto temp = src.toType(torch::kUInt8); + return {.dl_managed_tensor = at::toDLPack(temp), .is_bool = false}; + } + + return {.dl_managed_tensor = at::toDLPack(src), .is_bool = false}; +} + +/** + * @brief A Torch's module which wraps TVM's OperatorModule Class. + * The basic forward function calling TVM's runtime is provided. + * The TVM module can be serialized/deserialized as a Torch module. + */ +class OperatorModuleWrapper : public torch::jit::CustomClassHolder { + public: + OperatorModuleWrapper() { runtime_module = get_last_saved_runtime_module(); } + + void forward(const c10::List& inputs) { + int input_length = inputs.size(); + + std::vector tensors; + + for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i])); + + operator_module_forward(this->runtime_module, static_cast(tensors.data()), + tensors.size()); + + for (int k = 0; k < input_length; ++k) { + tensors[k]->deleter(tensors[k]); + } + } + + std::string Serialize() { return std::string(encode(runtime_module)); } + + explicit OperatorModuleWrapper(std::string state) { runtime_module = decode(state.c_str()); } + + private: + RuntimeModulePointer* runtime_module; +}; + +/** + * @brief A Torch's module which wraps TVM's GraphExecutorFactory Class. + * The basic forward function calling TVM's runtime is provided. + * The TVM module can be serialized/deserialized as a Torch module. + */ +class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { + public: + explicit GraphExecutorFactoryWrapper(RuntimeModulePointer* executor_factory) + : executor_factory_(executor_factory) {} + + GraphExecutorFactoryWrapper() : GraphExecutorFactoryWrapper(get_last_saved_runtime_module()) {} + std::string Serialize() { return encode(executor_factory_); } + + explicit GraphExecutorFactoryWrapper(std::string state) { + executor_factory_ = decode(state.c_str()); + } + + c10::List forward(const c10::List& inputs) { + int input_length = inputs.size(); + + TORCH_CHECK(input_length > 0, "Receive empty list of input tensors"); + + std::vector tensors; + + for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i])); + + TensorList* outputs = new TensorList; + + auto num_outputs = graph_executor_module_forward( + executor_factory_, static_cast(tensors.data()), tensors.size(), outputs); + + c10::List ret; + ret.reserve(num_outputs); + + for (int k = 0; k < num_outputs; ++k) { + at::Tensor atTensor = at::fromDLPack((*outputs)[k]); + ret.emplace_back(atTensor); + } + + for (int k = 0; k < input_length; ++k) { + tensors[k]->deleter(tensors[k]); + } + + delete outputs; + + return ret; + } + + private: + RuntimeModulePointer* executor_factory_; +}; + +TORCH_LIBRARY(tvm_torch, m) { + m.class_("OperatorModuleWrapper") + .def(torch::init<>()) + .def("forward", &OperatorModuleWrapper::forward) + .def_pickle( + [](const c10::intrusive_ptr& self) -> std::string { + return self->Serialize(); + }, + [](std::string state) { return c10::make_intrusive(state); }); + m.class_("GraphExecutorFactoryWrapper") + .def(torch::init<>()) + .def("forward", &GraphExecutorFactoryWrapper::forward) + .def_pickle( + [](const c10::intrusive_ptr& self) -> std::string { + return self->Serialize(); + }, + [](std::string state) { + return c10::make_intrusive(state); + }); +} + +} // namespace contrib +} // namespace tvm \ No newline at end of file diff --git a/src/contrib/torch/pt_call_tvm/runtime_bridge.h b/src/contrib/torch/pt_call_tvm/runtime_bridge.h new file mode 100644 index 000000000000..930471ecc735 --- /dev/null +++ b/src/contrib/torch/pt_call_tvm/runtime_bridge.h @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file runtime_bridge.h + * \brief Util functions for pytorch tvm interaction. + */ +#ifndef TVM_CONTRIB_TORCH_RUNTIME_BRIDGE_H_ +#define TVM_CONTRIB_TORCH_RUNTIME_BRIDGE_H_ + +extern "C" { + +typedef DLManagedTensor** TensorList; + +struct DLPackTensorExt { + DLManagedTensor* dl_managed_tensor; + bool is_bool; +}; + + +struct RuntimeModulePointer; + +RuntimeModulePointer* get_last_saved_runtime_module(); + +void operator_module_forward(RuntimeModulePointer* runtime_module, TensorList inputs, + size_t input_size); + +int64_t graph_executor_module_forward(RuntimeModulePointer* graph_module, TensorList inputs, + size_t input_size, + TensorList* outputs); + +char* encode(RuntimeModulePointer* runtime_module); + +RuntimeModulePointer* decode(const char* state); +} + +#endif // TVM_CONTRIB_TORCH_RUNTIME_BRIDGE_H_ diff --git a/src/contrib/torch/pt_call_tvm/tvm_class.cc b/src/contrib/torch/pt_call_tvm/tvm_class.cc deleted file mode 100644 index 5e57dc152f11..000000000000 --- a/src/contrib/torch/pt_call_tvm/tvm_class.cc +++ /dev/null @@ -1,686 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include "../utils.h" - -namespace tvm { -namespace contrib { -namespace pytorch { - -/*! \brief Class holding necessary components to call TVM graph runtime */ -class TvmGraphModulePack { - public: - /*! - * \brief Constructor. - * - * \param path Encoded path of graph runtime assets. - * \param device_type int64_t, kDLCPU or kDLCUDA. - * \param device_id int64_t. - */ - explicit TvmGraphModulePack(std::string path, int64_t device_type, int64_t device_id) - : path_(std::move(path)) { - LOG(INFO) << "[TvmGraphModule] loading module at path: [" << path_ << "] on device [" - << (device_type == kDLCUDA ? "cuda:" : "cpu:") << device_id << "]..."; - std::string lib_path, graph_path, params_path; - DecodePaths(path_, &lib_path, &graph_path, ¶ms_path); - - // load graph - std::ifstream graph_in(graph_path); - std::string graph_data((std::istreambuf_iterator(graph_in)), - std::istreambuf_iterator()); - graph_in.close(); - - // load mod syslib - tvm::runtime::Module lib = tvm::runtime::Module::LoadFromFile(lib_path); - - const auto runtime_create = *tvm::runtime::Registry::Get("tvm.graph_executor.create"); - - // read params data - std::ifstream params_in(params_path, std::ios::binary); - std::string params_data((std::istreambuf_iterator(params_in)), - std::istreambuf_iterator()); - params_in.close(); - TVMByteArray params_arr; - params_arr.data = params_data.c_str(); - params_arr.size = params_data.length(); - - // set devices - module_ = runtime_create(graph_data, lib, device_type, device_id); - const tvm::runtime::PackedFunc load_params = module_.GetFunction("load_params"); - load_params(params_arr); - - set_input = module_.GetFunction("set_input_zero_copy"); - run = module_.GetFunction("run"); - get_output = module_.GetFunction("get_output"); - set_output = module_.GetFunction("set_output_zero_copy"); - num_outputs_ = module_.GetFunction("get_num_outputs")(); - } - - static constexpr char kPathDelimiter = '|'; - - /*! - * \brief Decode lib_path, graph_path, params_path from encoded path. - * - * \param path The encoded path, concated with `kPathDelimiter`. - * \param lib_path The path of .so lib file. - * \param graph_path The path of graph.json. - * \param params_path The path of params data. - */ - static void DecodePaths(const std::string& path, std::string* lib_path, std::string* graph_path, - std::string* params_path) { - std::vector paths; - for (size_t i = 0, pre = 0, lim = path.size(); i <= lim; ++i) { - if (i == lim || path.at(i) == kPathDelimiter) { - paths.push_back(path.substr(pre, i - pre)); - pre = i + 1; - } - } - CHECK_EQ(paths.size(), 3u); - *lib_path = paths.at(0); - *graph_path = paths.at(1); - *params_path = paths.at(2); - } - - /*! - * \brief Encode lib_path, graph_path, params_path by concat then with `kPathDelimiter`. - * - * \param lib_path The path of .so lib file. - * \param graph_path The path of graph.json. - * \param params_path The path of params data. - * - * \return The encoded path, concated with `kPathDelimiter`. - */ - static std::string EncodePaths(const std::string& lib_path, const std::string& graph_path, - const std::string& params_path) { - return lib_path + kPathDelimiter + graph_path + kPathDelimiter + params_path; - } - - const std::string& path() const { return path_; } - - const int64_t num_outputs() const { return num_outputs_; } - - tvm::runtime::PackedFunc set_input; - tvm::runtime::PackedFunc run; - tvm::runtime::PackedFunc get_output; - tvm::runtime::PackedFunc set_output; - - private: - tvm::runtime::Module module_; - int64_t num_outputs_; - std::string path_; -}; - -/*! \brief Class holding necessary components to call TVM VM runtime */ -class TvmVMModulePack { - public: - /*! - * \brief Constructor. - * - * \param path Encoded path of vm runtime assets. - * \param device_type int64_t, kDLCPU or kDLCUDA. - * \param device_id int64_t. - */ - explicit TvmVMModulePack(std::string path, int64_t device_type, int64_t device_id) - : path_(std::move(path)) { - LOG(INFO) << "[TvmVMModule] loading module at path: [" << path_ << "] on device [" - << (device_type == kDLCUDA ? "cuda:" : "cpu:") << device_id << "]..."; - // build tvm graph runtime - std::string lib_path, code_path; - DecodePaths(path_, &lib_path, &code_path); - // load lib - auto loaded_lib = tvm::runtime::Module::LoadFromFile(lib_path, "so"); - // load code - std::ifstream code_in(code_path); - std::string loaded_code((std::istreambuf_iterator(code_in)), - std::istreambuf_iterator()); - code_in.close(); - exe_ = tvm::runtime::vm::Executable::Load(loaded_code, loaded_lib); - const auto runtime_create = *tvm::runtime::Registry::Get("runtime._VirtualMachine"); - vm_ = runtime_create(exe_); - auto init_func = vm_.GetFunction("init", false); - auto alloc_type = static_cast(tvm::runtime::vm::AllocatorType::kPooled); - if (device_type != kDLCPU) { - // CPU is required for executing shape functions - init_func(static_cast(kDLCPU), 0, alloc_type, device_type, device_id, alloc_type); - } else { - init_func(device_type, device_id, alloc_type); - } - set_input = vm_.GetFunction("set_input", false); - invoke = vm_.GetFunction("invoke", false); - } - - static constexpr char kPathDelimiter = '|'; - - /*! - * \brief Decode lib_path, code_path from encoded path. - * - * \param path The encoded path, concated with `kPathDelimiter`. - * \param lib_path The path of lib file. - * \param code_path The path of code file. - */ - static void DecodePaths(const std::string& path, std::string* lib_path, std::string* code_path) { - std::vector paths; - for (size_t i = 0, pre = 0, lim = path.size(); i <= lim; ++i) { - if (i == lim || path.at(i) == kPathDelimiter) { - paths.push_back(path.substr(pre, i - pre)); - pre = i + 1; - } - } - CHECK_EQ(paths.size(), 2u); - *lib_path = paths.at(0); - *code_path = paths.at(1); - } - - /*! - * \brief Encode lib_path, code_path by concat then with `kPathDelimiter`. - * - * \param lib_path The path of vm lib file. - * \param code_path The path of code. - * - * \return The encoded path, concated with `kPathDelimiter`. - */ - static std::string EncodePaths(const std::string& lib_path, const std::string& code_path) { - return lib_path + kPathDelimiter + code_path; - } - - const std::string& path() const { return path_; } - - tvm::runtime::PackedFunc set_input; - tvm::runtime::PackedFunc invoke; - - private: - tvm::runtime::Module exe_; - tvm::runtime::Module vm_; - std::string path_; -}; - -/*! \brief Pytorch custom class to call TVM */ -class BaseTvmClass : public torch::jit::CustomClassHolder { - public: - /*! - * \brief Constructor. - * - * \param num_inputs Number of inputs. - * \param num_outputs Number of outputs. - * \param device std::string, use the pytorch device str format, e.g. `cuda:0`, 'cpu' - */ - BaseTvmClass(const int64_t num_inputs, const int64_t num_outputs, const std::string& device) - : num_inputs_(num_inputs), num_outputs_(num_outputs) { - auto torch_device = torch::Device(device); - device_type_ = torch_device.is_cuda() ? kDLCUDA : kDLCPU; - device_id_ = torch_device.index(); - } - - /*! \brief Virtual destructor. */ - virtual ~BaseTvmClass() {} - - /*! - * \brief Get repr string of pytorch input shapes. - * - * \param shapes Pytorch shapes of type List[List[int]]. - * - * \return std::string, the representation of inputs shapes. - */ - static std::string TvmShapeRepr(const c10::List>& shapes) { - std::stringstream ss; - for (const auto& shape : shapes) { - for (const auto& sz : static_cast>(shape)) { - ss << sz << "_"; - } - ss << "__"; - } - return ss.str(); - } - - /*! - * \brief Get input shapes. - * - * \param inputs Inputs with type List[Tensor]. - * - * \return outputs with type List[List[int]]. - */ - static c10::List> GetShapes(const c10::List& inputs) { - c10::List> shapes; - for (const auto& input : inputs) { - c10::List shape; - for (const auto sz : static_cast(input).sizes()) { - shape.push_back(sz); - } - shapes.push_back(shape); - } - return shapes; - } - - /*! - * \brief Move the TVM modules to given device. - * - * \param device String repr of the device to be moved to. - */ - virtual void to(const std::string& device) = 0; - - // getters - int64_t num_inputs() const { return num_inputs_; } - - int64_t num_outputs() const { return num_outputs_; } - - int64_t device_type() const { return device_type_; } - - int64_t device_id() const { return device_id_; } - - c10::DeviceType torch_device_type() const { - return device_type() == kDLCUDA ? torch::DeviceType::CUDA : torch::DeviceType::CPU; - } - - bool is_on_same_device(const torch::Tensor& tensor) const { - auto tensor_device_type = tensor.device().type(); - if (tensor_device_type == torch::DeviceType::CUDA) { - return tensor_device_type == torch_device_type() && device_id() == tensor.device().index(); - } - CHECK_EQ(tensor_device_type, torch::DeviceType::CPU); - return tensor_device_type == torch_device_type(); - } - - std::string device() const { return torch::Device(torch_device_type(), device_id()).str(); } - - /*! - * \brief Module forward. - * - * \param inputs Inputs with type List[Tensor]. - * - * \return outputs with type List[Tensor]. - */ - virtual c10::List forward(const c10::List& inputs) = 0; - - /*! - * \brief Serialize TVM Modules to Dict - */ - virtual c10::Dict SerializeTvmModules() const = 0; - - /*! - * \brief deserialize TVM Modules from Dict - */ - virtual void DeserializeTvmModules(const c10::Dict& shape_path_map) = 0; - - protected: - const int64_t num_inputs_; - const int64_t num_outputs_; - int64_t device_type_; - int64_t device_id_; -}; - -/*! \brief Pytorch custom class to call TVM graph runtime */ -class TvmGraphRuntimeClass : public BaseTvmClass { - public: - TvmGraphRuntimeClass(const int64_t num_inputs, const int64_t num_outputs, - const std::string& device) - : BaseTvmClass(num_inputs, num_outputs, device) {} - - /*! - * \brief Module forward. - * - * \param inputs Inputs with type List[Tensor]. - * - * \return outputs with type List[Tensor]. - */ - c10::List forward(const c10::List& inputs) override { - CHECK_EQ(inputs.size(), num_inputs_); - auto shape_repr = TvmShapeRepr(GetShapes(inputs)); - std::vector args(num_inputs_ + num_outputs_); - auto iter = tvm_modules_.find(shape_repr); - CHECK(iter != tvm_modules_.end()); - const auto& tvm_pack = iter->second; - std::vector buf_infos; - buf_infos.reserve(num_inputs_ + num_outputs_); - - for (int i = 0; i < num_inputs_; ++i) { - at::Tensor inp = inputs[i]; - CHECK(is_on_same_device(inp)) - << "input #" << i - << " of forward is not on the same device with TvmGraphRuntime, expected " << device() - << " but got " << inp.device().str(); - inp = inp.contiguous(); - buf_infos.emplace_back(inp); - auto& input_buf = buf_infos[i]; - input_buf.CopyFromOrigin(); - input_buf.MakeDLTensor(&args[i]); - tvm_pack.set_input(i, &args[i]); - } - // prepare output buffers - c10::List outputs; - outputs.reserve(num_outputs_); - - for (int i = 0; i < num_outputs_; ++i) { - tvm::runtime::NDArray output_arr = tvm_pack.get_output(i); - std::vector output_shape(output_arr->shape, output_arr->shape + output_arr->ndim); - - torch::ScalarType output_dtype = torch::ScalarType::Undefined; - CHECK(GetTorchDtype(output_arr.DataType(), &output_dtype)); - - CHECK(device_type_ == kDLCPU || device_type_ == kDLCUDA); - const c10::DeviceType pt_device_type = (device_type_ == kDLCUDA ? torch::kCUDA : torch::kCPU); - const auto options = - torch::TensorOptions().dtype(output_dtype).device(pt_device_type, device_id_); - - outputs.emplace_back(torch::empty(output_shape, options)); - buf_infos.emplace_back(outputs[i]); - auto& output_buf = buf_infos[num_inputs_ + i]; - output_buf.MakeDLTensor(&args[num_inputs_ + i]); - tvm_pack.set_output(i, &args[num_inputs_ + i]); - } - tvm_pack.run(); - for (int i = 0; i < num_outputs_; ++i) { - auto& output_buf = buf_infos[num_inputs_ + i]; - output_buf.CopyToOrigin(); - } - return outputs; - } - - /*! - * \brief Load TVM graph runtime module. - * - * \param shapes Input shapes. List[List[int]]. - * \param lib_path Path of .so lib file. - * \param graph_path Path of graph.json file. - * \param params_path Path of params data file. - */ - void LoadTvmModule(const c10::List>& shapes, const std::string& lib_path, - const std::string& graph_path, const std::string& params_path) { - std::string path = TvmGraphModulePack::EncodePaths(lib_path, graph_path, params_path); - auto shape_repr = TvmShapeRepr(shapes); - auto it_find = tvm_modules_.find(shape_repr); - if (it_find != tvm_modules_.end()) { - tvm_modules_.erase(it_find); - } - const auto it = - tvm_modules_.emplace(shape_repr, TvmGraphModulePack(path, device_type_, device_id_)).first; - if (it->second.num_outputs() != num_outputs_) { - LOG(FATAL) << "tvm class num outputs mismatch, expected " << num_outputs_ << ", got " - << it->second.num_outputs(); - } - } - - const std::map& tvm_modules() const { return tvm_modules_; } - - /*! - * \brief Serialize TVM modules to shape map. - * - * \return shape_path_map Dict of shape_repr to path. - */ - c10::Dict SerializeTvmModules() const override { - c10::Dict shape_path_map; - for (const auto& entry : tvm_modules()) { - shape_path_map.insert(entry.first, entry.second.path()); - } - return shape_path_map; - } - - /*! - * \brief Deserialize TVM modules from shape map. - * - * \param shape_path_map Dict of shape_repr to path. - */ - void DeserializeTvmModules(const c10::Dict& shape_path_map) override { - tvm_modules_.clear(); - for (const auto& entry : shape_path_map) { - const auto& shape_repr = entry.key(); - const auto& path = entry.value(); - tvm_modules_.emplace(shape_repr, TvmGraphModulePack(path, device_type_, device_id_)); - } - } - - /*! - * \brief Move the TVM modules to given device. - * - * \param device String repr of the device to be moved to. - */ - void to(const std::string& device) override { - if (device != this->device()) { - auto torch_device = torch::Device(device); - device_type_ = torch_device.is_cuda() ? kDLCUDA : kDLCPU; - device_id_ = torch_device.index(); - DeserializeTvmModules(SerializeTvmModules()); - } - } - - private: - std::map tvm_modules_; -}; - -/*! \brief Pytorch custom class to call TVM graph runtime */ -class TvmVMRuntimeClass : public BaseTvmClass { - public: - TvmVMRuntimeClass(const int64_t num_inputs, const int64_t num_outputs, const std::string& device) - : BaseTvmClass(num_inputs, num_outputs, device) {} - - /*! - * \brief Module forward. - * - * \param inputs Inputs with type List[Tensor]. - * - * \return outputs with type List[Tensor]. - */ - c10::List forward(const c10::List& inputs) override { - // get inputs repr str - auto shape_repr = TvmShapeRepr(GetShapes(inputs)); - // get tvm pack - auto iter = tvm_modules_.find(shape_repr); - CHECK(iter != tvm_modules_.end()) << "tvm module pack not found for shape_repr " << shape_repr; - const auto& tvm_pack = iter->second; - - // input tensors - CHECK_EQ(inputs.size(), num_inputs_); - std::vector args(num_inputs_); - std::vector args_arr(num_inputs_); - - for (int i = 0; i < num_inputs_; ++i) { - TensorAsBuf input_buf(inputs[i]); - input_buf.CopyFromOrigin(); - input_buf.MakeDLTensor(&args[i]); - args_arr[i] = - tvm::runtime::NDArray::FromDLPack(new DLManagedTensor({args[i], nullptr, nullptr})); - } - // set input - std::vector tvm_values(num_inputs_ + 1); - std::vector tvm_type_codes(num_inputs_ + 1); - tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); - setter(0, "main"); - for (int k = 0; k < num_inputs_; ++k) { - setter(k + 1, args_arr[k]); - } - tvm_pack.set_input.CallPacked( - tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), num_inputs_ + 1), nullptr); - - // run tvm - tvm::runtime::TVMRetValue ret = tvm_pack.invoke("main"); - - // get outputs - std::vector output_arrs(num_outputs_); - auto output_mismatch_msg = [](int actual, int expected) { - std::stringstream ss; - ss << "num_outputs not equal, actual:[" << actual << "] != expected:[" << expected << "]"; - return ss.str(); - }; - if (ret.type_code() == kTVMNDArrayHandle) { - CHECK_EQ(num_outputs_, 1) << output_mismatch_msg(1, num_outputs_); - output_arrs.at(0) = ret.AsObjectRef(); - } else if (ret.type_code() == kTVMObjectHandle) { - const auto& adt = ret.AsObjectRef(); - CHECK_EQ(adt.size(), num_outputs_) << output_mismatch_msg(adt.size(), num_outputs_); - for (size_t i = 0; i < adt.size(); ++i) { - CHECK(adt[i]->IsInstance()) - << "adt elements not tvm::runtime::NDArray"; - output_arrs.at(i) = tvm::runtime::Downcast(adt[i]); - } - } else { - LOG(FATAL) << "unsupported return type with type_code = " << ret.type_code(); - } - - std::vector output_args(num_outputs_); - c10::List outputs; - outputs.reserve(num_outputs_); - - for (int i = 0; i < num_outputs_; ++i) { - const auto& output_arr = output_arrs[i]; - std::vector output_shape(output_arr->shape, output_arr->shape + output_arr->ndim); - - torch::ScalarType output_dtype = torch::ScalarType::Undefined; - CHECK(GetTorchDtype(output_arr.DataType(), &output_dtype)); - - CHECK(device_type_ == kDLCPU || device_type_ == kDLCUDA); - const c10::DeviceType pt_device_type = (device_type_ == kDLCUDA ? torch::kCUDA : torch::kCPU); - const auto options = - torch::TensorOptions().dtype(output_dtype).device(pt_device_type, device_id_); - - outputs.emplace_back(torch::empty(output_shape, options)); - TensorAsBuf output_buf(outputs[i]); - output_buf.MakeDLTensor(&output_args[i]); - output_arr.CopyTo(&output_args[i]); - output_buf.CopyToOrigin(); - } - return outputs; - } - - /*! - * \brief Load TVM vm runtime module. - * - * \param shapes Input shapes. List[List[int]]. - * \param lib_path Path of .so lib file. - * \param code_path Path of code file. Typically named code.ro - */ - void LoadTvmModule(const c10::List>& shapes, const std::string& lib_path, - const std::string& code_path) { - std::string path = TvmVMModulePack::EncodePaths(lib_path, code_path); - auto shape_repr = TvmShapeRepr(shapes); - auto it_find = tvm_modules_.find(shape_repr); - if (it_find != tvm_modules_.end()) { - tvm_modules_.erase(it_find); - } - tvm_modules_.emplace(shape_repr, TvmVMModulePack(path, device_type_, device_id_)); - } - - const std::map& tvm_modules() const { return tvm_modules_; } - - /*! - * \brief Serialize TVM modules to shape map. - * - * \return shape_path_map Dict of shape_repr to path. - */ - c10::Dict SerializeTvmModules() const override { - c10::Dict shape_path_map; - for (const auto& entry : tvm_modules()) { - shape_path_map.insert(entry.first, entry.second.path()); - } - return shape_path_map; - } - - /*! - * \brief Deserialize TVM modules from shape map. - * - * \param shape_path_map Dict of shape_repr to path. - */ - void DeserializeTvmModules(const c10::Dict& shape_path_map) override { - tvm_modules_.clear(); - for (const auto& entry : shape_path_map) { - const auto& shape_repr = entry.key(); - const auto& path = entry.value(); - tvm_modules_.emplace(shape_repr, TvmVMModulePack(path, device_type_, device_id_)); - } - } - - /*! - * \brief Move the TVM modules to given device. - * - * \param device String repr of the device to be moved to. - */ - void to(const std::string& device) override { - if (device != this->device()) { - auto torch_device = torch::Device(device); - device_type_ = torch_device.is_cuda() ? kDLCUDA : kDLCPU; - device_id_ = torch_device.index(); - DeserializeTvmModules(SerializeTvmModules()); - } - } - - private: - std::map tvm_modules_; -}; - -// -using SerializeTuple = - std::tuple>; - -/***** registries *****/ -static auto __tvm_dsoop_graph_runtime_registry = - torch::jit::class_("tvm_dsoop", "TvmGraphModule") - .def(torch::init()) - .def("load_tvm_module", &TvmGraphRuntimeClass::LoadTvmModule) - .def("forward", &TvmGraphRuntimeClass::forward) - .def("to", &TvmGraphRuntimeClass::to) - .def_pickle( - [](const c10::intrusive_ptr& self) -> SerializeTuple { - return std::make_tuple(self->num_inputs(), self->num_outputs(), self->device(), - self->SerializeTvmModules()); - }, - [](SerializeTuple tuple) -> c10::intrusive_ptr { - auto ptr = c10::make_intrusive( - /*num_inputs=*/std::get<0>(tuple), - /*num_outputs=*/std::get<1>(tuple), - /*device=*/std::get<2>(tuple)); - ptr->DeserializeTvmModules(std::get<3>(tuple)); - return ptr; - }); - -static auto __tvm_dsoop_vm_runtime_registry = - torch::jit::class_("tvm_dsoop", "TvmVMModule") - .def(torch::init()) - .def("load_tvm_module", &TvmVMRuntimeClass::LoadTvmModule) - .def("forward", &TvmVMRuntimeClass::forward) - .def("to", &TvmVMRuntimeClass::to) - .def_pickle( - [](const c10::intrusive_ptr& self) -> SerializeTuple { - return std::make_tuple(self->num_inputs(), self->num_outputs(), self->device(), - self->SerializeTvmModules()); - }, - [](SerializeTuple tuple) -> c10::intrusive_ptr { - auto ptr = c10::make_intrusive( - /*num_inputs=*/std::get<0>(tuple), - /*num_outputs=*/std::get<1>(tuple), - /*device=*/std::get<2>(tuple)); - ptr->DeserializeTvmModules(std::get<3>(tuple)); - return ptr; - }); - -static auto __tvm_shape_repr_fn_registry = - torch::RegisterOperators("tvm_dsoop::tvm_shape_repr", &BaseTvmClass::TvmShapeRepr); -} // namespace pytorch -} // namespace contrib -} // namespace tvm From e1278396f001ec2217ad9229c72652037898f2a8 Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 26 Jul 2022 06:06:07 -0700 Subject: [PATCH 02/33] rename --- .../{RuntimeModuleWrapper.cc => RuntimeModuleWrapperTVM.cc} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/contrib/torch/pt_call_tvm/{RuntimeModuleWrapper.cc => RuntimeModuleWrapperTVM.cc} (100%) diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapperTVM.cc similarity index 100% rename from src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc rename to src/contrib/torch/pt_call_tvm/RuntimeModuleWrapperTVM.cc From 870e65a444ba678fa44b7a2929d31371c297b902 Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 26 Jul 2022 07:24:38 -0700 Subject: [PATCH 03/33] cmake --- cmake/modules/contrib/PT_TVMDSOOP.cmake | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/cmake/modules/contrib/PT_TVMDSOOP.cmake b/cmake/modules/contrib/PT_TVMDSOOP.cmake index 3bad3fd966c7..cfdc564a024c 100644 --- a/cmake/modules/contrib/PT_TVMDSOOP.cmake +++ b/cmake/modules/contrib/PT_TVMDSOOP.cmake @@ -6,7 +6,7 @@ # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -18,41 +18,41 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") find_package(PythonInterp REQUIRED) - execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import torch; print(torch.__path__[0].strip())" + # use ${PYTHON_EXECUTE} below + execute_process(COMMAND "/root/anaconda3/bin/python" -c "import torch; print(torch.__path__[0].strip())" OUTPUT_VARIABLE PT_PATH RESULT_VARIABLE PT_STATUS) - if (NOT ${PT_STATUS} EQUAL 0) + + if(NOT ${PT_STATUS} EQUAL 0) message(FATAL_ERROR "Fail to get pytorch path") endif() string(REGEX REPLACE "\n" "" PT_PATH "${PT_PATH}") message(STATUS "PyTorch path: ${PT_PATH}") - set(PT_COMPILE_FLAGS_STR "-I${PT_PATH}/include -D_GLIBCXX_USE_CXX11_ABI=0") + # set(PT_COMPILE_FLAGS_STR "-I${PT_PATH}/include -D_GLIBCXX_USE_CXX11_ABI=0") + set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapperTorch.cc PROPERTIES COMPILE_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=1") + set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapperTorch.cc PROPERTIES COMPILE_FLAGS "-I${PT_PATH}/include") set(PT_LINK_FLAGS_STR "-L${PT_PATH}/lib -l:libtorch.so -l:libtorch_python.so") if(NOT USE_CUDA STREQUAL "OFF") add_definitions(-DPT_TVMDSOOP_ENABLE_GPU) endif() - string(REGEX REPLACE "\n" " " PT_FLAGS "${PT_COMPILE_FLAGS} ${PT_LINK_FLAGS}") - separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND ${PT_COMPILE_FLAGS_STR}) + separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND) separate_arguments(PT_LINK_FLAGS UNIX_COMMAND ${PT_LINK_FLAGS_STR}) - set(LIBRARY_NAME pt_tvmdsoop) tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/**/*.cc) add_library(${LIBRARY_NAME} SHARED ${PTTVM_SRCS}) set(PTTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR}) - if (NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON") - add_dependencies(${LIBRARY_NAME} tvm) + if(NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON") + add_dependencies(${LIBRARY_NAME} tvm) endif() target_compile_options(${LIBRARY_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS}) target_link_libraries(${LIBRARY_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS}) target_compile_definitions(${LIBRARY_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=) - endif() - From 6a61d98fab42d52a38b74faf10fde98e112266ee Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 26 Jul 2022 18:43:10 -0700 Subject: [PATCH 04/33] deprecated --- python/tvm/contrib/torch/__init__.py | 13 +- python/tvm/contrib/torch/module.py | 121 ++++ python/tvm/contrib/torch/pytorch_tvm.py | 251 ++++++++ src/contrib/torch/pt_call_tvm/tvm_class.cc | 686 +++++++++++++++++++++ 4 files changed, 1070 insertions(+), 1 deletion(-) create mode 100644 python/tvm/contrib/torch/module.py create mode 100644 python/tvm/contrib/torch/pytorch_tvm.py create mode 100644 src/contrib/torch/pt_call_tvm/tvm_class.cc diff --git a/python/tvm/contrib/torch/__init__.py b/python/tvm/contrib/torch/__init__.py index 1c22228381f8..b4261f3ed906 100644 --- a/python/tvm/contrib/torch/__init__.py +++ b/python/tvm/contrib/torch/__init__.py @@ -39,6 +39,17 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): _load_platform_specific_library() +from . import module + +GraphModule = module.GraphModule +VMModule = module.VMModule +TraceTvmModule = module.TraceTvmModule + +from . import pytorch_tvm + +PyTorchTVMModule = pytorch_tvm.PyTorchTVMModule +compile = pytorch_tvm.compile + from . import as_torch TVMScriptIRModule = as_torch.OperatorModuleWrapper @@ -47,4 +58,4 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): from . import optimize_torch GraphExecutorFactoryWrapper = optimize_torch.GraphExecutorFactoryWrapper -optimize_torch = optimize_torch.optimize_torch +optimize_torch = optimize_torch.optimize_torch \ No newline at end of file diff --git a/python/tvm/contrib/torch/module.py b/python/tvm/contrib/torch/module.py new file mode 100644 index 000000000000..1652730b86eb --- /dev/null +++ b/python/tvm/contrib/torch/module.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Module container of PyTorch custom class""" +from typing import List +import torch + + +class GraphModule(torch.nn.Module): + r"""Module container of Pytorch class which wraps exported + TVM op implementation library to be called on Pytorch side""" + + @classmethod + def shape_repr(cls, input_shapes): + return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes) + + def __init__(self, num_inputs, num_outputs, device=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.empty(0)) + self.engine = None + + if device is not None: + self.to(device) + self.engine = torch.classes.tvm_dsoop.TvmGraphModule(num_inputs, num_outputs, self.device) + + def init(self, input_shapes, lib_path, graph_path, params_path): + r"""Load tvm module""" + self.engine.load_tvm_module(input_shapes, lib_path, graph_path, params_path) + + def forward(self, inputs: List[torch.Tensor]): + r"""Call tvm module to forward""" + return self.engine.forward(inputs) + + @property + def device(self): + r"""Get the device string""" + return str(self.dummy_param.device) + + def _apply(self, fn): + r"""Override to device function, manually move tvm module to desired device""" + super()._apply(fn) + if self.engine is not None: + self.engine.to(self.device) + return self + + +class VMModule(torch.nn.Module): + r"""Module container of Pytorch class which wraps exported + TVM op implementation library to be called on Pytorch side""" + + @classmethod + def shape_repr(cls, input_shapes): + return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes) + + def __init__(self, num_inputs, num_outputs, device=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.empty(0)) + self.engine = None + + if device is not None: + self.to(device) + self.engine = torch.classes.tvm_dsoop.TvmVMModule(num_inputs, num_outputs, self.device) + + def init(self, input_shapes, lib_path, code_path): + r"""Load tvm module""" + self.engine.load_tvm_module(input_shapes, lib_path, code_path) + + def forward(self, inputs: List[torch.Tensor]): + r"""Call tvm module to forward""" + return self.engine.forward(inputs) + + @property + def device(self): + r"""Get the device string""" + return str(self.dummy_param.device) + + def _apply(self, fn): + r"""Override to device function, manually move tvm module to desired device""" + super()._apply(fn) + if self.engine is not None: + self.engine.to(self.device) + return self + + +class TraceTvmModule(torch.nn.Module): + r"""Wrapper for trace GraphModule + + GraphModule and VMModule only supports List[Tensor] inputs and cannot be traced. + This is a wrapper class for trace GraphModule or VMModule in order to support + arbitrary number of inputs + + Example: + import tvm.contrib.torch + tvm_module = tvm.contrib.torch.GraphModule(1, 1, 'cuda:0') + tvm_module.init(input_shapes, lib_path, graph_path, params_path) + + trace_wrapper = tvm.contrib.torch.TraceGraphModule(torch.jit.script(tvm_module)) + traced = torch.jit.trace(trace_wrapper, example_inputs) + """ + + def __init__(self, tvm_module): + super().__init__() + self.tvm_module = tvm_module + + def forward(self, *inputs): + outputs = self.tvm_module(inputs) + return outputs[0] if len(outputs) == 1 else tuple(outputs) \ No newline at end of file diff --git a/python/tvm/contrib/torch/pytorch_tvm.py b/python/tvm/contrib/torch/pytorch_tvm.py new file mode 100644 index 000000000000..914b86d8fc9c --- /dev/null +++ b/python/tvm/contrib/torch/pytorch_tvm.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin +"""`compile` api that convert torch module to torch tvm module""" +import os +import warnings +import tvm +import tvm.testing +from tvm import relay, autotvm +from tvm.runtime import load_module +from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner +from tvm.contrib import graph_executor +from tvm.contrib.debugger import debug_executor +from . import GraphModule + + +def tune_tasks( + tasks, + measure_option, + tuner="xgb", + n_trial=1000, + early_stopping=None, + log_filename="tuning.log", + use_transfer_learning=True, +): + """Tune tasks and generate tuning log to file""" + # create tmp log file + tmp_log_file = log_filename + ".tmp" + if os.path.exists(tmp_log_file): + os.remove(tmp_log_file) + + for i, tsk in enumerate(reversed(tasks)): + prefix = f"[Task {i + 1:2d}/{len(tasks):2d}] " + + # create tuner + if tuner in ("xgb", "sgb-rank"): + tuner_obj = XGBTuner(tsk, loss_type="rank") + elif tuner == "ga": + tuner_obj = GATuner(tsk, pop_size=100) + elif tuner == "random": + tuner_obj = RandomTuner(tsk) + elif tuner == "gridsearch": + tuner_obj = GridSearchTuner(tsk) + else: + raise ValueError("Invalid tuner: " + tuner) + + if use_transfer_learning: + if os.path.isfile(tmp_log_file): + tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) + + # do tuning + tsk_trial = min(n_trial, len(tsk.config_space)) + tuner_obj.tune( + n_trial=tsk_trial, + early_stopping=early_stopping, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(tsk_trial, prefix=prefix), + autotvm.callback.log_to_file(tmp_log_file), + ], + ) + + # pick best records to a cache file + if not os.path.exists(log_filename): + with open(log_filename, "w", encoding="utf-8"): + pass + if os.path.exists(tmp_log_file): + autotvm.record.pick_best(tmp_log_file, log_filename) + os.remove(tmp_log_file) + + +def get_tuning_opt(log_file="tuning.log", n_trial=200): + """Returns tuning options""" + tuning_opt = { + "log_filename": log_file, + "tuner": "random", + "n_trial": n_trial, + "early_stopping": 60, + "measure_option": autotvm.measure_option( + builder=autotvm.LocalBuilder(timeout=10), + runner=autotvm.LocalRunner(number=20, repeat=3, timeout=4, min_repeat_ms=150), + ), + } + return tuning_opt + + +TVM_ASSETS = ["mod.so", "graph.json", "params"] + + +class PyTorchTVMModule: + """Helper class for compiling pytorch module to tvm module""" + + def __init__(self, target="cuda", device=tvm.cuda(0)) -> None: + self.script_module = None + self.input_infos = None + self.default_dtype = "float32" + self.mod = None + self.params = None + self.tasks = None + self.target = target + self.dev = device + self.log_file = None + self.tvm_module = None + self.tvm_graph = None + self.tvm_lib = None + self.tvm_params = None + + def from_pytorch(self, script_module, input_infos, default_dtype="float32"): + self.script_module = script_module + self.input_infos = input_infos + self.default_dtype = default_dtype + self.mod, self.params = relay.frontend.from_pytorch( + script_module, input_infos, default_dtype=default_dtype + ) + + def tune_tvm(self, log_file="tuning.log", n_trial=200): + self.tasks = autotvm.task.extract_from_program( + self.mod["main"], + target=self.target, + params=self.params, + ) + self.log_file = log_file + tuning_opt = get_tuning_opt(log_file, n_trial) + tune_tasks(self.tasks, **tuning_opt) + + def build_tvm(self, export_dir, debug_runtime=False): + tvm_mod = self._build_tvm(debug_runtime) + self._export_tvm(export_dir) + return tvm_mod + + def _build_tvm(self, debug_runtime=False): + # compile kernels with history best records + with autotvm.apply_history_best(self.log_file): + with tvm.transform.PassContext(opt_level=3): + self.tvm_graph, self.tvm_lib, self.tvm_params = relay.build( + self.mod, target=self.target, params=self.params + ) + + if not debug_runtime: + self.tvm_module = graph_executor.create(self.tvm_graph, self.tvm_lib, device=self.dev) + else: + self.tvm_module = debug_executor.create(self.tvm_graph, self.tvm_lib, device=self.dev) + self.tvm_module.set_input(**self.tvm_params) + return self.tvm_module + + def _export_tvm(self, export_dir): + if not os.path.isdir(export_dir): + os.makedirs(export_dir) + self.export_dir = export_dir + self.tvm_lib.export_library(os.path.join(export_dir, TVM_ASSETS[0])) + with open(os.path.join(export_dir, TVM_ASSETS[1]), "w", encoding="utf8") as fout: + fout.write(self.tvm_graph) + with open(os.path.join(export_dir, TVM_ASSETS[2]), "wb") as fout: + fout.write(relay.save_param_dict(self.tvm_params)) + + def load_tvm(self, export_dir): + """Load tvm module from export directory""" + self.export_dir = export_dir + self.tvm_lib = load_module(os.path.join(export_dir, TVM_ASSETS[0])) + with open(os.path.join(export_dir, TVM_ASSETS[1]), "r", encoding="utf8") as f: + self.tvm_graph = f.read() + with open(os.path.join(export_dir, TVM_ASSETS[2]), "rb") as f: + self.tvm_params = relay.load_param_dict(f.read()) + + self.tvm_module = graph_executor.create(self.tvm_graph, self.tvm_lib, device=self.dev) + self.tvm_module.set_input(**self.tvm_params) + return self.tvm_module + + def build_pytorch_module(self, num_inputs, num_outputs, input_infos=None): + """Build pytorch module containing TVM Graph Module""" + assert self.export_dir, "you must build_tvm or load_tvm before" + input_infos = input_infos or self.input_infos + assert input_infos + assert len(input_infos) == num_inputs + assets = [os.path.join(self.export_dir, i) for i in TVM_ASSETS] + input_shapes = [i[1] for i in input_infos] + + def _tvm_dev_to_pt_dev(device): + """convert tvm device to pytorch device string""" + if tvm.runtime.Device.MASK2STR[device.device_type] == "cpu": + return "cpu" + if tvm.runtime.Device.MASK2STR[device.device_type] == "cuda": + return f"cuda:{device.device_id}" + raise ValueError(f"unsupported device for pt graph module: {device}") + + mod = GraphModule(num_inputs=num_inputs, num_outputs=num_outputs).to( + _tvm_dev_to_pt_dev(self.dev) + ) + mod.init(input_shapes, *assets) + return mod + + +def compile(script_module, option): + """ + example: + option = { + "input_infos": [ + ("x", (1, 3, 244, 244)), + ], + "default_dtype": "float16", + "export_dir": "pytorch_compiled", + "num_outputs": 1, + "tuning_n_trials": 20, # set zero to skip tuning + "tuning_log_file": "tuning.log", + "target": "llvm", + "device": tvm.cpu(), + } + script_module = torch.jit.script(model) + pytorch_tvm_module = compile(script_module, option) + pytorch_tvm_module("model_tvm.pt") + """ + warnings.warn("We suggest users to use `optimized_torch` for tuning Torch modules instead", DeprecationWarning) + input_infos = option["input_infos"] + default_dtype = option.get("default_dtype", "float32") + export_dir = option.get("export_dir", "pytorch_compiled") + tuning_log_file = option.get("tuning_log_file", "tuning.log") + tuning_n_trials = option.get("tuning_n_trials", 20) + num_outputs = option.get("num_outputs", 1) + target = option.get("target", "cuda") + device = option.get("device", tvm.cuda(0)) + + mod = PyTorchTVMModule(target=target, device=device) + print("Converting...") + + mod.log_file = tuning_log_file + mod.from_pytorch(script_module, input_infos, default_dtype) + + if tuning_n_trials > 0: + print("Tuning...") + mod.tune_tvm(log_file=tuning_log_file, n_trial=tuning_n_trials) + + print("Building...") + mod.build_tvm(export_dir) + pytorch_mod = mod.build_pytorch_module(num_inputs=len(input_infos), num_outputs=num_outputs) + return pytorch_mod \ No newline at end of file diff --git a/src/contrib/torch/pt_call_tvm/tvm_class.cc b/src/contrib/torch/pt_call_tvm/tvm_class.cc new file mode 100644 index 000000000000..2d9e4064c003 --- /dev/null +++ b/src/contrib/torch/pt_call_tvm/tvm_class.cc @@ -0,0 +1,686 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace contrib { +namespace pytorch { + +/*! \brief Class holding necessary components to call TVM graph runtime */ +class TvmGraphModulePack { + public: + /*! + * \brief Constructor. + * + * \param path Encoded path of graph runtime assets. + * \param device_type int64_t, kDLCPU or kDLCUDA. + * \param device_id int64_t. + */ + explicit TvmGraphModulePack(std::string path, int64_t device_type, int64_t device_id) + : path_(std::move(path)) { + LOG(INFO) << "[TvmGraphModule] loading module at path: [" << path_ << "] on device [" + << (device_type == kDLCUDA ? "cuda:" : "cpu:") << device_id << "]..."; + std::string lib_path, graph_path, params_path; + DecodePaths(path_, &lib_path, &graph_path, ¶ms_path); + + // load graph + std::ifstream graph_in(graph_path); + std::string graph_data((std::istreambuf_iterator(graph_in)), + std::istreambuf_iterator()); + graph_in.close(); + + // load mod syslib + tvm::runtime::Module lib = tvm::runtime::Module::LoadFromFile(lib_path); + + const auto runtime_create = *tvm::runtime::Registry::Get("tvm.graph_executor.create"); + + // read params data + std::ifstream params_in(params_path, std::ios::binary); + std::string params_data((std::istreambuf_iterator(params_in)), + std::istreambuf_iterator()); + params_in.close(); + TVMByteArray params_arr; + params_arr.data = params_data.c_str(); + params_arr.size = params_data.length(); + + // set devices + module_ = runtime_create(graph_data, lib, device_type, device_id); + const tvm::runtime::PackedFunc load_params = module_.GetFunction("load_params"); + load_params(params_arr); + + set_input = module_.GetFunction("set_input_zero_copy"); + run = module_.GetFunction("run"); + get_output = module_.GetFunction("get_output"); + set_output = module_.GetFunction("set_output_zero_copy"); + num_outputs_ = module_.GetFunction("get_num_outputs")(); + } + + static constexpr char kPathDelimiter = '|'; + + /*! + * \brief Decode lib_path, graph_path, params_path from encoded path. + * + * \param path The encoded path, concated with `kPathDelimiter`. + * \param lib_path The path of .so lib file. + * \param graph_path The path of graph.json. + * \param params_path The path of params data. + */ + static void DecodePaths(const std::string& path, std::string* lib_path, std::string* graph_path, + std::string* params_path) { + std::vector paths; + for (size_t i = 0, pre = 0, lim = path.size(); i <= lim; ++i) { + if (i == lim || path.at(i) == kPathDelimiter) { + paths.push_back(path.substr(pre, i - pre)); + pre = i + 1; + } + } + CHECK_EQ(paths.size(), 3u); + *lib_path = paths.at(0); + *graph_path = paths.at(1); + *params_path = paths.at(2); + } + + /*! + * \brief Encode lib_path, graph_path, params_path by concat then with `kPathDelimiter`. + * + * \param lib_path The path of .so lib file. + * \param graph_path The path of graph.json. + * \param params_path The path of params data. + * + * \return The encoded path, concated with `kPathDelimiter`. + */ + static std::string EncodePaths(const std::string& lib_path, const std::string& graph_path, + const std::string& params_path) { + return lib_path + kPathDelimiter + graph_path + kPathDelimiter + params_path; + } + + const std::string& path() const { return path_; } + + const int64_t num_outputs() const { return num_outputs_; } + + tvm::runtime::PackedFunc set_input; + tvm::runtime::PackedFunc run; + tvm::runtime::PackedFunc get_output; + tvm::runtime::PackedFunc set_output; + + private: + tvm::runtime::Module module_; + int64_t num_outputs_; + std::string path_; +}; + +/*! \brief Class holding necessary components to call TVM VM runtime */ +class TvmVMModulePack { + public: + /*! + * \brief Constructor. + * + * \param path Encoded path of vm runtime assets. + * \param device_type int64_t, kDLCPU or kDLCUDA. + * \param device_id int64_t. + */ + explicit TvmVMModulePack(std::string path, int64_t device_type, int64_t device_id) + : path_(std::move(path)) { + LOG(INFO) << "[TvmVMModule] loading module at path: [" << path_ << "] on device [" + << (device_type == kDLCUDA ? "cuda:" : "cpu:") << device_id << "]..."; + // build tvm graph runtime + std::string lib_path, code_path; + DecodePaths(path_, &lib_path, &code_path); + // load lib + auto loaded_lib = tvm::runtime::Module::LoadFromFile(lib_path, "so"); + // load code + std::ifstream code_in(code_path); + std::string loaded_code((std::istreambuf_iterator(code_in)), + std::istreambuf_iterator()); + code_in.close(); + exe_ = tvm::runtime::vm::Executable::Load(loaded_code, loaded_lib); + const auto runtime_create = *tvm::runtime::Registry::Get("runtime._VirtualMachine"); + vm_ = runtime_create(exe_); + auto init_func = vm_.GetFunction("init", false); + auto alloc_type = static_cast(tvm::runtime::vm::AllocatorType::kPooled); + if (device_type != kDLCPU) { + // CPU is required for executing shape functions + init_func(static_cast(kDLCPU), 0, alloc_type, device_type, device_id, alloc_type); + } else { + init_func(device_type, device_id, alloc_type); + } + set_input = vm_.GetFunction("set_input", false); + invoke = vm_.GetFunction("invoke", false); + } + + static constexpr char kPathDelimiter = '|'; + + /*! + * \brief Decode lib_path, code_path from encoded path. + * + * \param path The encoded path, concated with `kPathDelimiter`. + * \param lib_path The path of lib file. + * \param code_path The path of code file. + */ + static void DecodePaths(const std::string& path, std::string* lib_path, std::string* code_path) { + std::vector paths; + for (size_t i = 0, pre = 0, lim = path.size(); i <= lim; ++i) { + if (i == lim || path.at(i) == kPathDelimiter) { + paths.push_back(path.substr(pre, i - pre)); + pre = i + 1; + } + } + CHECK_EQ(paths.size(), 2u); + *lib_path = paths.at(0); + *code_path = paths.at(1); + } + + /*! + * \brief Encode lib_path, code_path by concat then with `kPathDelimiter`. + * + * \param lib_path The path of vm lib file. + * \param code_path The path of code. + * + * \return The encoded path, concated with `kPathDelimiter`. + */ + static std::string EncodePaths(const std::string& lib_path, const std::string& code_path) { + return lib_path + kPathDelimiter + code_path; + } + + const std::string& path() const { return path_; } + + tvm::runtime::PackedFunc set_input; + tvm::runtime::PackedFunc invoke; + + private: + tvm::runtime::Module exe_; + tvm::runtime::Module vm_; + std::string path_; +}; + +/*! \brief Pytorch custom class to call TVM */ +class BaseTvmClass : public torch::jit::CustomClassHolder { + public: + /*! + * \brief Constructor. + * + * \param num_inputs Number of inputs. + * \param num_outputs Number of outputs. + * \param device std::string, use the pytorch device str format, e.g. `cuda:0`, 'cpu' + */ + BaseTvmClass(const int64_t num_inputs, const int64_t num_outputs, const std::string& device) + : num_inputs_(num_inputs), num_outputs_(num_outputs) { + auto torch_device = torch::Device(device); + device_type_ = torch_device.is_cuda() ? kDLCUDA : kDLCPU; + device_id_ = torch_device.index(); + } + + /*! \brief Virtual destructor. */ + virtual ~BaseTvmClass() {} + + /*! + * \brief Get repr string of pytorch input shapes. + * + * \param shapes Pytorch shapes of type List[List[int]]. + * + * \return std::string, the representation of inputs shapes. + */ + static std::string TvmShapeRepr(const c10::List>& shapes) { + std::stringstream ss; + for (const auto& shape : shapes) { + for (const auto& sz : static_cast>(shape)) { + ss << sz << "_"; + } + ss << "__"; + } + return ss.str(); + } + + /*! + * \brief Get input shapes. + * + * \param inputs Inputs with type List[Tensor]. + * + * \return outputs with type List[List[int]]. + */ + static c10::List> GetShapes(const c10::List& inputs) { + c10::List> shapes; + for (const auto& input : inputs) { + c10::List shape; + for (const auto sz : static_cast(input).sizes()) { + shape.push_back(sz); + } + shapes.push_back(shape); + } + return shapes; + } + + /*! + * \brief Move the TVM modules to given device. + * + * \param device String repr of the device to be moved to. + */ + virtual void to(const std::string& device) = 0; + + // getters + int64_t num_inputs() const { return num_inputs_; } + + int64_t num_outputs() const { return num_outputs_; } + + int64_t device_type() const { return device_type_; } + + int64_t device_id() const { return device_id_; } + + c10::DeviceType torch_device_type() const { + return device_type() == kDLCUDA ? torch::DeviceType::CUDA : torch::DeviceType::CPU; + } + + bool is_on_same_device(const torch::Tensor& tensor) const { + auto tensor_device_type = tensor.device().type(); + if (tensor_device_type == torch::DeviceType::CUDA) { + return tensor_device_type == torch_device_type() && device_id() == tensor.device().index(); + } + CHECK_EQ(tensor_device_type, torch::DeviceType::CPU); + return tensor_device_type == torch_device_type(); + } + + std::string device() const { return torch::Device(torch_device_type(), device_id()).str(); } + + /*! + * \brief Module forward. + * + * \param inputs Inputs with type List[Tensor]. + * + * \return outputs with type List[Tensor]. + */ + virtual c10::List forward(const c10::List& inputs) = 0; + + /*! + * \brief Serialize TVM Modules to Dict + */ + virtual c10::Dict SerializeTvmModules() const = 0; + + /*! + * \brief deserialize TVM Modules from Dict + */ + virtual void DeserializeTvmModules(const c10::Dict& shape_path_map) = 0; + + protected: + const int64_t num_inputs_; + const int64_t num_outputs_; + int64_t device_type_; + int64_t device_id_; +}; + +/*! \brief Pytorch custom class to call TVM graph runtime */ +class TvmGraphRuntimeClass : public BaseTvmClass { + public: + TvmGraphRuntimeClass(const int64_t num_inputs, const int64_t num_outputs, + const std::string& device) + : BaseTvmClass(num_inputs, num_outputs, device) {} + + /*! + * \brief Module forward. + * + * \param inputs Inputs with type List[Tensor]. + * + * \return outputs with type List[Tensor]. + */ + c10::List forward(const c10::List& inputs) override { + CHECK_EQ(inputs.size(), num_inputs_); + auto shape_repr = TvmShapeRepr(GetShapes(inputs)); + std::vector args(num_inputs_ + num_outputs_); + auto iter = tvm_modules_.find(shape_repr); + CHECK(iter != tvm_modules_.end()); + const auto& tvm_pack = iter->second; + std::vector buf_infos; + buf_infos.reserve(num_inputs_ + num_outputs_); + + for (int i = 0; i < num_inputs_; ++i) { + at::Tensor inp = inputs[i]; + CHECK(is_on_same_device(inp)) + << "input #" << i + << " of forward is not on the same device with TvmGraphRuntime, expected " << device() + << " but got " << inp.device().str(); + inp = inp.contiguous(); + buf_infos.emplace_back(inp); + auto& input_buf = buf_infos[i]; + input_buf.CopyFromOrigin(); + input_buf.MakeDLTensor(&args[i]); + tvm_pack.set_input(i, &args[i]); + } + // prepare output buffers + c10::List outputs; + outputs.reserve(num_outputs_); + + for (int i = 0; i < num_outputs_; ++i) { + tvm::runtime::NDArray output_arr = tvm_pack.get_output(i); + std::vector output_shape(output_arr->shape, output_arr->shape + output_arr->ndim); + + torch::ScalarType output_dtype = torch::ScalarType::Undefined; + CHECK(GetTorchDtype(output_arr.DataType(), &output_dtype)); + + CHECK(device_type_ == kDLCPU || device_type_ == kDLCUDA); + const c10::DeviceType pt_device_type = (device_type_ == kDLCUDA ? torch::kCUDA : torch::kCPU); + const auto options = + torch::TensorOptions().dtype(output_dtype).device(pt_device_type, device_id_); + + outputs.emplace_back(torch::empty(output_shape, options)); + buf_infos.emplace_back(outputs[i]); + auto& output_buf = buf_infos[num_inputs_ + i]; + output_buf.MakeDLTensor(&args[num_inputs_ + i]); + tvm_pack.set_output(i, &args[num_inputs_ + i]); + } + tvm_pack.run(); + for (int i = 0; i < num_outputs_; ++i) { + auto& output_buf = buf_infos[num_inputs_ + i]; + output_buf.CopyToOrigin(); + } + return outputs; + } + + /*! + * \brief Load TVM graph runtime module. + * + * \param shapes Input shapes. List[List[int]]. + * \param lib_path Path of .so lib file. + * \param graph_path Path of graph.json file. + * \param params_path Path of params data file. + */ + void LoadTvmModule(const c10::List>& shapes, const std::string& lib_path, + const std::string& graph_path, const std::string& params_path) { + std::string path = TvmGraphModulePack::EncodePaths(lib_path, graph_path, params_path); + auto shape_repr = TvmShapeRepr(shapes); + auto it_find = tvm_modules_.find(shape_repr); + if (it_find != tvm_modules_.end()) { + tvm_modules_.erase(it_find); + } + const auto it = + tvm_modules_.emplace(shape_repr, TvmGraphModulePack(path, device_type_, device_id_)).first; + if (it->second.num_outputs() != num_outputs_) { + LOG(FATAL) << "tvm class num outputs mismatch, expected " << num_outputs_ << ", got " + << it->second.num_outputs(); + } + } + + const std::map& tvm_modules() const { return tvm_modules_; } + + /*! + * \brief Serialize TVM modules to shape map. + * + * \return shape_path_map Dict of shape_repr to path. + */ + c10::Dict SerializeTvmModules() const override { + c10::Dict shape_path_map; + for (const auto& entry : tvm_modules()) { + shape_path_map.insert(entry.first, entry.second.path()); + } + return shape_path_map; + } + + /*! + * \brief Deserialize TVM modules from shape map. + * + * \param shape_path_map Dict of shape_repr to path. + */ + void DeserializeTvmModules(const c10::Dict& shape_path_map) override { + tvm_modules_.clear(); + for (const auto& entry : shape_path_map) { + const auto& shape_repr = entry.key(); + const auto& path = entry.value(); + tvm_modules_.emplace(shape_repr, TvmGraphModulePack(path, device_type_, device_id_)); + } + } + + /*! + * \brief Move the TVM modules to given device. + * + * \param device String repr of the device to be moved to. + */ + void to(const std::string& device) override { + if (device != this->device()) { + auto torch_device = torch::Device(device); + device_type_ = torch_device.is_cuda() ? kDLCUDA : kDLCPU; + device_id_ = torch_device.index(); + DeserializeTvmModules(SerializeTvmModules()); + } + } + + private: + std::map tvm_modules_; +}; + +/*! \brief Pytorch custom class to call TVM graph runtime */ +class TvmVMRuntimeClass : public BaseTvmClass { + public: + TvmVMRuntimeClass(const int64_t num_inputs, const int64_t num_outputs, const std::string& device) + : BaseTvmClass(num_inputs, num_outputs, device) {} + + /*! + * \brief Module forward. + * + * \param inputs Inputs with type List[Tensor]. + * + * \return outputs with type List[Tensor]. + */ + c10::List forward(const c10::List& inputs) override { + // get inputs repr str + auto shape_repr = TvmShapeRepr(GetShapes(inputs)); + // get tvm pack + auto iter = tvm_modules_.find(shape_repr); + CHECK(iter != tvm_modules_.end()) << "tvm module pack not found for shape_repr " << shape_repr; + const auto& tvm_pack = iter->second; + + // input tensors + CHECK_EQ(inputs.size(), num_inputs_); + std::vector args(num_inputs_); + std::vector args_arr(num_inputs_); + + for (int i = 0; i < num_inputs_; ++i) { + TensorAsBuf input_buf(inputs[i]); + input_buf.CopyFromOrigin(); + input_buf.MakeDLTensor(&args[i]); + args_arr[i] = + tvm::runtime::NDArray::FromDLPack(new DLManagedTensor({args[i], nullptr, nullptr})); + } + // set input + std::vector tvm_values(num_inputs_ + 1); + std::vector tvm_type_codes(num_inputs_ + 1); + tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); + setter(0, "main"); + for (int k = 0; k < num_inputs_; ++k) { + setter(k + 1, args_arr[k]); + } + tvm_pack.set_input.CallPacked( + tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), num_inputs_ + 1), nullptr); + + // run tvm + tvm::runtime::TVMRetValue ret = tvm_pack.invoke("main"); + + // get outputs + std::vector output_arrs(num_outputs_); + auto output_mismatch_msg = [](int actual, int expected) { + std::stringstream ss; + ss << "num_outputs not equal, actual:[" << actual << "] != expected:[" << expected << "]"; + return ss.str(); + }; + if (ret.type_code() == kTVMNDArrayHandle) { + CHECK_EQ(num_outputs_, 1) << output_mismatch_msg(1, num_outputs_); + output_arrs.at(0) = ret.AsObjectRef(); + } else if (ret.type_code() == kTVMObjectHandle) { + const auto& adt = ret.AsObjectRef(); + CHECK_EQ(adt.size(), num_outputs_) << output_mismatch_msg(adt.size(), num_outputs_); + for (size_t i = 0; i < adt.size(); ++i) { + CHECK(adt[i]->IsInstance()) + << "adt elements not tvm::runtime::NDArray"; + output_arrs.at(i) = tvm::runtime::Downcast(adt[i]); + } + } else { + LOG(FATAL) << "unsupported return type with type_code = " << ret.type_code(); + } + + std::vector output_args(num_outputs_); + c10::List outputs; + outputs.reserve(num_outputs_); + + for (int i = 0; i < num_outputs_; ++i) { + const auto& output_arr = output_arrs[i]; + std::vector output_shape(output_arr->shape, output_arr->shape + output_arr->ndim); + + torch::ScalarType output_dtype = torch::ScalarType::Undefined; + CHECK(GetTorchDtype(output_arr.DataType(), &output_dtype)); + + CHECK(device_type_ == kDLCPU || device_type_ == kDLCUDA); + const c10::DeviceType pt_device_type = (device_type_ == kDLCUDA ? torch::kCUDA : torch::kCPU); + const auto options = + torch::TensorOptions().dtype(output_dtype).device(pt_device_type, device_id_); + + outputs.emplace_back(torch::empty(output_shape, options)); + TensorAsBuf output_buf(outputs[i]); + output_buf.MakeDLTensor(&output_args[i]); + output_arr.CopyTo(&output_args[i]); + output_buf.CopyToOrigin(); + } + return outputs; + } + + /*! + * \brief Load TVM vm runtime module. + * + * \param shapes Input shapes. List[List[int]]. + * \param lib_path Path of .so lib file. + * \param code_path Path of code file. Typically named code.ro + */ + void LoadTvmModule(const c10::List>& shapes, const std::string& lib_path, + const std::string& code_path) { + std::string path = TvmVMModulePack::EncodePaths(lib_path, code_path); + auto shape_repr = TvmShapeRepr(shapes); + auto it_find = tvm_modules_.find(shape_repr); + if (it_find != tvm_modules_.end()) { + tvm_modules_.erase(it_find); + } + tvm_modules_.emplace(shape_repr, TvmVMModulePack(path, device_type_, device_id_)); + } + + const std::map& tvm_modules() const { return tvm_modules_; } + + /*! + * \brief Serialize TVM modules to shape map. + * + * \return shape_path_map Dict of shape_repr to path. + */ + c10::Dict SerializeTvmModules() const override { + c10::Dict shape_path_map; + for (const auto& entry : tvm_modules()) { + shape_path_map.insert(entry.first, entry.second.path()); + } + return shape_path_map; + } + + /*! + * \brief Deserialize TVM modules from shape map. + * + * \param shape_path_map Dict of shape_repr to path. + */ + void DeserializeTvmModules(const c10::Dict& shape_path_map) override { + tvm_modules_.clear(); + for (const auto& entry : shape_path_map) { + const auto& shape_repr = entry.key(); + const auto& path = entry.value(); + tvm_modules_.emplace(shape_repr, TvmVMModulePack(path, device_type_, device_id_)); + } + } + + /*! + * \brief Move the TVM modules to given device. + * + * \param device String repr of the device to be moved to. + */ + void to(const std::string& device) override { + if (device != this->device()) { + auto torch_device = torch::Device(device); + device_type_ = torch_device.is_cuda() ? kDLCUDA : kDLCPU; + device_id_ = torch_device.index(); + DeserializeTvmModules(SerializeTvmModules()); + } + } + + private: + std::map tvm_modules_; +}; + +// +using SerializeTuple = + std::tuple>; + +/***** registries *****/ +static auto __tvm_dsoop_graph_runtime_registry = + torch::jit::class_("tvm_dsoop", "TvmGraphModule") + .def(torch::init()) + .def("load_tvm_module", &TvmGraphRuntimeClass::LoadTvmModule) + .def("forward", &TvmGraphRuntimeClass::forward) + .def("to", &TvmGraphRuntimeClass::to) + .def_pickle( + [](const c10::intrusive_ptr& self) -> SerializeTuple { + return std::make_tuple(self->num_inputs(), self->num_outputs(), self->device(), + self->SerializeTvmModules()); + }, + [](SerializeTuple tuple) -> c10::intrusive_ptr { + auto ptr = c10::make_intrusive( + /*num_inputs=*/std::get<0>(tuple), + /*num_outputs=*/std::get<1>(tuple), + /*device=*/std::get<2>(tuple)); + ptr->DeserializeTvmModules(std::get<3>(tuple)); + return ptr; + }); + +static auto __tvm_dsoop_vm_runtime_registry = + torch::jit::class_("tvm_dsoop", "TvmVMModule") + .def(torch::init()) + .def("load_tvm_module", &TvmVMRuntimeClass::LoadTvmModule) + .def("forward", &TvmVMRuntimeClass::forward) + .def("to", &TvmVMRuntimeClass::to) + .def_pickle( + [](const c10::intrusive_ptr& self) -> SerializeTuple { + return std::make_tuple(self->num_inputs(), self->num_outputs(), self->device(), + self->SerializeTvmModules()); + }, + [](SerializeTuple tuple) -> c10::intrusive_ptr { + auto ptr = c10::make_intrusive( + /*num_inputs=*/std::get<0>(tuple), + /*num_outputs=*/std::get<1>(tuple), + /*device=*/std::get<2>(tuple)); + ptr->DeserializeTvmModules(std::get<3>(tuple)); + return ptr; + }); + +static auto __tvm_shape_repr_fn_registry = + torch::RegisterOperators("tvm_dsoop::tvm_shape_repr", &BaseTvmClass::TvmShapeRepr); +} // namespace pytorch +} // namespace contrib +} // namespace tvm \ No newline at end of file From 99d768cb6582460ac91aa60d24593993ecc6f7db Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 26 Jul 2022 18:44:50 -0700 Subject: [PATCH 05/33] newline --- python/tvm/contrib/torch/__init__.py | 2 +- python/tvm/contrib/torch/module.py | 2 +- python/tvm/contrib/torch/pytorch_tvm.py | 7 +++++-- src/contrib/torch/pt_call_tvm/tvm_class.cc | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/tvm/contrib/torch/__init__.py b/python/tvm/contrib/torch/__init__.py index b4261f3ed906..340f9cef9e58 100644 --- a/python/tvm/contrib/torch/__init__.py +++ b/python/tvm/contrib/torch/__init__.py @@ -58,4 +58,4 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): from . import optimize_torch GraphExecutorFactoryWrapper = optimize_torch.GraphExecutorFactoryWrapper -optimize_torch = optimize_torch.optimize_torch \ No newline at end of file +optimize_torch = optimize_torch.optimize_torch diff --git a/python/tvm/contrib/torch/module.py b/python/tvm/contrib/torch/module.py index 1652730b86eb..3da9c6f591ce 100644 --- a/python/tvm/contrib/torch/module.py +++ b/python/tvm/contrib/torch/module.py @@ -118,4 +118,4 @@ def __init__(self, tvm_module): def forward(self, *inputs): outputs = self.tvm_module(inputs) - return outputs[0] if len(outputs) == 1 else tuple(outputs) \ No newline at end of file + return outputs[0] if len(outputs) == 1 else tuple(outputs) diff --git a/python/tvm/contrib/torch/pytorch_tvm.py b/python/tvm/contrib/torch/pytorch_tvm.py index 914b86d8fc9c..6fce09e11d23 100644 --- a/python/tvm/contrib/torch/pytorch_tvm.py +++ b/python/tvm/contrib/torch/pytorch_tvm.py @@ -225,7 +225,10 @@ def compile(script_module, option): pytorch_tvm_module = compile(script_module, option) pytorch_tvm_module("model_tvm.pt") """ - warnings.warn("We suggest users to use `optimized_torch` for tuning Torch modules instead", DeprecationWarning) + warnings.warn( + "We suggest users to use `optimized_torch` for tuning Torch modules instead", + DeprecationWarning, + ) input_infos = option["input_infos"] default_dtype = option.get("default_dtype", "float32") export_dir = option.get("export_dir", "pytorch_compiled") @@ -248,4 +251,4 @@ def compile(script_module, option): print("Building...") mod.build_tvm(export_dir) pytorch_mod = mod.build_pytorch_module(num_inputs=len(input_infos), num_outputs=num_outputs) - return pytorch_mod \ No newline at end of file + return pytorch_mod diff --git a/src/contrib/torch/pt_call_tvm/tvm_class.cc b/src/contrib/torch/pt_call_tvm/tvm_class.cc index 2d9e4064c003..5e57dc152f11 100644 --- a/src/contrib/torch/pt_call_tvm/tvm_class.cc +++ b/src/contrib/torch/pt_call_tvm/tvm_class.cc @@ -683,4 +683,4 @@ static auto __tvm_shape_repr_fn_registry = torch::RegisterOperators("tvm_dsoop::tvm_shape_repr", &BaseTvmClass::TvmShapeRepr); } // namespace pytorch } // namespace contrib -} // namespace tvm \ No newline at end of file +} // namespace tvm From 8e550d2ba99897a3692ca85de3e786d031e3cc48 Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 26 Jul 2022 20:12:43 -0700 Subject: [PATCH 06/33] config --- cmake/modules/contrib/PT_TVMDSOOP.cmake | 35 +++++++++++++++++++++---- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/cmake/modules/contrib/PT_TVMDSOOP.cmake b/cmake/modules/contrib/PT_TVMDSOOP.cmake index cfdc564a024c..535c9b3b5ee5 100644 --- a/cmake/modules/contrib/PT_TVMDSOOP.cmake +++ b/cmake/modules/contrib/PT_TVMDSOOP.cmake @@ -18,8 +18,7 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") find_package(PythonInterp REQUIRED) - # use ${PYTHON_EXECUTE} below - execute_process(COMMAND "/root/anaconda3/bin/python" -c "import torch; print(torch.__path__[0].strip())" + execute_process(COMMAND ${PYTHON_EXECUTE} -c "import torch; print(torch.__path__[0].strip())" OUTPUT_VARIABLE PT_PATH RESULT_VARIABLE PT_STATUS) @@ -30,9 +29,35 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") string(REGEX REPLACE "\n" "" PT_PATH "${PT_PATH}") message(STATUS "PyTorch path: ${PT_PATH}") - # set(PT_COMPILE_FLAGS_STR "-I${PT_PATH}/include -D_GLIBCXX_USE_CXX11_ABI=0") - set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapperTorch.cc PROPERTIES COMPILE_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=1") - set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapperTorch.cc PROPERTIES COMPILE_FLAGS "-I${PT_PATH}/include") + execute_process(COMMAND ${PYTHON_EXECUTE} -c "import torch;print(torch.compiled_with_cxx11_abi())" + OUTPUT_VARIABLE PT_CXX_FLAG + RESULT_VARIABLE PT_STATUS) + + string(REGEX REPLACE "\n" "" PT_CXX_FLAG "${PT_CXX_FLAG}") + message(STATUS "PT_CXX_FLAG: ${PT_CXX_FLAG} ") + + if(${PT_CXX_FLAG} STREQUAL "False") + set(CXX_ABI_ENABLED 0) + else() + set(CXX_ABI_ENABLED 1) + endif() + + message(STATUS "CXX_ABI_ENABLED: ${CXX_ABI_ENABLED} ") + set_property( + SOURCE + ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapperTorch.cc + APPEND PROPERTY + COMPILE_OPTIONS + "-D_GLIBCXX_USE_CXX11_ABI=${CXX_ABI_ENABLED}" + "-I${PT_PATH}/include" + ) + set_property( + SOURCE + ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/tvm_class.cc + APPEND PROPERTY + COMPILE_OPTIONS + "-I${PT_PATH}/include" + ) set(PT_LINK_FLAGS_STR "-L${PT_PATH}/lib -l:libtorch.so -l:libtorch_python.so") if(NOT USE_CUDA STREQUAL "OFF") From 3c60a993025a474aadf56812473cb163a74184ac Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 26 Jul 2022 20:13:59 -0700 Subject: [PATCH 07/33] config --- cmake/modules/contrib/PT_TVMDSOOP.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/modules/contrib/PT_TVMDSOOP.cmake b/cmake/modules/contrib/PT_TVMDSOOP.cmake index 535c9b3b5ee5..28af3b202c87 100644 --- a/cmake/modules/contrib/PT_TVMDSOOP.cmake +++ b/cmake/modules/contrib/PT_TVMDSOOP.cmake @@ -18,7 +18,7 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") find_package(PythonInterp REQUIRED) - execute_process(COMMAND ${PYTHON_EXECUTE} -c "import torch; print(torch.__path__[0].strip())" + execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import torch; print(torch.__path__[0].strip())" OUTPUT_VARIABLE PT_PATH RESULT_VARIABLE PT_STATUS) From 84a25f842f21810613f33d2845a451cabe7c1718 Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 26 Jul 2022 20:25:11 -0700 Subject: [PATCH 08/33] typo --- cmake/modules/contrib/PT_TVMDSOOP.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/modules/contrib/PT_TVMDSOOP.cmake b/cmake/modules/contrib/PT_TVMDSOOP.cmake index 28af3b202c87..45aebaaa4d8f 100644 --- a/cmake/modules/contrib/PT_TVMDSOOP.cmake +++ b/cmake/modules/contrib/PT_TVMDSOOP.cmake @@ -29,7 +29,7 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") string(REGEX REPLACE "\n" "" PT_PATH "${PT_PATH}") message(STATUS "PyTorch path: ${PT_PATH}") - execute_process(COMMAND ${PYTHON_EXECUTE} -c "import torch;print(torch.compiled_with_cxx11_abi())" + execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import torch;print(torch.compiled_with_cxx11_abi())" OUTPUT_VARIABLE PT_CXX_FLAG RESULT_VARIABLE PT_STATUS) From 5861561c656497c45cbba7a1b0eeb5fbc2a53d3c Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 26 Jul 2022 20:33:24 -0700 Subject: [PATCH 09/33] skip tvm_class --- cmake/modules/contrib/PT_TVMDSOOP.cmake | 11 ++--------- .../RuntimeModuleWrapperTVM.cc | 0 .../RuntimeModuleWrapperTorch.cc | 0 .../runtime_bridge.h | 0 4 files changed, 2 insertions(+), 9 deletions(-) rename src/contrib/torch/{pt_call_tvm => tvm_module_wrapper}/RuntimeModuleWrapperTVM.cc (100%) rename src/contrib/torch/{pt_call_tvm => tvm_module_wrapper}/RuntimeModuleWrapperTorch.cc (100%) rename src/contrib/torch/{pt_call_tvm => tvm_module_wrapper}/runtime_bridge.h (100%) diff --git a/cmake/modules/contrib/PT_TVMDSOOP.cmake b/cmake/modules/contrib/PT_TVMDSOOP.cmake index 45aebaaa4d8f..d896e4c2b6a3 100644 --- a/cmake/modules/contrib/PT_TVMDSOOP.cmake +++ b/cmake/modules/contrib/PT_TVMDSOOP.cmake @@ -45,19 +45,12 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") message(STATUS "CXX_ABI_ENABLED: ${CXX_ABI_ENABLED} ") set_property( SOURCE - ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapperTorch.cc + ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc APPEND PROPERTY COMPILE_OPTIONS "-D_GLIBCXX_USE_CXX11_ABI=${CXX_ABI_ENABLED}" "-I${PT_PATH}/include" ) - set_property( - SOURCE - ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/tvm_class.cc - APPEND PROPERTY - COMPILE_OPTIONS - "-I${PT_PATH}/include" - ) set(PT_LINK_FLAGS_STR "-L${PT_PATH}/lib -l:libtorch.so -l:libtorch_python.so") if(NOT USE_CUDA STREQUAL "OFF") @@ -69,7 +62,7 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") separate_arguments(PT_LINK_FLAGS UNIX_COMMAND ${PT_LINK_FLAGS_STR}) set(LIBRARY_NAME pt_tvmdsoop) - tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/**/*.cc) + tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/*.cc) add_library(${LIBRARY_NAME} SHARED ${PTTVM_SRCS}) set(PTTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR}) diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc similarity index 100% rename from src/contrib/torch/pt_call_tvm/RuntimeModuleWrapperTVM.cc rename to src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc similarity index 100% rename from src/contrib/torch/pt_call_tvm/RuntimeModuleWrapperTorch.cc rename to src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc diff --git a/src/contrib/torch/pt_call_tvm/runtime_bridge.h b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h similarity index 100% rename from src/contrib/torch/pt_call_tvm/runtime_bridge.h rename to src/contrib/torch/tvm_module_wrapper/runtime_bridge.h From 340dac5dda461b29e05bbb38b2f59045962dc667 Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 26 Jul 2022 21:49:58 -0700 Subject: [PATCH 10/33] rename --- .../RuntimeModuleWrapperTVM.cc | 23 ++++++++-------- .../RuntimeModuleWrapperTorch.cc | 27 ++++++++++--------- .../torch/tvm_module_wrapper/runtime_bridge.h | 19 +++++++------ 3 files changed, 36 insertions(+), 33 deletions(-) diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index fdd773df77d8..f929269dc018 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -118,18 +118,18 @@ tvm::runtime::NDArray NDArrayFromDlPackExt(DLPackTensorExt dlpack_ext) { extern "C" { -struct RuntimeModulePointer { +struct TVMContribTorchRuntimeModule { tvm::runtime::Module mod; - RuntimeModulePointer(tvm::runtime::Module mod) : mod(mod) {} + TVMContribTorchRuntimeModule(tvm::runtime::Module mod) : mod(mod) {} }; -RuntimeModulePointer* get_last_saved_runtime_module() { - return new RuntimeModulePointer(ThreadLocalStore::ThreadLocal()->mod); +TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module() { + return new TVMContribTorchRuntimeModule(ThreadLocalStore::ThreadLocal()->mod); } -void operator_module_forward(RuntimeModulePointer* runtime_module, TensorList inputs, - size_t input_size) { +void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* runtime_module, + TensorList inputs, size_t input_size) { tvm::runtime::PackedFunc run = runtime_module->mod.GetFunction("__tvm_main__"); std::vector tvm_values(input_size); @@ -149,8 +149,9 @@ tvm::Device getDeviceInfo(DLManagedTensor* input_device) { .device_id = input_device->dl_tensor.device.device_id}; } -int64_t graph_executor_module_forward(RuntimeModulePointer* graph_module, TensorList inputs, - size_t input_size, TensorList* outputs) { +int64_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeModule* graph_module, + TensorList inputs, size_t input_size, + TensorList* outputs) { tvm::runtime::PackedFunc built_module = graph_module->mod.GetFunction("default"); tvm::runtime::Module runtime_module = built_module(getDeviceInfo(inputs[0])); tvm::runtime::PackedFunc run = runtime_module.GetFunction("run"); @@ -178,7 +179,7 @@ int64_t graph_executor_module_forward(RuntimeModulePointer* graph_module, Tensor return output_length; } -char* encode(RuntimeModulePointer* runtime_module) { +char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module) { auto std = tvm::contrib::serialize(runtime_module->mod); auto* ret = new char[std.length() + 1]; strcpy(ret, std.c_str()); @@ -186,8 +187,8 @@ char* encode(RuntimeModulePointer* runtime_module) { return ret; } -RuntimeModulePointer* decode(const char* state) { +TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state) { auto ret = tvm::contrib::deserialize(state); - return new RuntimeModulePointer(ret); + return new TVMContribTorchRuntimeModule(ret); } } \ No newline at end of file diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc index 2ce306a167cc..99b7f2aad3ac 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc @@ -47,7 +47,7 @@ DLPackTensorExt toDlPackExt(const at::Tensor& src) { */ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { public: - OperatorModuleWrapper() { runtime_module = get_last_saved_runtime_module(); } + OperatorModuleWrapper() { runtime_module = tvm_contrib_torch_get_last_saved_runtime_module(); } void forward(const c10::List& inputs) { int input_length = inputs.size(); @@ -56,20 +56,22 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i])); - operator_module_forward(this->runtime_module, static_cast(tensors.data()), - tensors.size()); + tvm_contrib_torch_operator_module_forward( + this->runtime_module, static_cast(tensors.data()), tensors.size()); for (int k = 0; k < input_length; ++k) { tensors[k]->deleter(tensors[k]); } } - std::string Serialize() { return std::string(encode(runtime_module)); } + std::string Serialize() { return std::string(tvm_contrib_torch_encode(runtime_module)); } - explicit OperatorModuleWrapper(std::string state) { runtime_module = decode(state.c_str()); } + explicit OperatorModuleWrapper(std::string state) { + runtime_module = tvm_contrib_torch_decode(state.c_str()); + } private: - RuntimeModulePointer* runtime_module; + TVMContribTorchRuntimeModule* runtime_module; }; /** @@ -79,14 +81,15 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { */ class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { public: - explicit GraphExecutorFactoryWrapper(RuntimeModulePointer* executor_factory) + explicit GraphExecutorFactoryWrapper(TVMContribTorchRuntimeModule* executor_factory) : executor_factory_(executor_factory) {} - GraphExecutorFactoryWrapper() : GraphExecutorFactoryWrapper(get_last_saved_runtime_module()) {} - std::string Serialize() { return encode(executor_factory_); } + GraphExecutorFactoryWrapper() + : GraphExecutorFactoryWrapper(tvm_contrib_torch_get_last_saved_runtime_module()) {} + std::string Serialize() { return tvm_contrib_torch_encode(executor_factory_); } explicit GraphExecutorFactoryWrapper(std::string state) { - executor_factory_ = decode(state.c_str()); + executor_factory_ = tvm_contrib_torch_decode(state.c_str()); } c10::List forward(const c10::List& inputs) { @@ -100,7 +103,7 @@ class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { TensorList* outputs = new TensorList; - auto num_outputs = graph_executor_module_forward( + auto num_outputs = tvm_contrib_torch_graph_executor_module_forward( executor_factory_, static_cast(tensors.data()), tensors.size(), outputs); c10::List ret; @@ -121,7 +124,7 @@ class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { } private: - RuntimeModulePointer* executor_factory_; + TVMContribTorchRuntimeModule* executor_factory_; }; TORCH_LIBRARY(tvm_torch, m) { diff --git a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h index 930471ecc735..42ca980f6a24 100644 --- a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h +++ b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h @@ -32,21 +32,20 @@ struct DLPackTensorExt { bool is_bool; }; +struct TVMContribTorchRuntimeModule; -struct RuntimeModulePointer; +TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module(); -RuntimeModulePointer* get_last_saved_runtime_module(); +void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* runtime_module, + TensorList inputs, size_t input_size); -void operator_module_forward(RuntimeModulePointer* runtime_module, TensorList inputs, - size_t input_size); +int64_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeModule* graph_module, + TensorList inputs, size_t input_size, + TensorList* outputs); -int64_t graph_executor_module_forward(RuntimeModulePointer* graph_module, TensorList inputs, - size_t input_size, - TensorList* outputs); +char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module); -char* encode(RuntimeModulePointer* runtime_module); - -RuntimeModulePointer* decode(const char* state); +TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state); } #endif // TVM_CONTRIB_TORCH_RUNTIME_BRIDGE_H_ From 10edc3a8c5defd8401d94892d465c2905470d94c Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 26 Jul 2022 22:13:32 -0700 Subject: [PATCH 11/33] delete ptr --- .../torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc | 2 +- src/contrib/torch/tvm_module_wrapper/runtime_bridge.h | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc index 99b7f2aad3ac..09194636e61a 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc @@ -118,7 +118,7 @@ class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { tensors[k]->deleter(tensors[k]); } - delete outputs; + tvm_contrib_torch_delete_raw_pointer(static_cast(outputs)); return ret; } diff --git a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h index 42ca980f6a24..22d4a903703f 100644 --- a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h +++ b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h @@ -46,6 +46,8 @@ int64_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeMo char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module); TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state); + +void tvm_contrib_torch_delete_raw_pointer(void* ptr); } #endif // TVM_CONTRIB_TORCH_RUNTIME_BRIDGE_H_ From b9de0619a30811fd04a48f9e35218ea54ed1ecda Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 26 Jul 2022 22:13:40 -0700 Subject: [PATCH 12/33] delete ptr --- src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index f929269dc018..80872a578dc1 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -191,4 +191,6 @@ TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state) { auto ret = tvm::contrib::deserialize(state); return new TVMContribTorchRuntimeModule(ret); } + +void tvm_contrib_torch_delete_raw_pointer(void* ptr) { delete ptr; } } \ No newline at end of file From e1ca6d3f5a965fed563d020f1004de878c60e797 Mon Sep 17 00:00:00 2001 From: juda Date: Wed, 27 Jul 2022 05:12:00 -0700 Subject: [PATCH 13/33] save progress --- .../RuntimeModuleWrapperTVM.cc | 14 +++++++++----- .../RuntimeModuleWrapperTorch.cc | 18 +++++++++--------- .../torch/tvm_module_wrapper/runtime_bridge.h | 4 ++-- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index 80872a578dc1..480abd6a6682 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -101,6 +101,7 @@ tvm::runtime::NDArray NDArrayFromDlPackExt(DLPackTensorExt dlpack_ext) { array = NDArray::FromDLPack(dlpack_ext.dl_managed_tensor); } else { // Copy if data pointer isn't aligned to the kAllocAlignment of TVM + LOG(INFO) << "ndim: " << dl_tensor.ndim; array = NDArray::NewFromDLTensor(&dl_tensor, dl_tensor.device); dlpack_ext.dl_managed_tensor->deleter(dlpack_ext.dl_managed_tensor); } @@ -129,17 +130,20 @@ TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module() } void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* runtime_module, - TensorList inputs, size_t input_size) { + DLPackTensorExt* inputs, size_t input_size) { tvm::runtime::PackedFunc run = runtime_module->mod.GetFunction("__tvm_main__"); std::vector tvm_values(input_size); std::vector tvm_type_codes(input_size); tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); for (int k = 0; k < input_size; ++k) { - // auto datum = tvm::contrib::NDArrayFromDlPackExt(inputs[k]); - setter(k, &inputs[k]->dl_tensor); + auto datum = tvm::contrib::NDArrayFromDlPackExt(inputs[k]); + LOG(INFO) << "shape: " << datum.Shape(); + setter(k, datum); + } + for (int k = 0; k < input_size; ++k) { + LOG(INFO) << tvm_type_codes[k]; } - run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_size), nullptr); } @@ -192,5 +196,5 @@ TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state) { return new TVMContribTorchRuntimeModule(ret); } -void tvm_contrib_torch_delete_raw_pointer(void* ptr) { delete ptr; } +void tvm_contrib_torch_delete_raw_pointer(TensorList* ptr) { delete ptr; } } \ No newline at end of file diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc index 09194636e61a..def7d2b4f51e 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc @@ -27,9 +27,9 @@ namespace tvm { namespace contrib { -DLPackTensorExt toDlPackExt(const at::Tensor& src) { +DLPackTensorExt toDLPackExt(const at::Tensor& src) { if (!src.is_contiguous()) { - return toDlPackExt(src.contiguous()); + return toDLPackExt(src.contiguous()); } if (src.dtype().isScalarType(torch::kBool)) { @@ -52,16 +52,16 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { void forward(const c10::List& inputs) { int input_length = inputs.size(); - std::vector tensors; + std::vector tensors; - for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i])); + for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPackExt(inputs[i])); tvm_contrib_torch_operator_module_forward( - this->runtime_module, static_cast(tensors.data()), tensors.size()); + this->runtime_module, static_cast(tensors.data()), tensors.size()); - for (int k = 0; k < input_length; ++k) { - tensors[k]->deleter(tensors[k]); - } + // for (int k = 0; k < input_length; ++k) { + // tensors[k]->deleter(tensors[k]); + // } } std::string Serialize() { return std::string(tvm_contrib_torch_encode(runtime_module)); } @@ -118,7 +118,7 @@ class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { tensors[k]->deleter(tensors[k]); } - tvm_contrib_torch_delete_raw_pointer(static_cast(outputs)); + tvm_contrib_torch_delete_raw_pointer(outputs); return ret; } diff --git a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h index 22d4a903703f..136cd4776730 100644 --- a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h +++ b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h @@ -37,7 +37,7 @@ struct TVMContribTorchRuntimeModule; TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module(); void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* runtime_module, - TensorList inputs, size_t input_size); + DLPackTensorExt* inputs, size_t input_size); int64_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeModule* graph_module, TensorList inputs, size_t input_size, @@ -47,7 +47,7 @@ char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module); TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state); -void tvm_contrib_torch_delete_raw_pointer(void* ptr); +void tvm_contrib_torch_delete_raw_pointer(TensorList* ptr); } #endif // TVM_CONTRIB_TORCH_RUNTIME_BRIDGE_H_ From 09eb1677279b648df228699964810e2389fb7274 Mon Sep 17 00:00:00 2001 From: juda Date: Thu, 28 Jul 2022 03:49:49 -0700 Subject: [PATCH 14/33] boolean support --- .../RuntimeModuleWrapperTVM.cc | 68 +++++++------------ .../RuntimeModuleWrapperTorch.cc | 31 +++++---- .../torch/tvm_module_wrapper/runtime_bridge.h | 4 +- 3 files changed, 45 insertions(+), 58 deletions(-) diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index 480abd6a6682..28a26fe4cf80 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -85,35 +85,15 @@ tvm::runtime::Module deserialize(std::string state) { return ret; } +tvm::Device getDeviceInfo(DLManagedTensor* input_device) { + return {.device_type = input_device->dl_tensor.device.device_type, + .device_id = input_device->dl_tensor.device.device_id}; +} + TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) { ThreadLocalStore::ThreadLocal()->mod = mod; }); -tvm::runtime::NDArray NDArrayFromDlPackExt(DLPackTensorExt dlpack_ext) { - using tvm::runtime::NDArray; - - NDArray array; - auto& dl_tensor = dlpack_ext.dl_managed_tensor->dl_tensor; - bool is_zero_copy = - tvm::runtime::NDArray::AbilityOfZeroCopyForDLTensor(&dl_tensor, dl_tensor.device); - if (is_zero_copy) { - // Zero-copy if data pointer is aligned - array = NDArray::FromDLPack(dlpack_ext.dl_managed_tensor); - } else { - // Copy if data pointer isn't aligned to the kAllocAlignment of TVM - LOG(INFO) << "ndim: " << dl_tensor.ndim; - array = NDArray::NewFromDLTensor(&dl_tensor, dl_tensor.device); - dlpack_ext.dl_managed_tensor->deleter(dlpack_ext.dl_managed_tensor); - } - if (dlpack_ext.is_bool) { - auto result = tvm::runtime::NDArray::Empty(array.Shape(), DataType::Bool(), array->device); - result.CopyFrom(array); - return result; - } - - return array; -} - } // namespace contrib } // namespace tvm @@ -137,47 +117,47 @@ void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* run std::vector tvm_type_codes(input_size); tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); for (int k = 0; k < input_size; ++k) { - auto datum = tvm::contrib::NDArrayFromDlPackExt(inputs[k]); - LOG(INFO) << "shape: " << datum.Shape(); - setter(k, datum); - } - for (int k = 0; k < input_size; ++k) { - LOG(INFO) << tvm_type_codes[k]; + setter(k, &inputs[k].dl_managed_tensor->dl_tensor); } run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_size), nullptr); } -tvm::Device getDeviceInfo(DLManagedTensor* input_device) { - return {.device_type = input_device->dl_tensor.device.device_type, - .device_id = input_device->dl_tensor.device.device_id}; -} - int64_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeModule* graph_module, - TensorList inputs, size_t input_size, - TensorList* outputs) { + DLPackTensorExt* inputs, size_t input_size, + DLPackTensorExt** outputs) { tvm::runtime::PackedFunc built_module = graph_module->mod.GetFunction("default"); - tvm::runtime::Module runtime_module = built_module(getDeviceInfo(inputs[0])); + auto device_info = tvm::contrib::getDeviceInfo(inputs[0].dl_managed_tensor); + tvm::runtime::Module runtime_module = built_module(device_info); tvm::runtime::PackedFunc run = runtime_module.GetFunction("run"); tvm::runtime::PackedFunc set_input = runtime_module.GetFunction("set_input"); tvm::runtime::PackedFunc get_output = runtime_module.GetFunction("get_output"); tvm::runtime::PackedFunc get_num_outputs = runtime_module.GetFunction("get_num_outputs"); for (int k = 0; k < input_size; ++k) { - set_input(k, &inputs[k]->dl_tensor); + set_input(k, &inputs[k].dl_managed_tensor->dl_tensor); } run(); int64_t output_length = get_num_outputs(); - auto out_ptr = new DLManagedTensor*[output_length]; - *outputs = out_ptr; + auto outputs_ptr = new DLPackTensorExt[output_length]; + *outputs = outputs_ptr; for (int k = 0; k < output_length; ++k) { tvm::runtime::NDArray results = get_output(k); - auto tensor = results.ToDLPack(); - out_ptr[k] = tensor; + auto is_bool = results.DataType().is_bool(); + DLManagedTensor* tensor; + if (is_bool) { + auto tmp = + tvm::runtime::NDArray::Empty(results.Shape(), DLDataType{kDLInt, 8, 1}, device_info); + results.CopyTo(tmp); + tensor = tmp.ToDLPack(); + } else { + tensor = results.ToDLPack(); + } + outputs_ptr[k] = {.dl_managed_tensor = tensor, .is_bool = is_bool}; } return output_length; diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc index def7d2b4f51e..d9aeb2a699a5 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc @@ -34,12 +34,20 @@ DLPackTensorExt toDLPackExt(const at::Tensor& src) { if (src.dtype().isScalarType(torch::kBool)) { auto temp = src.toType(torch::kUInt8); - return {.dl_managed_tensor = at::toDLPack(temp), .is_bool = false}; + return {.dl_managed_tensor = at::toDLPack(temp), .is_bool = true}; } return {.dl_managed_tensor = at::toDLPack(src), .is_bool = false}; } +at::Tensor fromDLPackExt(const DLPackTensorExt& src) { + if (src.is_bool) { + return at::fromDLPack(src.dl_managed_tensor).toType(torch::kBool); + } else { + return at::fromDLPack(src.dl_managed_tensor); + } +} + /** * @brief A Torch's module which wraps TVM's OperatorModule Class. * The basic forward function calling TVM's runtime is provided. @@ -55,13 +63,12 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { std::vector tensors; for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPackExt(inputs[i])); - tvm_contrib_torch_operator_module_forward( this->runtime_module, static_cast(tensors.data()), tensors.size()); - // for (int k = 0; k < input_length; ++k) { - // tensors[k]->deleter(tensors[k]); - // } + for (int k = 0; k < input_length; ++k) { + tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor); + } } std::string Serialize() { return std::string(tvm_contrib_torch_encode(runtime_module)); } @@ -97,28 +104,28 @@ class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { TORCH_CHECK(input_length > 0, "Receive empty list of input tensors"); - std::vector tensors; + std::vector tensors; - for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i])); + for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPackExt(inputs[i])); - TensorList* outputs = new TensorList; + auto outputs = new DLPackTensorExt*; auto num_outputs = tvm_contrib_torch_graph_executor_module_forward( - executor_factory_, static_cast(tensors.data()), tensors.size(), outputs); + executor_factory_, static_cast(tensors.data()), tensors.size(), outputs); c10::List ret; ret.reserve(num_outputs); for (int k = 0; k < num_outputs; ++k) { - at::Tensor atTensor = at::fromDLPack((*outputs)[k]); + at::Tensor atTensor = fromDLPackExt((*outputs)[k]); ret.emplace_back(atTensor); } for (int k = 0; k < input_length; ++k) { - tensors[k]->deleter(tensors[k]); + tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor); } - tvm_contrib_torch_delete_raw_pointer(outputs); + delete outputs; return ret; } diff --git a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h index 136cd4776730..e7c9694a5cc4 100644 --- a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h +++ b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h @@ -40,8 +40,8 @@ void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* run DLPackTensorExt* inputs, size_t input_size); int64_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeModule* graph_module, - TensorList inputs, size_t input_size, - TensorList* outputs); + DLPackTensorExt* inputs, size_t input_size, + DLPackTensorExt** outputs); char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module); From 12eb80b2c5512af32ea6789de5f1c2db5db595cc Mon Sep 17 00:00:00 2001 From: juda Date: Thu, 28 Jul 2022 20:04:07 -0700 Subject: [PATCH 15/33] cmake file --- cmake/modules/contrib/PT_TVMDSOOP.cmake | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cmake/modules/contrib/PT_TVMDSOOP.cmake b/cmake/modules/contrib/PT_TVMDSOOP.cmake index d896e4c2b6a3..48cc03e4f5f0 100644 --- a/cmake/modules/contrib/PT_TVMDSOOP.cmake +++ b/cmake/modules/contrib/PT_TVMDSOOP.cmake @@ -34,7 +34,6 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") RESULT_VARIABLE PT_STATUS) string(REGEX REPLACE "\n" "" PT_CXX_FLAG "${PT_CXX_FLAG}") - message(STATUS "PT_CXX_FLAG: ${PT_CXX_FLAG} ") if(${PT_CXX_FLAG} STREQUAL "False") set(CXX_ABI_ENABLED 0) @@ -42,7 +41,7 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") set(CXX_ABI_ENABLED 1) endif() - message(STATUS "CXX_ABI_ENABLED: ${CXX_ABI_ENABLED} ") + message(STATUS "SET CXX_ABI_ENABLED: ${CXX_ABI_ENABLED} ") set_property( SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc From 1f31ded287d93f20c5b881021535163bc36a25dc Mon Sep 17 00:00:00 2001 From: juda Date: Fri, 29 Jul 2022 01:40:36 -0700 Subject: [PATCH 16/33] polish code --- apps/pt_tvmdsoop/tests/test_boolean_tensor.py | 58 +++++++++++++++++++ cmake/modules/contrib/PT_TVMDSOOP.cmake | 2 +- python/tvm/contrib/torch/pytorch_tvm.py | 6 ++ .../RuntimeModuleWrapperTVM.cc | 7 +-- .../RuntimeModuleWrapperTorch.cc | 2 +- .../torch/tvm_module_wrapper/runtime_bridge.h | 6 +- 6 files changed, 72 insertions(+), 9 deletions(-) create mode 100644 apps/pt_tvmdsoop/tests/test_boolean_tensor.py diff --git a/apps/pt_tvmdsoop/tests/test_boolean_tensor.py b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py new file mode 100644 index 000000000000..2138a7cac777 --- /dev/null +++ b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test script for boolean tensor support""" +import numpy as np + +import torch + +import tvm +import tvm.testing +from tvm.contrib.torch import optimize_torch + + +def negate(x): + return x.logical_not() + + +def sum_up_tensor(x, y): + return torch.sum(x[y]) + + +def test_bool_tensor_negate(): + input = torch.ones(1, dtype=torch.bool) + optimized_negate = optimize_torch( + negate, + input, + ) + output = optimized_negate(negate(input)) + tvm.testing.assert_allclose(input.numpy(), output.numpy(), atol=1e-5, rtol=1e-5) + + +def test_sum_up_tensor(): + x = torch.randint(0, 2, (8,)) + y = x.bool() + optimized_func = optimize_torch(sum_up_tensor, (x, y)) + ret1 = torch.sum(x).numpy() + ret2 = optimized_func(x, y).numpy() + tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + test_bool_tensor_negate() + test_sum_up_tensor() diff --git a/cmake/modules/contrib/PT_TVMDSOOP.cmake b/cmake/modules/contrib/PT_TVMDSOOP.cmake index 48cc03e4f5f0..e8f800f1fd84 100644 --- a/cmake/modules/contrib/PT_TVMDSOOP.cmake +++ b/cmake/modules/contrib/PT_TVMDSOOP.cmake @@ -34,6 +34,7 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") RESULT_VARIABLE PT_STATUS) string(REGEX REPLACE "\n" "" PT_CXX_FLAG "${PT_CXX_FLAG}") + message(STATUS "Found TORCH_BUILT_WITH_CXX_ABI=${PT_CXX_FLAG} ") if(${PT_CXX_FLAG} STREQUAL "False") set(CXX_ABI_ENABLED 0) @@ -41,7 +42,6 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") set(CXX_ABI_ENABLED 1) endif() - message(STATUS "SET CXX_ABI_ENABLED: ${CXX_ABI_ENABLED} ") set_property( SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc diff --git a/python/tvm/contrib/torch/pytorch_tvm.py b/python/tvm/contrib/torch/pytorch_tvm.py index 6fce09e11d23..27b46c4fd6bf 100644 --- a/python/tvm/contrib/torch/pytorch_tvm.py +++ b/python/tvm/contrib/torch/pytorch_tvm.py @@ -184,6 +184,11 @@ def load_tvm(self, export_dir): def build_pytorch_module(self, num_inputs, num_outputs, input_infos=None): """Build pytorch module containing TVM Graph Module""" + warnings.warn( + "We suggest users to use `optimized_torch` for tuning Torch modules instead", + DeprecationWarning, + stacklevel=2, + ) assert self.export_dir, "you must build_tvm or load_tvm before" input_infos = input_infos or self.input_infos assert input_infos @@ -228,6 +233,7 @@ def compile(script_module, option): warnings.warn( "We suggest users to use `optimized_torch` for tuning Torch modules instead", DeprecationWarning, + stacklevel=2, ) input_infos = option["input_infos"] default_dtype = option.get("default_dtype", "float32") diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index 28a26fe4cf80..912ab7a72f21 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -102,7 +102,7 @@ extern "C" { struct TVMContribTorchRuntimeModule { tvm::runtime::Module mod; - TVMContribTorchRuntimeModule(tvm::runtime::Module mod) : mod(mod) {} + explicit TVMContribTorchRuntimeModule(tvm::runtime::Module mod) : mod(mod) {} }; TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module() { @@ -166,8 +166,7 @@ int64_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeMo char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module) { auto std = tvm::contrib::serialize(runtime_module->mod); auto* ret = new char[std.length() + 1]; - strcpy(ret, std.c_str()); - ret[std.length()] = '\0'; + snprintf(ret, std.length() + 1, "%s", std.c_str()); return ret; } @@ -177,4 +176,4 @@ TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state) { } void tvm_contrib_torch_delete_raw_pointer(TensorList* ptr) { delete ptr; } -} \ No newline at end of file +} diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc index d9aeb2a699a5..b809dc7632f9 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc @@ -156,4 +156,4 @@ TORCH_LIBRARY(tvm_torch, m) { } } // namespace contrib -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h index e7c9694a5cc4..e409a950051b 100644 --- a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h +++ b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h @@ -20,8 +20,8 @@ * \file runtime_bridge.h * \brief Util functions for pytorch tvm interaction. */ -#ifndef TVM_CONTRIB_TORCH_RUNTIME_BRIDGE_H_ -#define TVM_CONTRIB_TORCH_RUNTIME_BRIDGE_H_ +#ifndef TVM_CONTRIB_TORCH_TVM_MODULE_WRAPPER_RUNTIME_BRIDGE_H_ +#define TVM_CONTRIB_TORCH_TVM_MODULE_WRAPPER_RUNTIME_BRIDGE_H_ extern "C" { @@ -50,4 +50,4 @@ TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state); void tvm_contrib_torch_delete_raw_pointer(TensorList* ptr); } -#endif // TVM_CONTRIB_TORCH_RUNTIME_BRIDGE_H_ +#endif // TVM_CONTRIB_TORCH_TVM_MODULE_WRAPPER_RUNTIME_BRIDGE_H_ From 12d38facfd4eed50bb0f653aefbc15e826e43651 Mon Sep 17 00:00:00 2001 From: juda Date: Sat, 30 Jul 2022 02:45:07 -0700 Subject: [PATCH 17/33] compile config --- apps/pt_tvmdsoop/tests/test_boolean_tensor.py | 12 +++++------ cmake/modules/contrib/PT_TVMDSOOP.cmake | 20 ++++++++++++++++++- python/tvm/contrib/torch/__init__.py | 12 ++++++++--- python/tvm/contrib/torch/module.py | 16 +++++++++++++++ python/tvm/contrib/torch/pytorch_tvm.py | 14 +++++++++++-- 5 files changed, 62 insertions(+), 12 deletions(-) diff --git a/apps/pt_tvmdsoop/tests/test_boolean_tensor.py b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py index 2138a7cac777..3d9f2b70221d 100644 --- a/apps/pt_tvmdsoop/tests/test_boolean_tensor.py +++ b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py @@ -30,8 +30,8 @@ def negate(x): return x.logical_not() -def sum_up_tensor(x, y): - return torch.sum(x[y]) +def sum_up_tensor(x): + return x.size(dim=0) - torch.sum(x.int()) def test_bool_tensor_negate(): @@ -45,11 +45,11 @@ def test_bool_tensor_negate(): def test_sum_up_tensor(): - x = torch.randint(0, 2, (8,)) + x = torch.randint(0, 2, (16,)) y = x.bool() - optimized_func = optimize_torch(sum_up_tensor, (x, y)) - ret1 = torch.sum(x).numpy() - ret2 = optimized_func(x, y).numpy() + optimized_func = optimize_torch(sum_up_tensor, (y,)) + ret1 = (x[x == 0]).size(dim=0) + ret2 = optimized_func(y).numpy() tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5) diff --git a/cmake/modules/contrib/PT_TVMDSOOP.cmake b/cmake/modules/contrib/PT_TVMDSOOP.cmake index e8f800f1fd84..22b108e583a0 100644 --- a/cmake/modules/contrib/PT_TVMDSOOP.cmake +++ b/cmake/modules/contrib/PT_TVMDSOOP.cmake @@ -50,6 +50,15 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") "-D_GLIBCXX_USE_CXX11_ABI=${CXX_ABI_ENABLED}" "-I${PT_PATH}/include" ) + + set_property( + SOURCE + ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/tvm_class.cc + APPEND PROPERTY + COMPILE_OPTIONS + "-I${PT_PATH}/include" + ) + set(PT_LINK_FLAGS_STR "-L${PT_PATH}/lib -l:libtorch.so -l:libtorch_python.so") if(NOT USE_CUDA STREQUAL "OFF") @@ -61,8 +70,13 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") separate_arguments(PT_LINK_FLAGS UNIX_COMMAND ${PT_LINK_FLAGS_STR}) set(LIBRARY_NAME pt_tvmdsoop) - tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/*.cc) + set(LIBRARY_TORCH_NAME pt_tvmintg) + tvm_file_glob(GLOB_RECURSE PTTVM_TORCH ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/*.cc) + + tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/*.cc) + add_library(${LIBRARY_NAME} SHARED ${PTTVM_SRCS}) + add_library(${LIBRARY_TORCH_NAME} SHARED ${PTTVM_TORCH}) set(PTTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR}) if(NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON") @@ -72,4 +86,8 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") target_compile_options(${LIBRARY_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS}) target_link_libraries(${LIBRARY_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS}) target_compile_definitions(${LIBRARY_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=) + + target_compile_options(${LIBRARY_TORCH_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS}) + target_link_libraries(${LIBRARY_TORCH_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS}) + target_compile_definitions(${LIBRARY_TORCH_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=) endif() diff --git a/python/tvm/contrib/torch/__init__.py b/python/tvm/contrib/torch/__init__.py index 340f9cef9e58..59002f5588d6 100644 --- a/python/tvm/contrib/torch/__init__.py +++ b/python/tvm/contrib/torch/__init__.py @@ -18,11 +18,12 @@ """Module container of Pytorch custom class""" import os import platform +import warnings import torch from tvm._ffi import libinfo -def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): +def _load_platform_specific_library(lib_name="libpt_tvmintg"): system = platform.system() if system == "Darwin": lib_file_name = lib_name + ".dylib" @@ -33,12 +34,17 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): lib_path = libinfo.find_lib_path()[0] lib_dir = os.path.dirname(lib_path) lib_file_path = os.path.join(lib_dir, lib_file_name) - torch.classes.load_library(lib_file_path) + try: + torch.classes.load_library(lib_file_path) + except: + warnings.warn( + f"The library {lib_name} is not built and loaded successfully.", RuntimeWarning + ) +_load_platform_specific_library("libpt_tvmdsoop") _load_platform_specific_library() - from . import module GraphModule = module.GraphModule diff --git a/python/tvm/contrib/torch/module.py b/python/tvm/contrib/torch/module.py index 3da9c6f591ce..624477293d90 100644 --- a/python/tvm/contrib/torch/module.py +++ b/python/tvm/contrib/torch/module.py @@ -18,6 +18,7 @@ """Module container of PyTorch custom class""" from typing import List import torch +import warnings class GraphModule(torch.nn.Module): @@ -29,6 +30,11 @@ def shape_repr(cls, input_shapes): return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes) def __init__(self, num_inputs, num_outputs, device=None): + warnings.warn( + "This module will be removed at TVM version 0.11", + DeprecationWarning, + stacklevel=2, + ) super().__init__() self.dummy_param = torch.nn.Parameter(torch.empty(0)) self.engine = None @@ -67,6 +73,11 @@ def shape_repr(cls, input_shapes): return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes) def __init__(self, num_inputs, num_outputs, device=None): + warnings.warn( + "This module will be removed at TVM version 0.11", + DeprecationWarning, + stacklevel=2, + ) super().__init__() self.dummy_param = torch.nn.Parameter(torch.empty(0)) self.engine = None @@ -113,6 +124,11 @@ class TraceTvmModule(torch.nn.Module): """ def __init__(self, tvm_module): + warnings.warn( + "This module will be removed at TVM version 0.11", + DeprecationWarning, + stacklevel=2, + ) super().__init__() self.tvm_module = tvm_module diff --git a/python/tvm/contrib/torch/pytorch_tvm.py b/python/tvm/contrib/torch/pytorch_tvm.py index 27b46c4fd6bf..ffab4fa0d246 100644 --- a/python/tvm/contrib/torch/pytorch_tvm.py +++ b/python/tvm/contrib/torch/pytorch_tvm.py @@ -185,7 +185,12 @@ def load_tvm(self, export_dir): def build_pytorch_module(self, num_inputs, num_outputs, input_infos=None): """Build pytorch module containing TVM Graph Module""" warnings.warn( - "We suggest users to use `optimized_torch` for tuning Torch modules instead", + " ".join( + ( + "This function will be removed at TVM version 0.11,", + "we suggest users to use `optimized_torch` for tuning Torch modules instead.", + ) + ), DeprecationWarning, stacklevel=2, ) @@ -231,7 +236,12 @@ def compile(script_module, option): pytorch_tvm_module("model_tvm.pt") """ warnings.warn( - "We suggest users to use `optimized_torch` for tuning Torch modules instead", + " ".join( + ( + "This function will be removed at TVM version 0.11,", + "we suggest users to use `optimized_torch` for tuning Torch modules instead.", + ) + ), DeprecationWarning, stacklevel=2, ) From 180dfd1ff8a68e887fe0f0d4545681f441a2cb75 Mon Sep 17 00:00:00 2001 From: juda Date: Sun, 31 Jul 2022 21:03:50 -0700 Subject: [PATCH 18/33] improving the codes --- apps/pt_tvmdsoop/tests/test_as_torch.py | 7 +- apps/pt_tvmdsoop/tests/test_boolean_tensor.py | 33 +++++++- python/tvm/contrib/torch/__init__.py | 5 +- python/tvm/contrib/torch/module.py | 3 +- .../RuntimeModuleWrapperTVM.cc | 83 ++++++++++++------- .../RuntimeModuleWrapperTorch.cc | 57 +++++++++---- .../torch/tvm_module_wrapper/runtime_bridge.h | 17 ++-- 7 files changed, 145 insertions(+), 60 deletions(-) diff --git a/apps/pt_tvmdsoop/tests/test_as_torch.py b/apps/pt_tvmdsoop/tests/test_as_torch.py index 2c454e9454e7..a13d669e7f36 100644 --- a/apps/pt_tvmdsoop/tests/test_as_torch.py +++ b/apps/pt_tvmdsoop/tests/test_as_torch.py @@ -17,6 +17,8 @@ # specific language governing permissions and limitations # under the License. """Test script for tvm torch module""" +import tempfile + import numpy as np import torch @@ -190,7 +192,10 @@ def test_tvmscript_torch_gpu(): q1 = torch.arange(8, device=cuda0).type(torch.float32) q2 = torch.zeros((8,), dtype=torch.float32, device=cuda0) - ModuleGPU(q1, q2) + with tempfile.NamedTemporaryFile(suffix=".pt") as tmp: + torch.save(ModuleGPU, tmp.name) + loaded_mod = torch.load(tmp.name) + loaded_mod(q1, q2) tvm.testing.assert_allclose(q2.cpu().numpy(), (q1 + 1).cpu().numpy(), atol=1e-5, rtol=1e-5) diff --git a/apps/pt_tvmdsoop/tests/test_boolean_tensor.py b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py index 3d9f2b70221d..600e412d3007 100644 --- a/apps/pt_tvmdsoop/tests/test_boolean_tensor.py +++ b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py @@ -17,11 +17,13 @@ # specific language governing permissions and limitations # under the License. """Test script for boolean tensor support""" -import numpy as np +import tempfile +import numpy as np import torch import tvm +from tvm.meta_schedule.tune import TuneConfig import tvm.testing from tvm.contrib.torch import optimize_torch @@ -34,25 +36,50 @@ def sum_up_tensor(x): return x.size(dim=0) - torch.sum(x.int()) +def tensor_boolean_operation(x): + arr1 = (x + 0.3).floor().bool() + arr2 = (~((x + 0.7).int().bool())).bool() + ret = ((arr1 & arr2).byte() + 0.5).half() + return ~(ret.bool()) + + def test_bool_tensor_negate(): input = torch.ones(1, dtype=torch.bool) optimized_negate = optimize_torch( negate, input, ) - output = optimized_negate(negate(input)) + with tempfile.NamedTemporaryFile(suffix=".pt") as tmp: + torch.save(optimized_negate, tmp.name) + loaded_mod = torch.load(tmp.name) + output = loaded_mod(negate(input)) tvm.testing.assert_allclose(input.numpy(), output.numpy(), atol=1e-5, rtol=1e-5) def test_sum_up_tensor(): x = torch.randint(0, 2, (16,)) y = x.bool() - optimized_func = optimize_torch(sum_up_tensor, (y,)) + optimized_func = optimize_torch( + sum_up_tensor, + (y,), + ) ret1 = (x[x == 0]).size(dim=0) ret2 = optimized_func(y).numpy() tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5) +def test_tensor_boolean_operation(): + input = torch.rand(200) + model = optimize_torch( + tensor_boolean_operation, + input, + ) + ret1 = tensor_boolean_operation(input) + ret2 = model(input) + tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5) + + if __name__ == "__main__": test_bool_tensor_negate() test_sum_up_tensor() + test_tensor_boolean_operation() diff --git a/python/tvm/contrib/torch/__init__.py b/python/tvm/contrib/torch/__init__.py index 59002f5588d6..95bdf0a29c90 100644 --- a/python/tvm/contrib/torch/__init__.py +++ b/python/tvm/contrib/torch/__init__.py @@ -36,9 +36,10 @@ def _load_platform_specific_library(lib_name="libpt_tvmintg"): lib_file_path = os.path.join(lib_dir, lib_file_name) try: torch.classes.load_library(lib_file_path) - except: + except OSError: warnings.warn( - f"The library {lib_name} is not built and loaded successfully.", RuntimeWarning + f"The library {lib_name} is not built successfully due to the CXXABI incompatibility.", + RuntimeWarning, ) diff --git a/python/tvm/contrib/torch/module.py b/python/tvm/contrib/torch/module.py index 624477293d90..cfa3ad264c3a 100644 --- a/python/tvm/contrib/torch/module.py +++ b/python/tvm/contrib/torch/module.py @@ -16,9 +16,10 @@ # under the License. # pylint: disable=invalid-name """Module container of PyTorch custom class""" +import warnings from typing import List + import torch -import warnings class GraphModule(torch.nn.Module): diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index 912ab7a72f21..c884fb9c55b6 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -32,6 +32,9 @@ #include "../base64.h" #include "runtime_bridge.h" +namespace tvm { +namespace contrib { + struct ThreadLocalStore { tvm::runtime::Module mod; static ThreadLocalStore* ThreadLocal() { @@ -40,9 +43,6 @@ struct ThreadLocalStore { } }; -namespace tvm { -namespace contrib { - std::string serialize(tvm::runtime::Module module) { static const runtime::PackedFunc* f_to_str = runtime::Registry::Get("script_torch.save_to_base64"); @@ -86,14 +86,30 @@ tvm::runtime::Module deserialize(std::string state) { } tvm::Device getDeviceInfo(DLManagedTensor* input_device) { - return {.device_type = input_device->dl_tensor.device.device_type, - .device_id = input_device->dl_tensor.device.device_id}; + tvm::Device ret{input_device->dl_tensor.device.device_type, + input_device->dl_tensor.device.device_id}; + return ret; } TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) { ThreadLocalStore::ThreadLocal()->mod = mod; }); +DLPackTensorExt create_dlpack_tensor_ext(tvm::runtime::NDArray* results, bool is_bool, + tvm::Device* device_info) { + DLManagedTensor* tensor; + if (is_bool) { + auto tmp = + tvm::runtime::NDArray::Empty(results->Shape(), DLDataType{kDLInt, 8, 1}, *device_info); + results->CopyTo(tmp); + tensor = tmp.ToDLPack(); + } else { + tensor = results->ToDLPack(); + } + DLPackTensorExt ret{tensor, is_bool}; + return ret; +} + } // namespace contrib } // namespace tvm @@ -102,11 +118,11 @@ extern "C" { struct TVMContribTorchRuntimeModule { tvm::runtime::Module mod; - explicit TVMContribTorchRuntimeModule(tvm::runtime::Module mod) : mod(mod) {} + explicit TVMContribTorchRuntimeModule(tvm::runtime::Module& mod) : mod(mod) {} }; TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module() { - return new TVMContribTorchRuntimeModule(ThreadLocalStore::ThreadLocal()->mod); + return new TVMContribTorchRuntimeModule(tvm::contrib::ThreadLocalStore::ThreadLocal()->mod); } void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* runtime_module, @@ -123,16 +139,22 @@ void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* run nullptr); } -int64_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeModule* graph_module, - DLPackTensorExt* inputs, size_t input_size, - DLPackTensorExt** outputs) { +TVMContribTorchRuntimeModule* tvm_contrib_torch_create_graph_runtime_module( + TVMContribTorchRuntimeModule* graph_module, DLManagedTensor* input_example) { tvm::runtime::PackedFunc built_module = graph_module->mod.GetFunction("default"); - auto device_info = tvm::contrib::getDeviceInfo(inputs[0].dl_managed_tensor); + tvm::Device device_info = tvm::contrib::getDeviceInfo(input_example); tvm::runtime::Module runtime_module = built_module(device_info); - tvm::runtime::PackedFunc run = runtime_module.GetFunction("run"); - tvm::runtime::PackedFunc set_input = runtime_module.GetFunction("set_input"); - tvm::runtime::PackedFunc get_output = runtime_module.GetFunction("get_output"); - tvm::runtime::PackedFunc get_num_outputs = runtime_module.GetFunction("get_num_outputs"); + return new TVMContribTorchRuntimeModule(runtime_module); +} + +size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeModule* runtime_module, + DLPackTensorExt* inputs, size_t input_size, + DLPackTensorExt** outputs) { + tvm::runtime::PackedFunc run = runtime_module->mod.GetFunction("run"); + tvm::runtime::PackedFunc set_input = runtime_module->mod.GetFunction("set_input"); + tvm::runtime::PackedFunc get_output = runtime_module->mod.GetFunction("get_output"); + tvm::runtime::PackedFunc get_num_outputs = runtime_module->mod.GetFunction("get_num_outputs"); + tvm::Device device_info = tvm::contrib::getDeviceInfo(inputs[0].dl_managed_tensor); for (int k = 0; k < input_size; ++k) { set_input(k, &inputs[k].dl_managed_tensor->dl_tensor); @@ -142,38 +164,37 @@ int64_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeMo int64_t output_length = get_num_outputs(); - auto outputs_ptr = new DLPackTensorExt[output_length]; + DLPackTensorExt* outputs_ptr = new DLPackTensorExt[output_length]; *outputs = outputs_ptr; for (int k = 0; k < output_length; ++k) { tvm::runtime::NDArray results = get_output(k); - auto is_bool = results.DataType().is_bool(); - DLManagedTensor* tensor; - if (is_bool) { - auto tmp = - tvm::runtime::NDArray::Empty(results.Shape(), DLDataType{kDLInt, 8, 1}, device_info); - results.CopyTo(tmp); - tensor = tmp.ToDLPack(); - } else { - tensor = results.ToDLPack(); - } - outputs_ptr[k] = {.dl_managed_tensor = tensor, .is_bool = is_bool}; + bool is_bool = results.DataType().is_bool(); + outputs_ptr[k] = tvm::contrib::create_dlpack_tensor_ext(&results, is_bool, &device_info); } return output_length; } char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module) { - auto std = tvm::contrib::serialize(runtime_module->mod); - auto* ret = new char[std.length() + 1]; + std::string std = tvm::contrib::serialize(runtime_module->mod); + char* ret = new char[std.length() + 1]; snprintf(ret, std.length() + 1, "%s", std.c_str()); return ret; } TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state) { - auto ret = tvm::contrib::deserialize(state); + tvm::runtime::Module ret = tvm::contrib::deserialize(state); return new TVMContribTorchRuntimeModule(ret); } -void tvm_contrib_torch_delete_raw_pointer(TensorList* ptr) { delete ptr; } +void tvm_contrib_torch_free_runtime_module(TVMContribTorchRuntimeModule* module_ptr) { + delete module_ptr; +} + +void tvm_contrib_torch_free_dlpack_tensor_ext_array(DLPackTensorExt* dlpack_ptr) { + delete dlpack_ptr; +} + +void tvm_contrib_torch_free_encoding(char* encoding) { delete encoding; } } diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc index b809dc7632f9..78ddbc1becb2 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc @@ -31,13 +31,17 @@ DLPackTensorExt toDLPackExt(const at::Tensor& src) { if (!src.is_contiguous()) { return toDLPackExt(src.contiguous()); } - + DLPackTensorExt ret; if (src.dtype().isScalarType(torch::kBool)) { auto temp = src.toType(torch::kUInt8); - return {.dl_managed_tensor = at::toDLPack(temp), .is_bool = true}; + ret.dl_managed_tensor = at::toDLPack(temp); + ret.is_bool = true; + } else { + ret.dl_managed_tensor = at::toDLPack(src); + ret.is_bool = false; } - return {.dl_managed_tensor = at::toDLPack(src), .is_bool = false}; + return ret; } at::Tensor fromDLPackExt(const DLPackTensorExt& src) { @@ -55,7 +59,8 @@ at::Tensor fromDLPackExt(const DLPackTensorExt& src) { */ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { public: - OperatorModuleWrapper() { runtime_module = tvm_contrib_torch_get_last_saved_runtime_module(); } + OperatorModuleWrapper() { runtime_module_ = tvm_contrib_torch_get_last_saved_runtime_module(); } + ~OperatorModuleWrapper() { tvm_contrib_torch_free_runtime_module(runtime_module_); } void forward(const c10::List& inputs) { int input_length = inputs.size(); @@ -64,21 +69,26 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPackExt(inputs[i])); tvm_contrib_torch_operator_module_forward( - this->runtime_module, static_cast(tensors.data()), tensors.size()); + this->runtime_module_, static_cast(tensors.data()), tensors.size()); for (int k = 0; k < input_length; ++k) { tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor); } } - std::string Serialize() { return std::string(tvm_contrib_torch_encode(runtime_module)); } + std::string Serialize() { + auto encoding = tvm_contrib_torch_encode(runtime_module_); + auto ret = std::string(encoding); + tvm_contrib_torch_free_encoding(encoding); + return ret; + } explicit OperatorModuleWrapper(std::string state) { - runtime_module = tvm_contrib_torch_decode(state.c_str()); + runtime_module_ = tvm_contrib_torch_decode(state.c_str()); } private: - TVMContribTorchRuntimeModule* runtime_module; + TVMContribTorchRuntimeModule* runtime_module_; }; /** @@ -89,14 +99,26 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { public: explicit GraphExecutorFactoryWrapper(TVMContribTorchRuntimeModule* executor_factory) - : executor_factory_(executor_factory) {} + : executor_factory_(executor_factory), executor_factory_runtime_(nullptr) {} + + ~GraphExecutorFactoryWrapper() { + tvm_contrib_torch_free_runtime_module(executor_factory_); + tvm_contrib_torch_free_runtime_module(executor_factory_runtime_); + } GraphExecutorFactoryWrapper() : GraphExecutorFactoryWrapper(tvm_contrib_torch_get_last_saved_runtime_module()) {} - std::string Serialize() { return tvm_contrib_torch_encode(executor_factory_); } + + std::string Serialize() { + auto encoding = tvm_contrib_torch_encode(executor_factory_); + auto ret = std::string(encoding); + tvm_contrib_torch_free_encoding(encoding); + return ret; + } explicit GraphExecutorFactoryWrapper(std::string state) { executor_factory_ = tvm_contrib_torch_decode(state.c_str()); + executor_factory_runtime_ = nullptr; } c10::List forward(const c10::List& inputs) { @@ -108,30 +130,33 @@ class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPackExt(inputs[i])); - auto outputs = new DLPackTensorExt*; - + DLPackTensorExt* outputs; + if (executor_factory_runtime_ == nullptr) { + executor_factory_runtime_ = tvm_contrib_torch_create_graph_runtime_module( + this->executor_factory_, tensors[0].dl_managed_tensor); + } auto num_outputs = tvm_contrib_torch_graph_executor_module_forward( - executor_factory_, static_cast(tensors.data()), tensors.size(), outputs); + executor_factory_runtime_, tensors.data(), tensors.size(), &outputs); c10::List ret; ret.reserve(num_outputs); for (int k = 0; k < num_outputs; ++k) { - at::Tensor atTensor = fromDLPackExt((*outputs)[k]); + at::Tensor atTensor = fromDLPackExt(outputs[k]); ret.emplace_back(atTensor); } for (int k = 0; k < input_length; ++k) { tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor); } - - delete outputs; + tvm_contrib_torch_free_dlpack_tensor_ext_array(outputs); return ret; } private: TVMContribTorchRuntimeModule* executor_factory_; + TVMContribTorchRuntimeModule* executor_factory_runtime_; }; TORCH_LIBRARY(tvm_torch, m) { diff --git a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h index e409a950051b..7697f7c8aa80 100644 --- a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h +++ b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h @@ -25,8 +25,6 @@ extern "C" { -typedef DLManagedTensor** TensorList; - struct DLPackTensorExt { DLManagedTensor* dl_managed_tensor; bool is_bool; @@ -36,18 +34,25 @@ struct TVMContribTorchRuntimeModule; TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module(); +void tvm_contrib_torch_free_runtime_module(TVMContribTorchRuntimeModule* module_ptr); + +TVMContribTorchRuntimeModule* tvm_contrib_torch_create_graph_runtime_module( + TVMContribTorchRuntimeModule* graph_module, DLManagedTensor* input_example); + void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* runtime_module, DLPackTensorExt* inputs, size_t input_size); -int64_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeModule* graph_module, - DLPackTensorExt* inputs, size_t input_size, - DLPackTensorExt** outputs); +size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeModule* graph_module, + DLPackTensorExt* inputs, size_t input_size, + DLPackTensorExt** outputs); char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module); TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state); -void tvm_contrib_torch_delete_raw_pointer(TensorList* ptr); +void tvm_contrib_torch_free_dlpack_tensor_ext_array(DLPackTensorExt*); + +void tvm_contrib_torch_free_encoding(char* encoding); } #endif // TVM_CONTRIB_TORCH_TVM_MODULE_WRAPPER_RUNTIME_BRIDGE_H_ From 14aec157a805b5dccbdf729c382de34b261f3a58 Mon Sep 17 00:00:00 2001 From: juda Date: Mon, 1 Aug 2022 00:31:55 -0700 Subject: [PATCH 19/33] format --- src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index c884fb9c55b6..97d1a4f013ee 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -96,7 +96,7 @@ TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime: }); DLPackTensorExt create_dlpack_tensor_ext(tvm::runtime::NDArray* results, bool is_bool, - tvm::Device* device_info) { + tvm::Device* device_info) { DLManagedTensor* tensor; if (is_bool) { auto tmp = From cf01c640546a02ce3d4a3439fb82c5d8e4cc01f3 Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 2 Aug 2022 20:45:36 -0700 Subject: [PATCH 20/33] doc&errormsg --- cmake/modules/contrib/PT_TVMDSOOP.cmake | 2 +- python/tvm/contrib/torch/__init__.py | 18 +++++-- .../RuntimeModuleWrapperTVM.cc | 14 ++++- .../RuntimeModuleWrapperTorch.cc | 25 +++++++++ .../torch/tvm_module_wrapper/runtime_bridge.h | 51 +++++++++++++++++++ 5 files changed, 104 insertions(+), 6 deletions(-) diff --git a/cmake/modules/contrib/PT_TVMDSOOP.cmake b/cmake/modules/contrib/PT_TVMDSOOP.cmake index 22b108e583a0..b8b0f09e3640 100644 --- a/cmake/modules/contrib/PT_TVMDSOOP.cmake +++ b/cmake/modules/contrib/PT_TVMDSOOP.cmake @@ -70,7 +70,7 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") separate_arguments(PT_LINK_FLAGS UNIX_COMMAND ${PT_LINK_FLAGS_STR}) set(LIBRARY_NAME pt_tvmdsoop) - set(LIBRARY_TORCH_NAME pt_tvmintg) + set(LIBRARY_TORCH_NAME pt_tvmdsoop_new) tvm_file_glob(GLOB_RECURSE PTTVM_TORCH ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/*.cc) tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/*.cc) diff --git a/python/tvm/contrib/torch/__init__.py b/python/tvm/contrib/torch/__init__.py index 95bdf0a29c90..c3dd34d47044 100644 --- a/python/tvm/contrib/torch/__init__.py +++ b/python/tvm/contrib/torch/__init__.py @@ -23,7 +23,7 @@ from tvm._ffi import libinfo -def _load_platform_specific_library(lib_name="libpt_tvmintg"): +def _load_platform_specific_library(lib_name): system = platform.system() if system == "Darwin": lib_file_name = lib_name + ".dylib" @@ -36,15 +36,25 @@ def _load_platform_specific_library(lib_name="libpt_tvmintg"): lib_file_path = os.path.join(lib_dir, lib_file_name) try: torch.classes.load_library(lib_file_path) - except OSError: + except OSError as err: + errmsg = str(err) + if errmsg.find("undefined symbol") != -1: + reason = " ".join( + ( + "Got undefined symbol error,", + "which might be due to the CXXABI incompatibility.", + ) + ) + else: + reason = errmsg warnings.warn( - f"The library {lib_name} is not built successfully due to the CXXABI incompatibility.", + f"The library {lib_name} is not built successfully. {reason}", RuntimeWarning, ) _load_platform_specific_library("libpt_tvmdsoop") -_load_platform_specific_library() +_load_platform_specific_library("libpt_tvmdsoop_new") from . import module diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index 97d1a4f013ee..2f0505a4ce58 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -43,6 +43,9 @@ struct ThreadLocalStore { } }; +/* + * Encode TVM runtime module to base64 stream + */ std::string serialize(tvm::runtime::Module module) { static const runtime::PackedFunc* f_to_str = runtime::Registry::Get("script_torch.save_to_base64"); @@ -61,10 +64,13 @@ struct Deleter { // deleter std::string file_name; }; +/* + * Decode TVM runtime module from base64 stream + */ tvm::runtime::Module deserialize(std::string state) { auto length = tvm::support::b64strlen(state); - std::vector bytes(length); + std::vector bytes(length); // bytes stream tvm::support::b64decode(state, bytes.data()); const std::string name = tmpnam(NULL); @@ -95,12 +101,18 @@ TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime: ThreadLocalStore::ThreadLocal()->mod = mod; }); +/* + * Convert NDArray to DLPack extend tensor. + * @param results Pointer to NDArray + * @return DLPack extend tensor + */ DLPackTensorExt create_dlpack_tensor_ext(tvm::runtime::NDArray* results, bool is_bool, tvm::Device* device_info) { DLManagedTensor* tensor; if (is_bool) { auto tmp = tvm::runtime::NDArray::Empty(results->Shape(), DLDataType{kDLInt, 8, 1}, *device_info); + // Here memory copy is imposed which might cause performance penalty. results->CopyTo(tmp); tensor = tmp.ToDLPack(); } else { diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc index 78ddbc1becb2..3432d30d3ad8 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc @@ -27,6 +27,12 @@ namespace tvm { namespace contrib { +/* + * Convert Torch tensor to DLPack extended tensor. + * The boolean Torch tensor will convert to DLtensor with `is_bool=True` flag. + * @param src Torch tensor + * @return DLPack extended tensor + */ DLPackTensorExt toDLPackExt(const at::Tensor& src) { if (!src.is_contiguous()) { return toDLPackExt(src.contiguous()); @@ -44,6 +50,11 @@ DLPackTensorExt toDLPackExt(const at::Tensor& src) { return ret; } +/* + * Convert DLPack extended tensor to Torch tensor. + * @param src DLPack extended tensor + * @return Torch tensor + */ at::Tensor fromDLPackExt(const DLPackTensorExt& src) { if (src.is_bool) { return at::fromDLPack(src.dl_managed_tensor).toType(torch::kBool); @@ -67,6 +78,8 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { std::vector tensors; + // Torch tensor supports boolean type while DLpack does not, + // we convert Torch tensor to an extension of DLPack tensor for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPackExt(inputs[i])); tvm_contrib_torch_operator_module_forward( this->runtime_module_, static_cast(tensors.data()), tensors.size()); @@ -88,6 +101,9 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { } private: + /* + * TVM runtime module wrapper + */ TVMContribTorchRuntimeModule* runtime_module_; }; @@ -128,6 +144,8 @@ class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { std::vector tensors; + // Torch tensor supports boolean type while DLpack does not, + // we convert Torch tensor to an extension of DLPack tensor for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPackExt(inputs[i])); DLPackTensorExt* outputs; @@ -155,7 +173,14 @@ class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { } private: + /* + * TVM Graph Executor Factory module wrapper + */ TVMContribTorchRuntimeModule* executor_factory_; + + /* + * TVM runtime module wrapper + */ TVMContribTorchRuntimeModule* executor_factory_runtime_; }; diff --git a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h index 7697f7c8aa80..89cbb1fadac5 100644 --- a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h +++ b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h @@ -25,33 +25,84 @@ extern "C" { +/* + * DLPack data structure extend with `is_bool` flag. + * DLPack haven't support boolean tensor, + * thus a boolean tensor will be regarded as a UInt8 tensor. + */ struct DLPackTensorExt { DLManagedTensor* dl_managed_tensor; bool is_bool; }; +/* + * A wrapper pointing to TVM runtime module. + */ struct TVMContribTorchRuntimeModule; +/* + * Obtain a saved runtime module passed by TVM FFI. + * @return A TVM runtime module wrapper. + */ TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module(); +/* + * Delete TVMContribTorchRuntimeModule pointer. + */ void tvm_contrib_torch_free_runtime_module(TVMContribTorchRuntimeModule* module_ptr); +/* + * Obtain ExecutorFactory runtime module from ExecutorFactory class. + * @param graph_module ExecutorFactory class + * @param input_example For obtaining device information + * @return ExecutorFactory TVM runtime module wrapper + */ TVMContribTorchRuntimeModule* tvm_contrib_torch_create_graph_runtime_module( TVMContribTorchRuntimeModule* graph_module, DLManagedTensor* input_example); +/* + * Forward method for OperatorModuleWrapper. + * @param runtime_module TVM runtime module wrapper + * @param inputs Array pointer of the input tensors + * @param input_size The number of input tensors + */ void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* runtime_module, DLPackTensorExt* inputs, size_t input_size); +/* + * Forward method for GraphExecutorFactoryWrapper. + * @param graph_module TVM runtime module wrapper + * @param inputs Array pointer of the input tensors + * @param input_size The number of input tensors + * @param outputs The resulting output tensors pointer + * @return The number of output tensors + */ size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeModule* graph_module, DLPackTensorExt* inputs, size_t input_size, DLPackTensorExt** outputs); +/* + * Encode TVM runtime module. + * @param runtime_module TVM runtime module wrapper + * @return The encoding stream (char array) + */ char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module); +/* + * Decode TVM runtime module. + * @param state The encoding stream (char array) of TVM runtime module + * @return TVM runtime module wrapper + */ TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state); +/* + * Delete DLPackTensorExt pointer. + */ void tvm_contrib_torch_free_dlpack_tensor_ext_array(DLPackTensorExt*); +/* + * Delete char array pointer. + */ void tvm_contrib_torch_free_encoding(char* encoding); } From 7d030ff821696dc85c3a4224af994e6f1534e2fe Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 2 Aug 2022 23:59:16 -0700 Subject: [PATCH 21/33] zero-cost copy --- .../tvm_module_wrapper/RuntimeModuleWrapperTVM.cc | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index 2f0505a4ce58..4a829e939a5c 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -102,21 +102,18 @@ TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime: }); /* - * Convert NDArray to DLPack extend tensor. - * @param results Pointer to NDArray + * Convert NDArray to DLPack extend tensor. It should be zero-cost. + * @param src Pointer to NDArray * @return DLPack extend tensor */ -DLPackTensorExt create_dlpack_tensor_ext(tvm::runtime::NDArray* results, bool is_bool, +DLPackTensorExt create_dlpack_tensor_ext(tvm::runtime::NDArray* src, bool is_bool, tvm::Device* device_info) { DLManagedTensor* tensor; if (is_bool) { - auto tmp = - tvm::runtime::NDArray::Empty(results->Shape(), DLDataType{kDLInt, 8, 1}, *device_info); - // Here memory copy is imposed which might cause performance penalty. - results->CopyTo(tmp); + auto tmp = src->CreateView(src->Shape(), DLDataType{kDLInt, 8, 1}); tensor = tmp.ToDLPack(); } else { - tensor = results->ToDLPack(); + tensor = src->ToDLPack(); } DLPackTensorExt ret{tensor, is_bool}; return ret; From f361984aafdf193663ba626ef4ad2329669bfbb7 Mon Sep 17 00:00:00 2001 From: juda Date: Wed, 3 Aug 2022 06:10:41 -0700 Subject: [PATCH 22/33] one step --- .../RuntimeModuleWrapperTVM.cc | 65 +++++++++++++++++-- .../RuntimeModuleWrapperTorch.cc | 2 +- 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index 4a829e939a5c..fd3c1953195b 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -104,12 +104,13 @@ TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime: /* * Convert NDArray to DLPack extend tensor. It should be zero-cost. * @param src Pointer to NDArray - * @return DLPack extend tensor + * @return DLPack extended tensor */ -DLPackTensorExt create_dlpack_tensor_ext(tvm::runtime::NDArray* src, bool is_bool, - tvm::Device* device_info) { +DLPackTensorExt create_dlpack_tensor_ext(tvm::runtime::NDArray* src, bool is_bool) { DLManagedTensor* tensor; if (is_bool) { + // If we change DLDataType{kDLInt, 8, 1} to DataType::Bool() + // we will get `RuntimeError: Unsupported kUInt bits 1` auto tmp = src->CreateView(src->Shape(), DLDataType{kDLInt, 8, 1}); tensor = tmp.ToDLPack(); } else { @@ -119,6 +120,46 @@ DLPackTensorExt create_dlpack_tensor_ext(tvm::runtime::NDArray* src, bool is_boo return ret; } +/* + * Create an NDArray with boolean type. (One memory copy) + * @param src DLpack extended tensor + * @return a new NDArray + */ +tvm::runtime::NDArray create_bool_ndarray(DLPackTensorExt* src) { + auto& dl_tensor = src->dl_managed_tensor->dl_tensor; + std::vector shape; + shape.resize(dl_tensor.ndim); + shape.assign(dl_tensor.shape, dl_tensor.shape + dl_tensor.ndim); + auto shapeTuple = ShapeTuple(shape); + auto ret = tvm::runtime::NDArray::Empty(shapeTuple, DataType::Bool(), dl_tensor.device); + ret.CopyFrom(&src->dl_managed_tensor->dl_tensor); + return ret; +} + +/* + * Create an NDArray from DLpack extended tensor. + * @param src DLpack extended tensor + * @return a new NDArray + */ +tvm::runtime::NDArray ndarray_from_dlpack(DLPackTensorExt* src) { + using tvm::runtime::NDArray; + + NDArray array; + auto& dl_tensor = src->dl_managed_tensor->dl_tensor; + bool is_zero_copy = + tvm::runtime::NDArray::AbilityOfZeroCopyForDLTensor(&dl_tensor, dl_tensor.device); + if (src->is_bool) { + // one memory copy + array = create_bool_ndarray(src); + } else if (is_zero_copy) { + array = NDArray::FromDLPack(src->dl_managed_tensor); + } else { + // one memory copy + array = NDArray::NewFromDLTensor(&dl_tensor, dl_tensor.device); + } + return array; +} + } // namespace contrib } // namespace tvm @@ -141,11 +182,24 @@ void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* run std::vector tvm_values(input_size); std::vector tvm_type_codes(input_size); tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); + + std::vector input_tmp(input_size); + for (int k = 0; k < input_size; ++k) { - setter(k, &inputs[k].dl_managed_tensor->dl_tensor); + auto datum = tvm::contrib::ndarray_from_dlpack(&inputs[k]); + setter(k, datum); + input_tmp.push_back(std::move(datum)); } run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_size), nullptr); + + for (int k = 0; k < input_size; ++k) { + tvm::runtime::NDArray* result_ndarray = + static_cast(tvm_values[k].v_handle); + // "result_ndarray->ToDLPack()->dl_tensor" will lead a memory error + // Maybe the static_cast is problematic + inputs[k].dl_managed_tensor->dl_tensor = result_ndarray->ToDLPack()->dl_tensor; + } } TVMContribTorchRuntimeModule* tvm_contrib_torch_create_graph_runtime_module( @@ -163,7 +217,6 @@ size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeMod tvm::runtime::PackedFunc set_input = runtime_module->mod.GetFunction("set_input"); tvm::runtime::PackedFunc get_output = runtime_module->mod.GetFunction("get_output"); tvm::runtime::PackedFunc get_num_outputs = runtime_module->mod.GetFunction("get_num_outputs"); - tvm::Device device_info = tvm::contrib::getDeviceInfo(inputs[0].dl_managed_tensor); for (int k = 0; k < input_size; ++k) { set_input(k, &inputs[k].dl_managed_tensor->dl_tensor); @@ -179,7 +232,7 @@ size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeMod for (int k = 0; k < output_length; ++k) { tvm::runtime::NDArray results = get_output(k); bool is_bool = results.DataType().is_bool(); - outputs_ptr[k] = tvm::contrib::create_dlpack_tensor_ext(&results, is_bool, &device_info); + outputs_ptr[k] = tvm::contrib::create_dlpack_tensor_ext(&results, is_bool); } return output_length; diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc index 3432d30d3ad8..43e0cfa0187b 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc @@ -82,7 +82,7 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { // we convert Torch tensor to an extension of DLPack tensor for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPackExt(inputs[i])); tvm_contrib_torch_operator_module_forward( - this->runtime_module_, static_cast(tensors.data()), tensors.size()); + this->runtime_module_, tensors.data(), tensors.size()); for (int k = 0; k < input_length; ++k) { tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor); From eed12b6bb8cd3a577d6c28cd10567bacc2c0f418 Mon Sep 17 00:00:00 2001 From: juda Date: Wed, 3 Aug 2022 19:16:38 -0700 Subject: [PATCH 23/33] to ndarray --- .../RuntimeModuleWrapperTVM.cc | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index fd3c1953195b..703e1d57b848 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -121,21 +121,32 @@ DLPackTensorExt create_dlpack_tensor_ext(tvm::runtime::NDArray* src, bool is_boo } /* - * Create an NDArray with boolean type. (One memory copy) + * Create an empty NDArray with boolean type. * @param src DLpack extended tensor - * @return a new NDArray + * @return an empty NDArray */ -tvm::runtime::NDArray create_bool_ndarray(DLPackTensorExt* src) { +tvm::runtime::NDArray create_empty_bool_ndarray(DLPackTensorExt* src) { auto& dl_tensor = src->dl_managed_tensor->dl_tensor; std::vector shape; shape.resize(dl_tensor.ndim); shape.assign(dl_tensor.shape, dl_tensor.shape + dl_tensor.ndim); + LOG(INFO) << shape[0] << " " << shape[1]; auto shapeTuple = ShapeTuple(shape); auto ret = tvm::runtime::NDArray::Empty(shapeTuple, DataType::Bool(), dl_tensor.device); - ret.CopyFrom(&src->dl_managed_tensor->dl_tensor); return ret; } +/* + * Create an NDArray with boolean type. (One memory copy) + * @param src DLpack extended tensor + * @return a new NDArray + */ +tvm::runtime::NDArray create_bool_ndarray(DLPackTensorExt* src) { + auto&& ret = create_empty_bool_ndarray(src); + ret.CopyFrom(&src->dl_managed_tensor->dl_tensor); + return std::move(ret); +} + /* * Create an NDArray from DLpack extended tensor. * @param src DLpack extended tensor @@ -152,7 +163,7 @@ tvm::runtime::NDArray ndarray_from_dlpack(DLPackTensorExt* src) { // one memory copy array = create_bool_ndarray(src); } else if (is_zero_copy) { - array = NDArray::FromDLPack(src->dl_managed_tensor); + array = NDArray::FromExternalDLTensor(src->dl_managed_tensor->dl_tensor); } else { // one memory copy array = NDArray::NewFromDLTensor(&dl_tensor, dl_tensor.device); @@ -194,11 +205,8 @@ void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* run nullptr); for (int k = 0; k < input_size; ++k) { - tvm::runtime::NDArray* result_ndarray = - static_cast(tvm_values[k].v_handle); - // "result_ndarray->ToDLPack()->dl_tensor" will lead a memory error - // Maybe the static_cast is problematic - inputs[k].dl_managed_tensor->dl_tensor = result_ndarray->ToDLPack()->dl_tensor; + auto result = static_cast(tvm_values[k].v_handle); + tvm::runtime::NDArray::CopyFromTo(&result->dl_tensor, &inputs[k].dl_managed_tensor->dl_tensor); } } From ce5bd2634092c7d2b132f4f5713160a8a730be2a Mon Sep 17 00:00:00 2001 From: juda Date: Wed, 3 Aug 2022 20:51:34 -0700 Subject: [PATCH 24/33] extra output --- .../RuntimeModuleWrapperTVM.cc | 25 ++++++++++--------- .../RuntimeModuleWrapperTorch.cc | 19 +++++++++++--- .../torch/tvm_module_wrapper/runtime_bridge.h | 3 ++- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index 703e1d57b848..c1425566a6c3 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -126,13 +126,12 @@ DLPackTensorExt create_dlpack_tensor_ext(tvm::runtime::NDArray* src, bool is_boo * @return an empty NDArray */ tvm::runtime::NDArray create_empty_bool_ndarray(DLPackTensorExt* src) { - auto& dl_tensor = src->dl_managed_tensor->dl_tensor; - std::vector shape; - shape.resize(dl_tensor.ndim); - shape.assign(dl_tensor.shape, dl_tensor.shape + dl_tensor.ndim); - LOG(INFO) << shape[0] << " " << shape[1]; - auto shapeTuple = ShapeTuple(shape); - auto ret = tvm::runtime::NDArray::Empty(shapeTuple, DataType::Bool(), dl_tensor.device); + auto& tensor = src->dl_managed_tensor->dl_tensor; + std::vector shape; + for (int64_t i = 0; i < tensor.ndim; i++) { + shape.push_back(tensor.shape[i]); + } + auto ret = tvm::runtime::NDArray::Empty(shape, DataType::Bool(), tensor.device); return ret; } @@ -187,26 +186,28 @@ TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module() } void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* runtime_module, - DLPackTensorExt* inputs, size_t input_size) { + DLPackTensorExt* inputs, size_t input_size, + DLPackTensorExt** outputs) { tvm::runtime::PackedFunc run = runtime_module->mod.GetFunction("__tvm_main__"); std::vector tvm_values(input_size); std::vector tvm_type_codes(input_size); tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); - std::vector input_tmp(input_size); + DLPackTensorExt* outputs_ptr = new DLPackTensorExt[input_size]; + *outputs = outputs_ptr; for (int k = 0; k < input_size; ++k) { auto datum = tvm::contrib::ndarray_from_dlpack(&inputs[k]); + outputs_ptr[k] = tvm::contrib::create_dlpack_tensor_ext(&datum, inputs[k].is_bool); setter(k, datum); - input_tmp.push_back(std::move(datum)); } run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_size), nullptr); for (int k = 0; k < input_size; ++k) { - auto result = static_cast(tvm_values[k].v_handle); - tvm::runtime::NDArray::CopyFromTo(&result->dl_tensor, &inputs[k].dl_managed_tensor->dl_tensor); + tvm::runtime::NDArray::CopyFromTo(&outputs_ptr[k].dl_managed_tensor->dl_tensor, + &inputs[k].dl_managed_tensor->dl_tensor); } } diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc index 43e0cfa0187b..16fadac513d7 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc @@ -73,20 +73,33 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { OperatorModuleWrapper() { runtime_module_ = tvm_contrib_torch_get_last_saved_runtime_module(); } ~OperatorModuleWrapper() { tvm_contrib_torch_free_runtime_module(runtime_module_); } - void forward(const c10::List& inputs) { + c10::List forward(const c10::List& inputs) { int input_length = inputs.size(); std::vector tensors; + DLPackTensorExt* outputs; + // Torch tensor supports boolean type while DLpack does not, // we convert Torch tensor to an extension of DLPack tensor for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPackExt(inputs[i])); - tvm_contrib_torch_operator_module_forward( - this->runtime_module_, tensors.data(), tensors.size()); + tvm_contrib_torch_operator_module_forward(this->runtime_module_, tensors.data(), tensors.size(), + &outputs); + + c10::List ret; + ret.reserve(input_length); + + for (int k = 0; k < input_length; ++k) { + at::Tensor atTensor = fromDLPackExt(outputs[k]); + ret.emplace_back(atTensor); + } for (int k = 0; k < input_length; ++k) { tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor); } + tvm_contrib_torch_free_dlpack_tensor_ext_array(outputs); + + return ret; } std::string Serialize() { diff --git a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h index 89cbb1fadac5..77cc85ccf87a 100644 --- a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h +++ b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h @@ -67,7 +67,8 @@ TVMContribTorchRuntimeModule* tvm_contrib_torch_create_graph_runtime_module( * @param input_size The number of input tensors */ void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* runtime_module, - DLPackTensorExt* inputs, size_t input_size); + DLPackTensorExt* inputs, size_t input_size, + DLPackTensorExt** outputs); /* * Forward method for GraphExecutorFactoryWrapper. From 71418fcc90a14a18e838d750a4d311546c7112e6 Mon Sep 17 00:00:00 2001 From: juda Date: Thu, 4 Aug 2022 00:16:43 -0700 Subject: [PATCH 25/33] delete extra codes --- .../RuntimeModuleWrapperTVM.cc | 15 ++++++++------- .../RuntimeModuleWrapperTorch.cc | 19 +++---------------- .../torch/tvm_module_wrapper/runtime_bridge.h | 3 +-- 3 files changed, 12 insertions(+), 25 deletions(-) diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index c1425566a6c3..17cf3c5a936a 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -160,6 +160,7 @@ tvm::runtime::NDArray ndarray_from_dlpack(DLPackTensorExt* src) { tvm::runtime::NDArray::AbilityOfZeroCopyForDLTensor(&dl_tensor, dl_tensor.device); if (src->is_bool) { // one memory copy + // the code is similar to NewFromDLTensor except for the type array = create_bool_ndarray(src); } else if (is_zero_copy) { array = NDArray::FromExternalDLTensor(src->dl_managed_tensor->dl_tensor); @@ -186,28 +187,28 @@ TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module() } void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* runtime_module, - DLPackTensorExt* inputs, size_t input_size, - DLPackTensorExt** outputs) { + DLPackTensorExt* inputs, size_t input_size) { tvm::runtime::PackedFunc run = runtime_module->mod.GetFunction("__tvm_main__"); std::vector tvm_values(input_size); std::vector tvm_type_codes(input_size); tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); - DLPackTensorExt* outputs_ptr = new DLPackTensorExt[input_size]; - *outputs = outputs_ptr; + std::vector input_cache(input_size); for (int k = 0; k < input_size; ++k) { auto datum = tvm::contrib::ndarray_from_dlpack(&inputs[k]); - outputs_ptr[k] = tvm::contrib::create_dlpack_tensor_ext(&datum, inputs[k].is_bool); + input_cache[k] = datum; // we keep the datum in a vector for future use, otherwise the datum + // will be freed after the loop setter(k, datum); } + run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_size), nullptr); for (int k = 0; k < input_size; ++k) { - tvm::runtime::NDArray::CopyFromTo(&outputs_ptr[k].dl_managed_tensor->dl_tensor, - &inputs[k].dl_managed_tensor->dl_tensor); + // this statement seems to not work for boolean tensor + input_cache[k].CopyTo(&inputs[k].dl_managed_tensor->dl_tensor); } } diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc index 16fadac513d7..5a707e5062d2 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc @@ -73,33 +73,20 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { OperatorModuleWrapper() { runtime_module_ = tvm_contrib_torch_get_last_saved_runtime_module(); } ~OperatorModuleWrapper() { tvm_contrib_torch_free_runtime_module(runtime_module_); } - c10::List forward(const c10::List& inputs) { + void forward(const c10::List& inputs) { int input_length = inputs.size(); std::vector tensors; - DLPackTensorExt* outputs; - // Torch tensor supports boolean type while DLpack does not, // we convert Torch tensor to an extension of DLPack tensor for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPackExt(inputs[i])); - tvm_contrib_torch_operator_module_forward(this->runtime_module_, tensors.data(), tensors.size(), - &outputs); - - c10::List ret; - ret.reserve(input_length); - - for (int k = 0; k < input_length; ++k) { - at::Tensor atTensor = fromDLPackExt(outputs[k]); - ret.emplace_back(atTensor); - } + tvm_contrib_torch_operator_module_forward(this->runtime_module_, tensors.data(), + tensors.size()); for (int k = 0; k < input_length; ++k) { tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor); } - tvm_contrib_torch_free_dlpack_tensor_ext_array(outputs); - - return ret; } std::string Serialize() { diff --git a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h index 77cc85ccf87a..89cbb1fadac5 100644 --- a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h +++ b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h @@ -67,8 +67,7 @@ TVMContribTorchRuntimeModule* tvm_contrib_torch_create_graph_runtime_module( * @param input_size The number of input tensors */ void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* runtime_module, - DLPackTensorExt* inputs, size_t input_size, - DLPackTensorExt** outputs); + DLPackTensorExt* inputs, size_t input_size); /* * Forward method for GraphExecutorFactoryWrapper. From 05d188e5bda1a9b247e2398ec17812513a204e15 Mon Sep 17 00:00:00 2001 From: juda Date: Thu, 4 Aug 2022 07:37:27 -0700 Subject: [PATCH 26/33] update test --- apps/pt_tvmdsoop/tests/test_boolean_tensor.py | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/apps/pt_tvmdsoop/tests/test_boolean_tensor.py b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py index 600e412d3007..327ea55dd889 100644 --- a/apps/pt_tvmdsoop/tests/test_boolean_tensor.py +++ b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py @@ -23,9 +23,10 @@ import torch import tvm -from tvm.meta_schedule.tune import TuneConfig import tvm.testing -from tvm.contrib.torch import optimize_torch +from tvm.contrib.torch import as_torch, optimize_torch +from tvm.meta_schedule.tune import TuneConfig +from tvm.script import tir as T def negate(x): @@ -79,7 +80,25 @@ def test_tensor_boolean_operation(): tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5) +@as_torch +@T.prim_func +def negate(X: T.Buffer[(8, 8), "bool"], Y: T.Buffer[(8, 8), "bool"]) -> None: + for i, j in T.grid(8, 8): + with T.block(): + Y[i, j] = not X[i, j] + + +def test_tvmscript_torch_decorator(): + q1 = (torch.rand(8, 8) + 0.5).int().bool() + q2 = torch.zeros((8, 8), dtype=torch.bool) + + negate(q1, q2) + + tvm.testing.assert_allclose(~q1.numpy(), q2.numpy(), atol=1e-5, rtol=1e-5) + + if __name__ == "__main__": + test_tvmscript_torch_decorator() test_bool_tensor_negate() test_sum_up_tensor() test_tensor_boolean_operation() From 56bdbb6f62a538844990f69b74334cc3996d7246 Mon Sep 17 00:00:00 2001 From: juda Date: Thu, 4 Aug 2022 18:21:59 -0700 Subject: [PATCH 27/33] boolean support --- apps/pt_tvmdsoop/tests/test_boolean_tensor.py | 4 ++-- .../torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/apps/pt_tvmdsoop/tests/test_boolean_tensor.py b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py index 327ea55dd889..2e6807cd2756 100644 --- a/apps/pt_tvmdsoop/tests/test_boolean_tensor.py +++ b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py @@ -82,7 +82,7 @@ def test_tensor_boolean_operation(): @as_torch @T.prim_func -def negate(X: T.Buffer[(8, 8), "bool"], Y: T.Buffer[(8, 8), "bool"]) -> None: +def negate_tvmscript(X: T.Buffer[(8, 8), "bool"], Y: T.Buffer[(8, 8), "bool"]) -> None: for i, j in T.grid(8, 8): with T.block(): Y[i, j] = not X[i, j] @@ -92,7 +92,7 @@ def test_tvmscript_torch_decorator(): q1 = (torch.rand(8, 8) + 0.5).int().bool() q2 = torch.zeros((8, 8), dtype=torch.bool) - negate(q1, q2) + negate_tvmscript(q1, q2) tvm.testing.assert_allclose(~q1.numpy(), q2.numpy(), atol=1e-5, rtol=1e-5) diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc index 5a707e5062d2..5fbf135212e8 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc @@ -85,7 +85,11 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { tensors.size()); for (int k = 0; k < input_length; ++k) { - tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor); + if (tensors[k].is_bool) { + inputs[k].copy_(fromDLPackExt(tensors[k])); + } else { + tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor); + } } } From d9803f390ac1dc3a8aa9aeee51767ce7742eb0e8 Mon Sep 17 00:00:00 2001 From: juda Date: Mon, 8 Aug 2022 00:41:41 -0700 Subject: [PATCH 28/33] strong test --- apps/pt_tvmdsoop/tests/test_boolean_tensor.py | 37 ++++++++++++++++--- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/apps/pt_tvmdsoop/tests/test_boolean_tensor.py b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py index 2e6807cd2756..feed3e2c4e6c 100644 --- a/apps/pt_tvmdsoop/tests/test_boolean_tensor.py +++ b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py @@ -82,19 +82,46 @@ def test_tensor_boolean_operation(): @as_torch @T.prim_func -def negate_tvmscript(X: T.Buffer[(8, 8), "bool"], Y: T.Buffer[(8, 8), "bool"]) -> None: +def negate_tvmscript( + X: T.Buffer[(8, 8), "bool"], + Y: T.Buffer[(8, 8), "float32"], + Z: T.Buffer[(8, 8), "bool"], + U: T.Buffer[(8, 8), "float32"], +) -> None: for i, j in T.grid(8, 8): with T.block(): - Y[i, j] = not X[i, j] + if Y[i, j] > 0.0: + Z[i, j] = X[i, j] + U[i, j] = Y[i, j] + else: + Z[i, j] = not X[i, j] + U[i, j] = 0.0 - Y[i, j] + + +def negate_vanila(x, y): + z = torch.zeros(8, 8).bool() + for i in range(8): + for j in range(8): + if y[i, j] > 0: + z[i, j] = x[i, j] + else: + z[i, j] = ~x[i, j] + return z def test_tvmscript_torch_decorator(): q1 = (torch.rand(8, 8) + 0.5).int().bool() - q2 = torch.zeros((8, 8), dtype=torch.bool) + q2 = torch.rand(8, 8) - 0.5 + q3 = torch.zeros(8, 8).bool() + q4 = torch.zeros(8, 8) + + std1 = negate_vanila(q1, q2) + std2 = torch.abs(q2) - negate_tvmscript(q1, q2) + negate_tvmscript(q1, q2, q3, q4) - tvm.testing.assert_allclose(~q1.numpy(), q2.numpy(), atol=1e-5, rtol=1e-5) + tvm.testing.assert_allclose(std1.numpy(), q3.numpy(), atol=1e-5, rtol=1e-5) + tvm.testing.assert_allclose(std2.numpy(), q4.numpy(), atol=1e-5, rtol=1e-5) if __name__ == "__main__": From d3d993a5bf60ce163acd0af9f33acc08361a540d Mon Sep 17 00:00:00 2001 From: juda Date: Mon, 8 Aug 2022 01:46:09 -0700 Subject: [PATCH 29/33] decrease memory copy --- .../RuntimeModuleWrapperTVM.cc | 21 +++++++++++++------ .../RuntimeModuleWrapperTorch.cc | 2 +- .../torch/tvm_module_wrapper/runtime_bridge.h | 5 +++++ 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index 17cf3c5a936a..e5e9c661ca4b 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -146,6 +146,13 @@ tvm::runtime::NDArray create_bool_ndarray(DLPackTensorExt* src) { return std::move(ret); } +bool is_zero_copy(DLPackTensorExt* src) { + auto& dl_tensor = src->dl_managed_tensor->dl_tensor; + bool is_zero_copy = + tvm::runtime::NDArray::AbilityOfZeroCopyForDLTensor(&dl_tensor, dl_tensor.device); + return is_zero_copy; +} + /* * Create an NDArray from DLpack extended tensor. * @param src DLpack extended tensor @@ -156,13 +163,11 @@ tvm::runtime::NDArray ndarray_from_dlpack(DLPackTensorExt* src) { NDArray array; auto& dl_tensor = src->dl_managed_tensor->dl_tensor; - bool is_zero_copy = - tvm::runtime::NDArray::AbilityOfZeroCopyForDLTensor(&dl_tensor, dl_tensor.device); if (src->is_bool) { // one memory copy // the code is similar to NewFromDLTensor except for the type array = create_bool_ndarray(src); - } else if (is_zero_copy) { + } else if (is_zero_copy(src)) { array = NDArray::FromExternalDLTensor(src->dl_managed_tensor->dl_tensor); } else { // one memory copy @@ -182,6 +187,10 @@ struct TVMContribTorchRuntimeModule { explicit TVMContribTorchRuntimeModule(tvm::runtime::Module& mod) : mod(mod) {} }; +bool tvm_contrib_torch_is_be_copied(DLPackTensorExt* src) { + return (src->is_bool) || (!tvm::contrib::is_zero_copy(src)); +} + TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module() { return new TVMContribTorchRuntimeModule(tvm::contrib::ThreadLocalStore::ThreadLocal()->mod); } @@ -197,7 +206,7 @@ void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* run std::vector input_cache(input_size); for (int k = 0; k < input_size; ++k) { - auto datum = tvm::contrib::ndarray_from_dlpack(&inputs[k]); + auto datum = tvm::contrib::ndarray_from_dlpack(&inputs[k]); // could have one memory copy input_cache[k] = datum; // we keep the datum in a vector for future use, otherwise the datum // will be freed after the loop setter(k, datum); @@ -207,8 +216,8 @@ void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* run nullptr); for (int k = 0; k < input_size; ++k) { - // this statement seems to not work for boolean tensor - input_cache[k].CopyTo(&inputs[k].dl_managed_tensor->dl_tensor); + if (tvm_contrib_torch_is_be_copied(&inputs[k])) + input_cache[k].CopyTo(&inputs[k].dl_managed_tensor->dl_tensor); } } diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc index 5fbf135212e8..b071b439d0a0 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc @@ -85,7 +85,7 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { tensors.size()); for (int k = 0; k < input_length; ++k) { - if (tensors[k].is_bool) { + if (tvm_contrib_torch_is_be_copied(&tensors[k])) { inputs[k].copy_(fromDLPackExt(tensors[k])); } else { tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor); diff --git a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h index 89cbb1fadac5..772d71cdf190 100644 --- a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h +++ b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h @@ -104,6 +104,11 @@ void tvm_contrib_torch_free_dlpack_tensor_ext_array(DLPackTensorExt*); * Delete char array pointer. */ void tvm_contrib_torch_free_encoding(char* encoding); + +/* + * Checking if a DLPackTensorExt is boolean or cannot be copied in zero cost. + */ +bool tvm_contrib_torch_is_be_copied(DLPackTensorExt*); } #endif // TVM_CONTRIB_TORCH_TVM_MODULE_WRAPPER_RUNTIME_BRIDGE_H_ From e3e52fc664abfa1472efd6d43382b9013d057443 Mon Sep 17 00:00:00 2001 From: juda Date: Sun, 14 Aug 2022 20:23:22 -0700 Subject: [PATCH 30/33] polish --- .../RuntimeModuleWrapperTVM.cc | 66 +++++++------------ .../RuntimeModuleWrapperTorch.cc | 4 +- .../torch/tvm_module_wrapper/runtime_bridge.h | 20 +++--- 3 files changed, 37 insertions(+), 53 deletions(-) diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index e5e9c661ca4b..7f2a84d27f46 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -91,12 +91,6 @@ tvm::runtime::Module deserialize(std::string state) { return ret; } -tvm::Device getDeviceInfo(DLManagedTensor* input_device) { - tvm::Device ret{input_device->dl_tensor.device.device_type, - input_device->dl_tensor.device.device_id}; - return ret; -} - TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) { ThreadLocalStore::ThreadLocal()->mod = mod; }); @@ -106,7 +100,8 @@ TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime: * @param src Pointer to NDArray * @return DLPack extended tensor */ -DLPackTensorExt create_dlpack_tensor_ext(tvm::runtime::NDArray* src, bool is_bool) { +DLPackTensorExt createDLpackTensorExt(tvm::runtime::NDArray* src) { + auto is_bool = src->DataType().is_bool(); DLManagedTensor* tensor; if (is_bool) { // If we change DLDataType{kDLInt, 8, 1} to DataType::Bool() @@ -121,36 +116,24 @@ DLPackTensorExt create_dlpack_tensor_ext(tvm::runtime::NDArray* src, bool is_boo } /* - * Create an empty NDArray with boolean type. + * Create an NDArray with boolean type. (One memory copy) * @param src DLpack extended tensor - * @return an empty NDArray + * @return a new NDArray */ -tvm::runtime::NDArray create_empty_bool_ndarray(DLPackTensorExt* src) { +tvm::runtime::NDArray createBoolNDarray(DLPackTensorExt* src) { auto& tensor = src->dl_managed_tensor->dl_tensor; std::vector shape; for (int64_t i = 0; i < tensor.ndim; i++) { shape.push_back(tensor.shape[i]); } auto ret = tvm::runtime::NDArray::Empty(shape, DataType::Bool(), tensor.device); - return ret; -} - -/* - * Create an NDArray with boolean type. (One memory copy) - * @param src DLpack extended tensor - * @return a new NDArray - */ -tvm::runtime::NDArray create_bool_ndarray(DLPackTensorExt* src) { - auto&& ret = create_empty_bool_ndarray(src); ret.CopyFrom(&src->dl_managed_tensor->dl_tensor); return std::move(ret); } -bool is_zero_copy(DLPackTensorExt* src) { +bool isZeroCopy(DLPackTensorExt* src) { auto& dl_tensor = src->dl_managed_tensor->dl_tensor; - bool is_zero_copy = - tvm::runtime::NDArray::AbilityOfZeroCopyForDLTensor(&dl_tensor, dl_tensor.device); - return is_zero_copy; + return tvm::runtime::NDArray::AbilityOfZeroCopyForDLTensor(&dl_tensor, dl_tensor.device); } /* @@ -158,7 +141,7 @@ bool is_zero_copy(DLPackTensorExt* src) { * @param src DLpack extended tensor * @return a new NDArray */ -tvm::runtime::NDArray ndarray_from_dlpack(DLPackTensorExt* src) { +tvm::runtime::NDArray ndarrayFromDLpack(DLPackTensorExt* src) { using tvm::runtime::NDArray; NDArray array; @@ -166,8 +149,8 @@ tvm::runtime::NDArray ndarray_from_dlpack(DLPackTensorExt* src) { if (src->is_bool) { // one memory copy // the code is similar to NewFromDLTensor except for the type - array = create_bool_ndarray(src); - } else if (is_zero_copy(src)) { + array = createBoolNDarray(src); + } else if (isZeroCopy(src)) { array = NDArray::FromExternalDLTensor(src->dl_managed_tensor->dl_tensor); } else { // one memory copy @@ -187,8 +170,8 @@ struct TVMContribTorchRuntimeModule { explicit TVMContribTorchRuntimeModule(tvm::runtime::Module& mod) : mod(mod) {} }; -bool tvm_contrib_torch_is_be_copied(DLPackTensorExt* src) { - return (src->is_bool) || (!tvm::contrib::is_zero_copy(src)); +bool tvm_contrib_torch_tensor_ability_of_zero_copy(DLPackTensorExt* src) { + return (src->is_bool) || (!tvm::contrib::isZeroCopy(src)); } TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module() { @@ -205,8 +188,8 @@ void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* run std::vector input_cache(input_size); - for (int k = 0; k < input_size; ++k) { - auto datum = tvm::contrib::ndarray_from_dlpack(&inputs[k]); // could have one memory copy + for (size_t k = 0; k < input_size; ++k) { + auto datum = tvm::contrib::ndarrayFromDLpack(&inputs[k]); // could have one memory copy input_cache[k] = datum; // we keep the datum in a vector for future use, otherwise the datum // will be freed after the loop setter(k, datum); @@ -215,16 +198,16 @@ void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* run run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_size), nullptr); - for (int k = 0; k < input_size; ++k) { - if (tvm_contrib_torch_is_be_copied(&inputs[k])) + for (size_t k = 0; k < input_size; ++k) { + if (tvm_contrib_torch_tensor_ability_of_zero_copy(&inputs[k])) input_cache[k].CopyTo(&inputs[k].dl_managed_tensor->dl_tensor); } } TVMContribTorchRuntimeModule* tvm_contrib_torch_create_graph_runtime_module( - TVMContribTorchRuntimeModule* graph_module, DLManagedTensor* input_example) { - tvm::runtime::PackedFunc built_module = graph_module->mod.GetFunction("default"); - tvm::Device device_info = tvm::contrib::getDeviceInfo(input_example); + TVMContribTorchRuntimeModule* graph_executor_factory, DLManagedTensor* input_example) { + tvm::runtime::PackedFunc built_module = graph_executor_factory->mod.GetFunction("default"); + tvm::Device device_info = input_example->dl_tensor.device; tvm::runtime::Module runtime_module = built_module(device_info); return new TVMContribTorchRuntimeModule(runtime_module); } @@ -237,7 +220,7 @@ size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeMod tvm::runtime::PackedFunc get_output = runtime_module->mod.GetFunction("get_output"); tvm::runtime::PackedFunc get_num_outputs = runtime_module->mod.GetFunction("get_num_outputs"); - for (int k = 0; k < input_size; ++k) { + for (size_t k = 0; k < input_size; ++k) { set_input(k, &inputs[k].dl_managed_tensor->dl_tensor); } @@ -248,10 +231,9 @@ size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeMod DLPackTensorExt* outputs_ptr = new DLPackTensorExt[output_length]; *outputs = outputs_ptr; - for (int k = 0; k < output_length; ++k) { + for (int64_t k = 0; k < output_length; ++k) { tvm::runtime::NDArray results = get_output(k); - bool is_bool = results.DataType().is_bool(); - outputs_ptr[k] = tvm::contrib::create_dlpack_tensor_ext(&results, is_bool); + outputs_ptr[k] = tvm::contrib::createDLpackTensorExt(&results); } return output_length; @@ -274,8 +256,8 @@ void tvm_contrib_torch_free_runtime_module(TVMContribTorchRuntimeModule* module_ } void tvm_contrib_torch_free_dlpack_tensor_ext_array(DLPackTensorExt* dlpack_ptr) { - delete dlpack_ptr; + delete[] dlpack_ptr; } -void tvm_contrib_torch_free_encoding(char* encoding) { delete encoding; } +void tvm_contrib_torch_free_encoding(char* encoding) { delete[] encoding; } } diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc index b071b439d0a0..a3252d6059df 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc @@ -85,7 +85,7 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { tensors.size()); for (int k = 0; k < input_length; ++k) { - if (tvm_contrib_torch_is_be_copied(&tensors[k])) { + if (tvm_contrib_torch_tensor_ability_of_zero_copy(&tensors[k])) { inputs[k].copy_(fromDLPackExt(tensors[k])); } else { tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor); @@ -163,7 +163,7 @@ class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { c10::List ret; ret.reserve(num_outputs); - for (int k = 0; k < num_outputs; ++k) { + for (size_t k = 0; k < num_outputs; ++k) { at::Tensor atTensor = fromDLPackExt(outputs[k]); ret.emplace_back(atTensor); } diff --git a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h index 772d71cdf190..58cd53a2840d 100644 --- a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h +++ b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h @@ -27,8 +27,10 @@ extern "C" { /* * DLPack data structure extend with `is_bool` flag. - * DLPack haven't support boolean tensor, - * thus a boolean tensor will be regarded as a UInt8 tensor. + * DLPack haven't support boolean tensor + * (https://github.com/pytorch/pytorch/blob/4618371da56c887195e2e1d16dad2b9686302800/aten/src/ATen/DLConvertor.cpp#L42), + * thus a boolean tensor will be regarded as a UInt8 tensor + * (https://github.com/apache/tvm/blob/de124862714e747764aa8b7f41a90bcb25f3c6a8/python/tvm/_ffi/runtime_ctypes.py#L91). */ struct DLPackTensorExt { DLManagedTensor* dl_managed_tensor; @@ -53,12 +55,12 @@ void tvm_contrib_torch_free_runtime_module(TVMContribTorchRuntimeModule* module_ /* * Obtain ExecutorFactory runtime module from ExecutorFactory class. - * @param graph_module ExecutorFactory class + * @param graph_executor_factory ExecutorFactory class * @param input_example For obtaining device information * @return ExecutorFactory TVM runtime module wrapper */ TVMContribTorchRuntimeModule* tvm_contrib_torch_create_graph_runtime_module( - TVMContribTorchRuntimeModule* graph_module, DLManagedTensor* input_example); + TVMContribTorchRuntimeModule* graph_executor_factory, DLManagedTensor* input_example); /* * Forward method for OperatorModuleWrapper. @@ -71,15 +73,15 @@ void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* run /* * Forward method for GraphExecutorFactoryWrapper. - * @param graph_module TVM runtime module wrapper + * @param graph_executor_factory TVM runtime module wrapper * @param inputs Array pointer of the input tensors * @param input_size The number of input tensors * @param outputs The resulting output tensors pointer * @return The number of output tensors */ -size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeModule* graph_module, - DLPackTensorExt* inputs, size_t input_size, - DLPackTensorExt** outputs); +size_t tvm_contrib_torch_graph_executor_module_forward( + TVMContribTorchRuntimeModule* graph_executor_factory, DLPackTensorExt* inputs, + size_t input_size, DLPackTensorExt** outputs); /* * Encode TVM runtime module. @@ -108,7 +110,7 @@ void tvm_contrib_torch_free_encoding(char* encoding); /* * Checking if a DLPackTensorExt is boolean or cannot be copied in zero cost. */ -bool tvm_contrib_torch_is_be_copied(DLPackTensorExt*); +bool tvm_contrib_torch_tensor_ability_of_zero_copy(DLPackTensorExt*); } #endif // TVM_CONTRIB_TORCH_TVM_MODULE_WRAPPER_RUNTIME_BRIDGE_H_ From 1e4af4708b916a4dc95b01af2273aed1569d8d63 Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 16 Aug 2022 06:59:01 -0700 Subject: [PATCH 31/33] reformat --- cmake/modules/contrib/PT_TVMDSOOP.cmake | 27 ++++++++++--------- .../RuntimeModuleWrapperTVM.cc | 23 +++++++++------- .../RuntimeModuleWrapperTorch.cc | 16 +++++------ 3 files changed, 36 insertions(+), 30 deletions(-) diff --git a/cmake/modules/contrib/PT_TVMDSOOP.cmake b/cmake/modules/contrib/PT_TVMDSOOP.cmake index b8b0f09e3640..a73d3f38e939 100644 --- a/cmake/modules/contrib/PT_TVMDSOOP.cmake +++ b/cmake/modules/contrib/PT_TVMDSOOP.cmake @@ -17,7 +17,6 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") find_package(PythonInterp REQUIRED) - execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import torch; print(torch.__path__[0].strip())" OUTPUT_VARIABLE PT_PATH RESULT_VARIABLE PT_STATUS) @@ -69,25 +68,29 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND) separate_arguments(PT_LINK_FLAGS UNIX_COMMAND ${PT_LINK_FLAGS_STR}) - set(LIBRARY_NAME pt_tvmdsoop) - set(LIBRARY_TORCH_NAME pt_tvmdsoop_new) + # This old version is depereated and will be removed after tvm 0.11 + set(LIBRARY_OLD_NAME pt_tvmdsoop) + + # This new library is set for pytorch integration, which solves the c++ abi imcompability issue + set(LIBRARY_NEW_NAME pt_tvmdsoop_new) tvm_file_glob(GLOB_RECURSE PTTVM_TORCH ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/*.cc) tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/*.cc) - add_library(${LIBRARY_NAME} SHARED ${PTTVM_SRCS}) - add_library(${LIBRARY_TORCH_NAME} SHARED ${PTTVM_TORCH}) + add_library(${LIBRARY_OLD_NAME} SHARED ${PTTVM_SRCS}) + add_library(${LIBRARY_NEW_NAME} SHARED ${PTTVM_TORCH}) set(PTTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR}) if(NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON") - add_dependencies(${LIBRARY_NAME} tvm) + add_dependencies(${LIBRARY_OLD_NAME} tvm) + add_dependencies(${LIBRARY_NEW_NAME} tvm) endif() - target_compile_options(${LIBRARY_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS}) - target_link_libraries(${LIBRARY_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS}) - target_compile_definitions(${LIBRARY_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=) + target_compile_options(${LIBRARY_OLD_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS}) + target_link_libraries(${LIBRARY_OLD_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS}) + target_compile_definitions(${LIBRARY_OLD_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=) - target_compile_options(${LIBRARY_TORCH_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS}) - target_link_libraries(${LIBRARY_TORCH_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS}) - target_compile_definitions(${LIBRARY_TORCH_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=) + target_compile_options(${LIBRARY_NEW_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS}) + target_link_libraries(${LIBRARY_NEW_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS}) + target_compile_definitions(${LIBRARY_NEW_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=) endif() diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index 7f2a84d27f46..fb570c163feb 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -35,6 +35,9 @@ namespace tvm { namespace contrib { +/* + * TVM's FFI for passing module from python to C++ + */ struct ThreadLocalStore { tvm::runtime::Module mod; static ThreadLocalStore* ThreadLocal() { @@ -100,7 +103,7 @@ TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime: * @param src Pointer to NDArray * @return DLPack extended tensor */ -DLPackTensorExt createDLpackTensorExt(tvm::runtime::NDArray* src) { +DLPackTensorExt CreateDLpackTensorExt(tvm::runtime::NDArray* src) { auto is_bool = src->DataType().is_bool(); DLManagedTensor* tensor; if (is_bool) { @@ -120,7 +123,7 @@ DLPackTensorExt createDLpackTensorExt(tvm::runtime::NDArray* src) { * @param src DLpack extended tensor * @return a new NDArray */ -tvm::runtime::NDArray createBoolNDarray(DLPackTensorExt* src) { +tvm::runtime::NDArray CreateBoolNDarray(DLPackTensorExt* src) { auto& tensor = src->dl_managed_tensor->dl_tensor; std::vector shape; for (int64_t i = 0; i < tensor.ndim; i++) { @@ -131,7 +134,7 @@ tvm::runtime::NDArray createBoolNDarray(DLPackTensorExt* src) { return std::move(ret); } -bool isZeroCopy(DLPackTensorExt* src) { +bool IsZeroCopy(DLPackTensorExt* src) { auto& dl_tensor = src->dl_managed_tensor->dl_tensor; return tvm::runtime::NDArray::AbilityOfZeroCopyForDLTensor(&dl_tensor, dl_tensor.device); } @@ -141,7 +144,7 @@ bool isZeroCopy(DLPackTensorExt* src) { * @param src DLpack extended tensor * @return a new NDArray */ -tvm::runtime::NDArray ndarrayFromDLpack(DLPackTensorExt* src) { +tvm::runtime::NDArray NDarrayFromDLpack(DLPackTensorExt* src) { using tvm::runtime::NDArray; NDArray array; @@ -149,8 +152,8 @@ tvm::runtime::NDArray ndarrayFromDLpack(DLPackTensorExt* src) { if (src->is_bool) { // one memory copy // the code is similar to NewFromDLTensor except for the type - array = createBoolNDarray(src); - } else if (isZeroCopy(src)) { + array = CreateBoolNDarray(src); + } else if (IsZeroCopy(src)) { array = NDArray::FromExternalDLTensor(src->dl_managed_tensor->dl_tensor); } else { // one memory copy @@ -171,7 +174,7 @@ struct TVMContribTorchRuntimeModule { }; bool tvm_contrib_torch_tensor_ability_of_zero_copy(DLPackTensorExt* src) { - return (src->is_bool) || (!tvm::contrib::isZeroCopy(src)); + return (!src->is_bool) && (tvm::contrib::IsZeroCopy(src)); } TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module() { @@ -189,7 +192,7 @@ void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* run std::vector input_cache(input_size); for (size_t k = 0; k < input_size; ++k) { - auto datum = tvm::contrib::ndarrayFromDLpack(&inputs[k]); // could have one memory copy + auto datum = tvm::contrib::NDarrayFromDLpack(&inputs[k]); // could have one memory copy input_cache[k] = datum; // we keep the datum in a vector for future use, otherwise the datum // will be freed after the loop setter(k, datum); @@ -199,7 +202,7 @@ void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* run nullptr); for (size_t k = 0; k < input_size; ++k) { - if (tvm_contrib_torch_tensor_ability_of_zero_copy(&inputs[k])) + if (!tvm_contrib_torch_tensor_ability_of_zero_copy(&inputs[k])) input_cache[k].CopyTo(&inputs[k].dl_managed_tensor->dl_tensor); } } @@ -233,7 +236,7 @@ size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeMod for (int64_t k = 0; k < output_length; ++k) { tvm::runtime::NDArray results = get_output(k); - outputs_ptr[k] = tvm::contrib::createDLpackTensorExt(&results); + outputs_ptr[k] = tvm::contrib::CreateDLpackTensorExt(&results); } return output_length; diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc index a3252d6059df..b5496ffba38e 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc @@ -33,9 +33,9 @@ namespace contrib { * @param src Torch tensor * @return DLPack extended tensor */ -DLPackTensorExt toDLPackExt(const at::Tensor& src) { +DLPackTensorExt ToDLPackExt(const at::Tensor& src) { if (!src.is_contiguous()) { - return toDLPackExt(src.contiguous()); + return ToDLPackExt(src.contiguous()); } DLPackTensorExt ret; if (src.dtype().isScalarType(torch::kBool)) { @@ -55,7 +55,7 @@ DLPackTensorExt toDLPackExt(const at::Tensor& src) { * @param src DLPack extended tensor * @return Torch tensor */ -at::Tensor fromDLPackExt(const DLPackTensorExt& src) { +at::Tensor FromDLPackExt(const DLPackTensorExt& src) { if (src.is_bool) { return at::fromDLPack(src.dl_managed_tensor).toType(torch::kBool); } else { @@ -80,13 +80,13 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { // Torch tensor supports boolean type while DLpack does not, // we convert Torch tensor to an extension of DLPack tensor - for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPackExt(inputs[i])); + for (int i = 0; i < input_length; ++i) tensors.push_back(ToDLPackExt(inputs[i])); tvm_contrib_torch_operator_module_forward(this->runtime_module_, tensors.data(), tensors.size()); for (int k = 0; k < input_length; ++k) { - if (tvm_contrib_torch_tensor_ability_of_zero_copy(&tensors[k])) { - inputs[k].copy_(fromDLPackExt(tensors[k])); + if (!tvm_contrib_torch_tensor_ability_of_zero_copy(&tensors[k])) { + inputs[k].copy_(FromDLPackExt(tensors[k])); } else { tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor); } @@ -150,7 +150,7 @@ class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { // Torch tensor supports boolean type while DLpack does not, // we convert Torch tensor to an extension of DLPack tensor - for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPackExt(inputs[i])); + for (int i = 0; i < input_length; ++i) tensors.push_back(ToDLPackExt(inputs[i])); DLPackTensorExt* outputs; if (executor_factory_runtime_ == nullptr) { @@ -164,7 +164,7 @@ class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { ret.reserve(num_outputs); for (size_t k = 0; k < num_outputs; ++k) { - at::Tensor atTensor = fromDLPackExt(outputs[k]); + at::Tensor atTensor = FromDLPackExt(outputs[k]); ret.emplace_back(atTensor); } From 10f791fdd3cb872f9ca8974d3ef269a056434fd7 Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 16 Aug 2022 22:00:17 -0700 Subject: [PATCH 32/33] polish --- .../torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc index b5496ffba38e..3159438d7202 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc @@ -85,10 +85,12 @@ class OperatorModuleWrapper : public torch::jit::CustomClassHolder { tensors.size()); for (int k = 0; k < input_length; ++k) { - if (!tvm_contrib_torch_tensor_ability_of_zero_copy(&tensors[k])) { - inputs[k].copy_(FromDLPackExt(tensors[k])); - } else { + if (tvm_contrib_torch_tensor_ability_of_zero_copy(&tensors[k])) { + // We need to free memory manually tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor); + } else { + // Ownership transferred + inputs[k].copy_(FromDLPackExt(tensors[k])); } } } From 457755c97480d4b8d8a1297870985f4e1a46eb42 Mon Sep 17 00:00:00 2001 From: juda Date: Tue, 16 Aug 2022 22:07:47 -0700 Subject: [PATCH 33/33] remove redundant import --- apps/pt_tvmdsoop/tests/test_boolean_tensor.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/apps/pt_tvmdsoop/tests/test_boolean_tensor.py b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py index feed3e2c4e6c..4718b4043945 100644 --- a/apps/pt_tvmdsoop/tests/test_boolean_tensor.py +++ b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py @@ -19,13 +19,11 @@ """Test script for boolean tensor support""" import tempfile -import numpy as np import torch import tvm import tvm.testing from tvm.contrib.torch import as_torch, optimize_torch -from tvm.meta_schedule.tune import TuneConfig from tvm.script import tir as T