From 1e50a03ad4b34f68f5aca267b24919f1e6aed264 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 15 Sep 2021 11:26:05 +0200 Subject: [PATCH 1/3] initial stab at TorchScript fallback --- CMakeLists.txt | 2 + cmake/config.cmake | 5 + cmake/modules/contrib/LibTorch.cmake | 30 +++ python/tvm/relay/op/contrib/__init__.py | 1 + python/tvm/relay/op/contrib/libtorch.py | 36 ++++ .../contrib/libtorch/libtorch_codegen.cc | 147 +++++++++++++++ src/relay/transforms/type_infer.cc | 48 +++-- .../contrib/libtorch/libtorch_runtime.cc | 173 ++++++++++++++++++ tests/python/contrib/test_libtorch_ops.py | 81 ++++++++ 9 files changed, 510 insertions(+), 13 deletions(-) create mode 100644 cmake/modules/contrib/LibTorch.cmake create mode 100644 python/tvm/relay/op/contrib/libtorch.py create mode 100644 src/relay/backend/contrib/libtorch/libtorch_codegen.cc create mode 100644 src/runtime/contrib/libtorch/libtorch_runtime.cc create mode 100644 tests/python/contrib/test_libtorch_ops.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 62598cbdf4a7..599f66c1e23b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -73,6 +73,7 @@ tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF) tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF) tvm_option(USE_SORT "Build with sort support" ON) tvm_option(USE_NNPACK "Build with nnpack support" OFF) +tvm_option(USE_LIBTORCH "Build with libtorch support" OFF) tvm_option(USE_RANDOM "Build with random support" ON) tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF) tvm_option(USE_CPP_RPC "Build CPP RPC" OFF) @@ -412,6 +413,7 @@ include(cmake/modules/contrib/Posit.cmake) include(cmake/modules/contrib/MicroStandaloneRuntime.cmake) include(cmake/modules/contrib/Sort.cmake) include(cmake/modules/contrib/NNPack.cmake) +include(cmake/modules/contrib/LibTorch.cmake) include(cmake/modules/contrib/HybridDump.cmake) include(cmake/modules/contrib/TFLite.cmake) include(cmake/modules/contrib/TF_TVMDSOOP.cmake) diff --git a/cmake/config.cmake b/cmake/config.cmake index 8d8186c1b4f0..bf9db180e70f 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -326,3 +326,8 @@ set(USE_CCACHE AUTO) # - OFF: disable PAPI support. # - /path/to/folder/containing/: Path to folder containing papi.pc. set(USE_PAPI OFF) + +# Whether to use LibTorch as backend +# To enable pass the path to the root libtorch (or PyTorch) directory +# OFF or /path/to/torch/ +set(USE_LIBTORCH OFF) diff --git a/cmake/modules/contrib/LibTorch.cmake b/cmake/modules/contrib/LibTorch.cmake new file mode 100644 index 000000000000..0e00e7008c2f --- /dev/null +++ b/cmake/modules/contrib/LibTorch.cmake @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_LIBTORCH) + find_package(Torch REQUIRED PATHS ${USE_LIBTORCH}/share/cmake/Torch + ) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${TORCH_LIBRARIES}) + include_directories(${TORCH_INCLUDE_DIRS}) + + file(GLOB LIBTORCH_RELAY_CONTRIB_SRC + src/relay/backend/contrib/libtorch/libtorch_codegen.cc + src/runtime/contrib/libtorch/libtorch_runtime.cc + ) + list(APPEND COMPILER_SRCS ${LIBTORCH_RELAY_CONTRIB_SRC}) + +endif(USE_LIBTORCH) diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index 30c2db0ddf0b..bc2d212444ff 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -23,4 +23,5 @@ from .bnns import * from .coreml import * from .ethosn import * +from .libtorch import * from .tensorrt import * diff --git a/python/tvm/relay/op/contrib/libtorch.py b/python/tvm/relay/op/contrib/libtorch.py new file mode 100644 index 000000000000..2827c2abd88b --- /dev/null +++ b/python/tvm/relay/op/contrib/libtorch.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, no-else-return, E1102 +"""Torch codegen operators""" + +from tvm import relay +from tvm.relay.op.annotation import compiler_begin, compiler_end + + +def torchop(script_fn, *params): + """Insert an Operation executed in the PyTorch JIT + + The operation includes backend annotation + + Currently, only tensors are supported. The shape inferrence + assumes that input shapes (and not values) determine output shapes.""" + return compiler_end( + relay.op._make.torchop( + [compiler_begin(p, "torch") for p in params], script_fn.save_to_buffer() + ), + "torch", + ) diff --git a/src/relay/backend/contrib/libtorch/libtorch_codegen.cc b/src/relay/backend/contrib/libtorch/libtorch_codegen.cc new file mode 100644 index 000000000000..a05e6b672786 --- /dev/null +++ b/src/relay/backend/contrib/libtorch/libtorch_codegen.cc @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/libtorch/codegen.cc + * \brief Implementation of libtorch codegen. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../../utils.h" + +namespace tvm { +namespace relay { +namespace contrib { + +using namespace backend; + +/*! \brief Attributes of a TorchFunction node */ +struct TorchFunctionAttrs : public tvm::AttrsNode { + std::string serialized_function; + int64_t len; + + TVM_DECLARE_ATTRS(TorchFunctionAttrs, "relay.attrs.TorchFunctionAttrs") { + TVM_ATTR_FIELD(serialized_function).set_default("").describe("Function from fn.save(...)"); + TVM_ATTR_FIELD(len).set_default(-1).describe("Function from fn.save(...)"); + } +}; + +TVM_REGISTER_NODE_TYPE(TorchFunctionAttrs); + +bool TorchOpRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const auto* sfattrs = attrs.as(); + std::stringstream str(sfattrs->serialized_function); + torch::jit::Module mod = torch::jit::load(str); + + std::vector inputs; + for (int i = 0; i < num_inputs; i++) { + auto* ty = types[i].as(); + ICHECK(ty) << "only accept tensors as inputs"; + std::vector shape; + for (const auto& s : ty->shape) { + auto* si = s.as(); + if (!si) { + return false; + } + shape.push_back(si->value); + } + ICHECK(ty->dtype == DataType::Float(32)) + << "only float supported"; // TODO(my_username): replace with lookup map + inputs.emplace_back(torch::randn(shape)); + } + auto res = mod.forward(inputs); + auto res_t = res.toTensor(); + ICHECK((int)types.size() == num_inputs + 1) << "only single output supported"; + ICHECK(res_t.dtype() == c10::kFloat); + Array res_sizes; + for (int d = 0; d < res_t.dim(); d++) { + res_sizes.push_back(IntImm(DataType::Int(32), res_t.size(d))); + } + reporter->Assign(types[num_inputs], TensorType(res_sizes, DataType::Float(32))); + return true; +} + +RELAY_REGISTER_OP("torch_op") + .set_support_level(99) + .add_type_rel("TorchOpRel", TorchOpRel) + .set_attrs_type(); + +Expr MakeTorchOp(Array args, const std::string& serialized_function) { + static const Op& op = Op::Get("torch_op"); + auto attrs = make_object(); + attrs->serialized_function = serialized_function; + attrs->len = serialized_function.size(); + return Call(op, args, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.torchop").set_body_typed(MakeTorchOp); + +/*! + * \brief The external compiler/codegen tool. It takes a Relay expression/module and + * compile it into a runtime module. + */ +runtime::Module TorchCompiler(const ObjectRef& ref) { + ICHECK(ref->IsInstance()) << "The input ref is expected to be a Relay function."; + Function func = Downcast(ref); + std::string func_name = backend::GetExtSymbol(func); + + ICHECK(func.defined()) << "Input error: expect a Relay function."; + const auto* call = func->body.as(); + ICHECK(call) << "Expected call node\n"; + const auto* op_node = call->op.as(); + ICHECK(op_node) << "Expect OpNode, but got " << call->op->GetTypeKey(); + const auto op_name = GetRef(op_node)->name; + ICHECK(op_name == "torch_op") << "Unsupported op: " << AsText(call->op, false) << "\n"; + + const auto* attrs = call->attrs.as(); + + // TensorRTJSONSerializer serializer(func_name, func); + // serializer.serialize(); + // std::string graph_json = serializer.GetJSON(); + + const auto* pf = runtime::Registry::Get("runtime.torch_runtime_create"); + ICHECK(pf != nullptr) << "Cannot find Torch runtime module create function."; + TVMByteArray serialized_function{attrs->serialized_function.c_str(), + attrs->serialized_function.length()}; + runtime::Module lib = (*pf)(func_name, serialized_function); + return lib; +} + +TVM_REGISTER_GLOBAL("relay.ext.torch").set_body_typed(TorchCompiler); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 6c2371716b16..c3da2291cb5f 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -499,27 +499,49 @@ class TypeInferencer : private ExprFunctor, size_t type_arity = fn_ty->arg_types.size(); size_t number_of_args = arg_types.size(); + bool is_variable = false; - if (type_arity != number_of_args) { - if (type_arity < number_of_args) { - this->EmitFatal(Diagnostic::Error(call->span) - << "the function is provided too many arguments " - << "expected " << type_arity << ", found " << number_of_args); - } else { - this->EmitFatal(Diagnostic::Error(call->span) - << "the function is provided too few arguments " - << "expected " << type_arity << ", found " << number_of_args); + if (const OpNode* opnode = call->op.as()) { + if (opnode->num_inputs == -1) { + is_variable = true; } } - for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { - this->Unify(fn_ty->arg_types[i], arg_types[i], call->span, true, false); + if ((type_arity < number_of_args) && !is_variable) { + this->EmitFatal(Diagnostic::Error(call->span) + << "the function is provided too many arguments " + << "expected " << type_arity << ", found " << number_of_args); + } else if (type_arity > number_of_args) { + this->EmitFatal(Diagnostic::Error(call->span) + << "the function is provided too few arguments " + << "expected " << type_arity << ", found " << number_of_args); } + Array unified_arg_types; + if (!is_variable) { + for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { + this->Unify(fn_ty->arg_types[i], arg_types[i], call->span, true, false); + } + } else { + for (size_t i = 0; i < number_of_args; i++) { + if (i < fn_ty->arg_types.size()) { + unified_arg_types.push_back( + this->Unify(fn_ty->arg_types[i], arg_types[i], call->span, false, false)); + } else { + unified_arg_types.push_back(arg_types[i]); + } + } + unified_arg_types.push_back(fn_ty->ret_type); + } for (auto cs : fn_ty->type_constraints) { if (const auto* tr = cs.as()) { - solver_.AddConstraint(TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs), - call->span); + if (!is_variable) { + solver_.AddConstraint(TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs), + call->span); + } else { + solver_.AddConstraint( + TypeRelation(tr->func, unified_arg_types, number_of_args, call->attrs), call->span); + } } else { solver_.AddConstraint(cs, call->span); } diff --git a/src/runtime/contrib/libtorch/libtorch_runtime.cc b/src/runtime/contrib/libtorch/libtorch_runtime.cc new file mode 100644 index 000000000000..75fde9ac538c --- /dev/null +++ b/src/runtime/contrib/libtorch/libtorch_runtime.cc @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/libtorch/libtorch_runtime.cc + * \brief runtime implementation for LibTorch/TorchScript. + */ + +// we do not want clang to reorder our includes +// clang-format off +#include +#include +#include + +#include +#include +#include +#include + +// clang-format on + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace contrib { + +static void monly_deleter(DLManagedTensor* self) { delete self; } + +void run_torch_module(torch::jit::Module* module, TVMArgs args, TVMRetValue* rv) { + std::vector inputs; + std::vector outputs; + auto m = module->get_method("forward"); + for (int i = 0; i < args.size(); i++) { + const DLTensor* arg; + if (args[i].IsObjectRef()) { + NDArray arr = args[i]; + arg = arr.operator->(); + } else { + arg = args[i].operator DLTensor*(); + } + DLManagedTensor* inp = new DLManagedTensor{}; + inp->dl_tensor = *arg; + inp->deleter = &monly_deleter; + // m.num_inputs includes the self argument of forward(self, ...) + // num_inputs - 1 is the number of (Tensor) inputs + if (i < static_cast(m.num_inputs()) - 1) { + inputs.emplace_back(at::fromDLPack(inp)); + } else { + outputs.emplace_back(at::fromDLPack(inp)); + } + } + ICHECK(outputs.size() == 1) << "wrong number of args, can handle only one output"; + torch::Tensor res = module->forward(inputs).toTensor(); + outputs[0].copy_(res); // too bad + // what to do about rv? + // NDArray res_array = NDArray::FromDLPack(at::toDLPack(res)); + // *rv = res_array; +} + +/*! + * \brief A json runtime that executes the serialized JSON format. This runtime + * can be extended by user defined runtime for execution. + */ +class TorchModuleNode : public ModuleNode { + public: + TorchModuleNode(const std::string& symbol_name, const torch::jit::Module& module) + : symbol_name_(symbol_name), module_(module) {} + + const char* type_key() const { return "torch"; } + + /*! + * \brief Get a packed function. + * \param name The name/symbol of the function. + * \param sptr_to_self The pointer to the module node. + * \return The packed function. + */ + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "get_symbol") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_name_; }); + } else if (name == "get_const_vars") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = Array{}; }); + } else if (this->symbol_name_ == name) { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + run_torch_module(&module_, args, rv); + }); + } else if ("__init_" + this->symbol_name_ == name) { + // The function to initialize constant tensors. + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 0; }); + } else { + return PackedFunc(nullptr); + } + } + + virtual void SaveToBinary(dmlc::Stream* stream) { + // Save the symbol + stream->Write(symbol_name_); + // Save the module + std::stringstream str; + module_.save(str); + stream->Write(str.str()); + } + + static Module LoadFromBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + std::string symbol; + std::string module_str; + // Load the symbol + ICHECK(stream->Read(&symbol)) << "Loading symbol name failed"; + ICHECK(stream->Read(&module_str)) << "Loading module str failed"; + std::stringstream str(module_str); + torch::jit::Module mod = torch::jit::load(str); + auto n = make_object(symbol, mod); + return Module(n); + } + + /*! + * \brief Get the source generated by codegen. + * + * \param format the format to return. + * \return A string of JSON. + */ + std::string GetSource(const std::string& format = "json") override { + return module_.dump_to_str(true, true, true); + } + + protected: + /*! \brief The only subgraph name for this module. */ + std::string symbol_name_; + /*! \brief Module. */ + torch::jit::Module module_; +}; + +runtime::Module TorchRuntimeCreate(const String& symbol_name, + const std::string& serialized_function) { + std::stringstream str(serialized_function); + torch::jit::Module mod = torch::jit::load(str); + auto n = make_object(symbol_name, mod); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.torch_runtime_create").set_body_typed(TorchRuntimeCreate); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_torch") + .set_body_typed(TorchModuleNode::LoadFromBinary); + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/tests/python/contrib/test_libtorch_ops.py b/tests/python/contrib/test_libtorch_ops.py new file mode 100644 index 000000000000..e1ca3ed69178 --- /dev/null +++ b/tests/python/contrib/test_libtorch_ops.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm.relay +from tvm.relay.op.contrib import torchop + +try: + import torch +except ImportError as _: + torch = None + + +@pytest.mark.skipif(torch is None, reason="PyTorch is not available") +def test_backend(): + @torch.jit.script + def script_fn(x, y): + res = x * y + return res + + x2 = tvm.relay.var("x", shape=[1, 2]) + y2 = tvm.relay.var("y", shape=[2, 2]) + + x3 = tvm.relay.var("x", shape=[1, 3]) + y3 = tvm.relay.var("y", shape=[3, 3]) + + test_body = tvm.relay.sum(torchop(script_fn, x2, y2)) + tvm.relay.sum( + torchop(script_fn, x3, y3) + ) + test_fn = tvm.relay.Function([x2, y2, x3, y3], test_body) + mod = tvm.IRModule({"main": test_fn}) + + tvm.relay.transform.InferType()(mod) + + # mod = tvm.relay.transform.AnnotateTarget("target.torch")(mod) + mod = tvm.relay.transform.MergeCompilerRegions()(mod) + mod = tvm.relay.transform.PartitionGraph()(mod) + mod = tvm.relay.transform.InferType()(mod) + + target = "llvm" + with tvm.transform.PassContext(opt_level=3): + lib = tvm.relay.build(mod, target, params={}) + + ctx = tvm.cpu(0) + rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](ctx)) + + x2t = torch.randn(1, 2) + y2t = torch.randn(2, 2) + x3t = torch.randn(1, 3) + y3t = torch.randn(3, 3) + # Set inputs + rt_mod.set_input(0, x2t) + rt_mod.set_input(1, y2t) + rt_mod.set_input(2, x3t) + rt_mod.set_input(3, y3t) + # Execute + rt_mod.run() + # Get outputs + tvm_output = rt_mod.get_output(0).asnumpy() + expected = (script_fn(x2t, y2t).sum() + script_fn(x3t, y3t).sum()).numpy() + + tvm.testing.assert_allclose(tvm_output, expected) + + +if __name__ == "__main__": + pytest.main([__file__]) From 84aa0fd98ebe311bf59e450e36171646071c4e95 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 8 Mar 2022 15:38:25 +0100 Subject: [PATCH 2/3] make types more flexible --- .../contrib/libtorch/libtorch_codegen.cc | 10 +- tests/python/contrib/test_libtorch_ops.py | 91 ++++++++++--------- 2 files changed, 54 insertions(+), 47 deletions(-) diff --git a/src/relay/backend/contrib/libtorch/libtorch_codegen.cc b/src/relay/backend/contrib/libtorch/libtorch_codegen.cc index a05e6b672786..424352aed2de 100644 --- a/src/relay/backend/contrib/libtorch/libtorch_codegen.cc +++ b/src/relay/backend/contrib/libtorch/libtorch_codegen.cc @@ -22,6 +22,7 @@ * \brief Implementation of libtorch codegen. */ +#include #include #include #include @@ -78,19 +79,18 @@ bool TorchOpRel(const Array& types, int num_inputs, const Attrs& attrs, } shape.push_back(si->value); } - ICHECK(ty->dtype == DataType::Float(32)) - << "only float supported"; // TODO(my_username): replace with lookup map - inputs.emplace_back(torch::randn(shape)); + auto torchScalarType = at::toScalarType(ty->dtype); + + inputs.emplace_back(torch::zeros(shape, at::TensorOptions().dtype(torchScalarType))); } auto res = mod.forward(inputs); auto res_t = res.toTensor(); ICHECK((int)types.size() == num_inputs + 1) << "only single output supported"; - ICHECK(res_t.dtype() == c10::kFloat); Array res_sizes; for (int d = 0; d < res_t.dim(); d++) { res_sizes.push_back(IntImm(DataType::Int(32), res_t.size(d))); } - reporter->Assign(types[num_inputs], TensorType(res_sizes, DataType::Float(32))); + reporter->Assign(types[num_inputs], TensorType(res_sizes, DataType(at::getDLDataType(res_t)))); return true; } diff --git a/tests/python/contrib/test_libtorch_ops.py b/tests/python/contrib/test_libtorch_ops.py index e1ca3ed69178..751a547f94f5 100644 --- a/tests/python/contrib/test_libtorch_ops.py +++ b/tests/python/contrib/test_libtorch_ops.py @@ -33,48 +33,55 @@ def script_fn(x, y): res = x * y return res - x2 = tvm.relay.var("x", shape=[1, 2]) - y2 = tvm.relay.var("y", shape=[2, 2]) - - x3 = tvm.relay.var("x", shape=[1, 3]) - y3 = tvm.relay.var("y", shape=[3, 3]) - - test_body = tvm.relay.sum(torchop(script_fn, x2, y2)) + tvm.relay.sum( - torchop(script_fn, x3, y3) - ) - test_fn = tvm.relay.Function([x2, y2, x3, y3], test_body) - mod = tvm.IRModule({"main": test_fn}) - - tvm.relay.transform.InferType()(mod) - - # mod = tvm.relay.transform.AnnotateTarget("target.torch")(mod) - mod = tvm.relay.transform.MergeCompilerRegions()(mod) - mod = tvm.relay.transform.PartitionGraph()(mod) - mod = tvm.relay.transform.InferType()(mod) - - target = "llvm" - with tvm.transform.PassContext(opt_level=3): - lib = tvm.relay.build(mod, target, params={}) - - ctx = tvm.cpu(0) - rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](ctx)) - - x2t = torch.randn(1, 2) - y2t = torch.randn(2, 2) - x3t = torch.randn(1, 3) - y3t = torch.randn(3, 3) - # Set inputs - rt_mod.set_input(0, x2t) - rt_mod.set_input(1, y2t) - rt_mod.set_input(2, x3t) - rt_mod.set_input(3, y3t) - # Execute - rt_mod.run() - # Get outputs - tvm_output = rt_mod.get_output(0).asnumpy() - expected = (script_fn(x2t, y2t).sum() + script_fn(x3t, y3t).sum()).numpy() - - tvm.testing.assert_allclose(tvm_output, expected) + for torch_dt, dt in ( + (torch.int32, "int32"), + (torch.float32, "float32"), + (torch.float64, "float64"), + ): + x2 = tvm.relay.var("x", shape=[1, 2], dtype=dt) + y2 = tvm.relay.var("y", shape=[2, 2], dtype=dt) + + x3 = tvm.relay.var("x", shape=[1, 3], dtype=dt) + y3 = tvm.relay.var("y", shape=[3, 3], dtype=dt) + + test_body = tvm.relay.sum(torchop(script_fn, x2, y2)) + tvm.relay.sum( + torchop(script_fn, x3, y3) + ) + test_fn = tvm.relay.Function([x2, y2, x3, y3], test_body) + mod = tvm.IRModule({"main": test_fn}) + + tvm.relay.transform.InferType()(mod) + + # mod = tvm.relay.transform.AnnotateTarget("target.torch")(mod) + mod = tvm.relay.transform.MergeCompilerRegions()(mod) + mod = tvm.relay.transform.PartitionGraph()(mod) + mod = tvm.relay.transform.InferType()(mod) + + target = "llvm" + with tvm.transform.PassContext(opt_level=3): + lib = tvm.relay.build(mod, target, params={}) + + ctx = tvm.cpu(0) + rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](ctx)) + + # int does not have randn, so we cast... + x2t = torch.randn(1, 2).to(dtype=torch_dt) + y2t = torch.randn(2, 2).to(dtype=torch_dt) + x3t = torch.randn(1, 3).to(dtype=torch_dt) + y3t = torch.randn(3, 3).to(dtype=torch_dt) + # Set inputs + rt_mod.set_input(0, x2t) + rt_mod.set_input(1, y2t) + rt_mod.set_input(2, x3t) + rt_mod.set_input(3, y3t) + # Execute + rt_mod.run() + # Get outputs + tvm_output = rt_mod.get_output(0).numpy() + expected = (script_fn(x2t, y2t).sum() + script_fn(x3t, y3t).sum()).numpy() + print(tvm_output.dtype) + print(expected.dtype) + tvm.testing.assert_allclose(tvm_output, expected) if __name__ == "__main__": From ad25aa6cc86d94bd73b1de1e7b9c81d3e625aa33 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 9 Mar 2022 13:32:25 +0100 Subject: [PATCH 3/3] Easy review bits. Thank you @masahi --- .../backend/contrib/libtorch/libtorch_codegen.cc | 13 ++----------- src/runtime/contrib/libtorch/libtorch_runtime.cc | 5 ----- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/src/relay/backend/contrib/libtorch/libtorch_codegen.cc b/src/relay/backend/contrib/libtorch/libtorch_codegen.cc index 424352aed2de..25bfbfad4443 100644 --- a/src/relay/backend/contrib/libtorch/libtorch_codegen.cc +++ b/src/relay/backend/contrib/libtorch/libtorch_codegen.cc @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -127,17 +128,7 @@ runtime::Module TorchCompiler(const ObjectRef& ref) { ICHECK(op_name == "torch_op") << "Unsupported op: " << AsText(call->op, false) << "\n"; const auto* attrs = call->attrs.as(); - - // TensorRTJSONSerializer serializer(func_name, func); - // serializer.serialize(); - // std::string graph_json = serializer.GetJSON(); - - const auto* pf = runtime::Registry::Get("runtime.torch_runtime_create"); - ICHECK(pf != nullptr) << "Cannot find Torch runtime module create function."; - TVMByteArray serialized_function{attrs->serialized_function.c_str(), - attrs->serialized_function.length()}; - runtime::Module lib = (*pf)(func_name, serialized_function); - return lib; + return tvm::runtime::contrib::TorchRuntimeCreate(func_name, attrs->serialized_function); } TVM_REGISTER_GLOBAL("relay.ext.torch").set_body_typed(TorchCompiler); diff --git a/src/runtime/contrib/libtorch/libtorch_runtime.cc b/src/runtime/contrib/libtorch/libtorch_runtime.cc index 75fde9ac538c..5076b967a1de 100644 --- a/src/runtime/contrib/libtorch/libtorch_runtime.cc +++ b/src/runtime/contrib/libtorch/libtorch_runtime.cc @@ -75,9 +75,6 @@ void run_torch_module(torch::jit::Module* module, TVMArgs args, TVMRetValue* rv) ICHECK(outputs.size() == 1) << "wrong number of args, can handle only one output"; torch::Tensor res = module->forward(inputs).toTensor(); outputs[0].copy_(res); // too bad - // what to do about rv? - // NDArray res_array = NDArray::FromDLPack(at::toDLPack(res)); - // *rv = res_array; } /*! @@ -163,8 +160,6 @@ runtime::Module TorchRuntimeCreate(const String& symbol_name, return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.torch_runtime_create").set_body_typed(TorchRuntimeCreate); - TVM_REGISTER_GLOBAL("runtime.module.loadbinary_torch") .set_body_typed(TorchModuleNode::LoadFromBinary);