diff --git a/CMakeLists.txt b/CMakeLists.txt index 9c6a7dddfdf6..658f9963fe46 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,7 @@ tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF) tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF) tvm_option(USE_SORT "Build with sort support" ON) tvm_option(USE_NNPACK "Build with nnpack support" OFF) +tvm_option(USE_LIBTORCH "Build with libtorch support" OFF) tvm_option(USE_RANDOM "Build with random support" ON) tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF) tvm_option(USE_CPP_RPC "Build CPP RPC" OFF) @@ -455,6 +456,7 @@ include(cmake/modules/contrib/Posit.cmake) include(cmake/modules/contrib/MicroStandaloneRuntime.cmake) include(cmake/modules/contrib/Sort.cmake) include(cmake/modules/contrib/NNPack.cmake) +include(cmake/modules/contrib/LibTorch.cmake) include(cmake/modules/contrib/HybridDump.cmake) include(cmake/modules/contrib/TFLite.cmake) include(cmake/modules/contrib/TF_TVMDSOOP.cmake) diff --git a/cmake/config.cmake b/cmake/config.cmake index 62eeb34fead7..f6d99af503e4 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -364,3 +364,8 @@ set(USE_CUTLASS OFF) # Enable to show a summary of TVM options set(SUMMARIZE OFF) + +# Whether to use LibTorch as backend +# To enable pass the path to the root libtorch (or PyTorch) directory +# OFF or /path/to/torch/ +set(USE_LIBTORCH OFF) diff --git a/cmake/modules/contrib/LibTorch.cmake b/cmake/modules/contrib/LibTorch.cmake new file mode 100644 index 000000000000..0e00e7008c2f --- /dev/null +++ b/cmake/modules/contrib/LibTorch.cmake @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_LIBTORCH) + find_package(Torch REQUIRED PATHS ${USE_LIBTORCH}/share/cmake/Torch + ) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${TORCH_LIBRARIES}) + include_directories(${TORCH_INCLUDE_DIRS}) + + file(GLOB LIBTORCH_RELAY_CONTRIB_SRC + src/relay/backend/contrib/libtorch/libtorch_codegen.cc + src/runtime/contrib/libtorch/libtorch_runtime.cc + ) + list(APPEND COMPILER_SRCS ${LIBTORCH_RELAY_CONTRIB_SRC}) + +endif(USE_LIBTORCH) diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index 1dd6da6c2747..a03d0f6d4f1c 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -23,5 +23,6 @@ from .bnns import * from .coreml import * from .ethosn import * +from .libtorch import * from .tensorrt import * from .cutlass import * diff --git a/python/tvm/relay/op/contrib/libtorch.py b/python/tvm/relay/op/contrib/libtorch.py new file mode 100644 index 000000000000..2827c2abd88b --- /dev/null +++ b/python/tvm/relay/op/contrib/libtorch.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, no-else-return, E1102 +"""Torch codegen operators""" + +from tvm import relay +from tvm.relay.op.annotation import compiler_begin, compiler_end + + +def torchop(script_fn, *params): + """Insert an Operation executed in the PyTorch JIT + + The operation includes backend annotation + + Currently, only tensors are supported. The shape inferrence + assumes that input shapes (and not values) determine output shapes.""" + return compiler_end( + relay.op._make.torchop( + [compiler_begin(p, "torch") for p in params], script_fn.save_to_buffer() + ), + "torch", + ) diff --git a/src/relay/backend/contrib/libtorch/libtorch_codegen.cc b/src/relay/backend/contrib/libtorch/libtorch_codegen.cc new file mode 100644 index 000000000000..25bfbfad4443 --- /dev/null +++ b/src/relay/backend/contrib/libtorch/libtorch_codegen.cc @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/libtorch/codegen.cc + * \brief Implementation of libtorch codegen. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../../utils.h" + +namespace tvm { +namespace relay { +namespace contrib { + +using namespace backend; + +/*! \brief Attributes of a TorchFunction node */ +struct TorchFunctionAttrs : public tvm::AttrsNode { + std::string serialized_function; + int64_t len; + + TVM_DECLARE_ATTRS(TorchFunctionAttrs, "relay.attrs.TorchFunctionAttrs") { + TVM_ATTR_FIELD(serialized_function).set_default("").describe("Function from fn.save(...)"); + TVM_ATTR_FIELD(len).set_default(-1).describe("Function from fn.save(...)"); + } +}; + +TVM_REGISTER_NODE_TYPE(TorchFunctionAttrs); + +bool TorchOpRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const auto* sfattrs = attrs.as(); + std::stringstream str(sfattrs->serialized_function); + torch::jit::Module mod = torch::jit::load(str); + + std::vector inputs; + for (int i = 0; i < num_inputs; i++) { + auto* ty = types[i].as(); + ICHECK(ty) << "only accept tensors as inputs"; + std::vector shape; + for (const auto& s : ty->shape) { + auto* si = s.as(); + if (!si) { + return false; + } + shape.push_back(si->value); + } + auto torchScalarType = at::toScalarType(ty->dtype); + + inputs.emplace_back(torch::zeros(shape, at::TensorOptions().dtype(torchScalarType))); + } + auto res = mod.forward(inputs); + auto res_t = res.toTensor(); + ICHECK((int)types.size() == num_inputs + 1) << "only single output supported"; + Array res_sizes; + for (int d = 0; d < res_t.dim(); d++) { + res_sizes.push_back(IntImm(DataType::Int(32), res_t.size(d))); + } + reporter->Assign(types[num_inputs], TensorType(res_sizes, DataType(at::getDLDataType(res_t)))); + return true; +} + +RELAY_REGISTER_OP("torch_op") + .set_support_level(99) + .add_type_rel("TorchOpRel", TorchOpRel) + .set_attrs_type(); + +Expr MakeTorchOp(Array args, const std::string& serialized_function) { + static const Op& op = Op::Get("torch_op"); + auto attrs = make_object(); + attrs->serialized_function = serialized_function; + attrs->len = serialized_function.size(); + return Call(op, args, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.torchop").set_body_typed(MakeTorchOp); + +/*! + * \brief The external compiler/codegen tool. It takes a Relay expression/module and + * compile it into a runtime module. + */ +runtime::Module TorchCompiler(const ObjectRef& ref) { + ICHECK(ref->IsInstance()) << "The input ref is expected to be a Relay function."; + Function func = Downcast(ref); + std::string func_name = backend::GetExtSymbol(func); + + ICHECK(func.defined()) << "Input error: expect a Relay function."; + const auto* call = func->body.as(); + ICHECK(call) << "Expected call node\n"; + const auto* op_node = call->op.as(); + ICHECK(op_node) << "Expect OpNode, but got " << call->op->GetTypeKey(); + const auto op_name = GetRef(op_node)->name; + ICHECK(op_name == "torch_op") << "Unsupported op: " << AsText(call->op, false) << "\n"; + + const auto* attrs = call->attrs.as(); + return tvm::runtime::contrib::TorchRuntimeCreate(func_name, attrs->serialized_function); +} + +TVM_REGISTER_GLOBAL("relay.ext.torch").set_body_typed(TorchCompiler); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 7de43eb36882..9c01c40517f4 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -506,27 +506,49 @@ class TypeInferencer : private ExprFunctor, size_t type_arity = fn_ty->arg_types.size(); size_t number_of_args = arg_types.size(); + bool is_variable = false; - if (type_arity != number_of_args) { - if (type_arity < number_of_args) { - this->EmitFatal(Diagnostic::Error(call->span) - << "the function is provided too many arguments " - << "expected " << type_arity << ", found " << number_of_args); - } else { - this->EmitFatal(Diagnostic::Error(call->span) - << "the function is provided too few arguments " - << "expected " << type_arity << ", found " << number_of_args); + if (const OpNode* opnode = call->op.as()) { + if (opnode->num_inputs == -1) { + is_variable = true; } } - for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { - this->Unify(fn_ty->arg_types[i], arg_types[i], call->span, true, false); + if ((type_arity < number_of_args) && !is_variable) { + this->EmitFatal(Diagnostic::Error(call->span) + << "the function is provided too many arguments " + << "expected " << type_arity << ", found " << number_of_args); + } else if (type_arity > number_of_args) { + this->EmitFatal(Diagnostic::Error(call->span) + << "the function is provided too few arguments " + << "expected " << type_arity << ", found " << number_of_args); } + Array unified_arg_types; + if (!is_variable) { + for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { + this->Unify(fn_ty->arg_types[i], arg_types[i], call->span, true, false); + } + } else { + for (size_t i = 0; i < number_of_args; i++) { + if (i < fn_ty->arg_types.size()) { + unified_arg_types.push_back( + this->Unify(fn_ty->arg_types[i], arg_types[i], call->span, false, false)); + } else { + unified_arg_types.push_back(arg_types[i]); + } + } + unified_arg_types.push_back(fn_ty->ret_type); + } for (auto cs : fn_ty->type_constraints) { if (const auto* tr = cs.as()) { - solver_.AddConstraint(TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs), - call->span); + if (!is_variable) { + solver_.AddConstraint(TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs), + call->span); + } else { + solver_.AddConstraint( + TypeRelation(tr->func, unified_arg_types, number_of_args, call->attrs), call->span); + } } else { solver_.AddConstraint(cs, call->span); } diff --git a/src/runtime/contrib/libtorch/libtorch_runtime.cc b/src/runtime/contrib/libtorch/libtorch_runtime.cc new file mode 100644 index 000000000000..5076b967a1de --- /dev/null +++ b/src/runtime/contrib/libtorch/libtorch_runtime.cc @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/libtorch/libtorch_runtime.cc + * \brief runtime implementation for LibTorch/TorchScript. + */ + +// we do not want clang to reorder our includes +// clang-format off +#include +#include +#include + +#include +#include +#include +#include + +// clang-format on + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace contrib { + +static void monly_deleter(DLManagedTensor* self) { delete self; } + +void run_torch_module(torch::jit::Module* module, TVMArgs args, TVMRetValue* rv) { + std::vector inputs; + std::vector outputs; + auto m = module->get_method("forward"); + for (int i = 0; i < args.size(); i++) { + const DLTensor* arg; + if (args[i].IsObjectRef()) { + NDArray arr = args[i]; + arg = arr.operator->(); + } else { + arg = args[i].operator DLTensor*(); + } + DLManagedTensor* inp = new DLManagedTensor{}; + inp->dl_tensor = *arg; + inp->deleter = &monly_deleter; + // m.num_inputs includes the self argument of forward(self, ...) + // num_inputs - 1 is the number of (Tensor) inputs + if (i < static_cast(m.num_inputs()) - 1) { + inputs.emplace_back(at::fromDLPack(inp)); + } else { + outputs.emplace_back(at::fromDLPack(inp)); + } + } + ICHECK(outputs.size() == 1) << "wrong number of args, can handle only one output"; + torch::Tensor res = module->forward(inputs).toTensor(); + outputs[0].copy_(res); // too bad +} + +/*! + * \brief A json runtime that executes the serialized JSON format. This runtime + * can be extended by user defined runtime for execution. + */ +class TorchModuleNode : public ModuleNode { + public: + TorchModuleNode(const std::string& symbol_name, const torch::jit::Module& module) + : symbol_name_(symbol_name), module_(module) {} + + const char* type_key() const { return "torch"; } + + /*! + * \brief Get a packed function. + * \param name The name/symbol of the function. + * \param sptr_to_self The pointer to the module node. + * \return The packed function. + */ + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "get_symbol") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_name_; }); + } else if (name == "get_const_vars") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = Array{}; }); + } else if (this->symbol_name_ == name) { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + run_torch_module(&module_, args, rv); + }); + } else if ("__init_" + this->symbol_name_ == name) { + // The function to initialize constant tensors. + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 0; }); + } else { + return PackedFunc(nullptr); + } + } + + virtual void SaveToBinary(dmlc::Stream* stream) { + // Save the symbol + stream->Write(symbol_name_); + // Save the module + std::stringstream str; + module_.save(str); + stream->Write(str.str()); + } + + static Module LoadFromBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + std::string symbol; + std::string module_str; + // Load the symbol + ICHECK(stream->Read(&symbol)) << "Loading symbol name failed"; + ICHECK(stream->Read(&module_str)) << "Loading module str failed"; + std::stringstream str(module_str); + torch::jit::Module mod = torch::jit::load(str); + auto n = make_object(symbol, mod); + return Module(n); + } + + /*! + * \brief Get the source generated by codegen. + * + * \param format the format to return. + * \return A string of JSON. + */ + std::string GetSource(const std::string& format = "json") override { + return module_.dump_to_str(true, true, true); + } + + protected: + /*! \brief The only subgraph name for this module. */ + std::string symbol_name_; + /*! \brief Module. */ + torch::jit::Module module_; +}; + +runtime::Module TorchRuntimeCreate(const String& symbol_name, + const std::string& serialized_function) { + std::stringstream str(serialized_function); + torch::jit::Module mod = torch::jit::load(str); + auto n = make_object(symbol_name, mod); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_torch") + .set_body_typed(TorchModuleNode::LoadFromBinary); + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/tests/python/contrib/test_libtorch_ops.py b/tests/python/contrib/test_libtorch_ops.py new file mode 100644 index 000000000000..751a547f94f5 --- /dev/null +++ b/tests/python/contrib/test_libtorch_ops.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm.relay +from tvm.relay.op.contrib import torchop + +try: + import torch +except ImportError as _: + torch = None + + +@pytest.mark.skipif(torch is None, reason="PyTorch is not available") +def test_backend(): + @torch.jit.script + def script_fn(x, y): + res = x * y + return res + + for torch_dt, dt in ( + (torch.int32, "int32"), + (torch.float32, "float32"), + (torch.float64, "float64"), + ): + x2 = tvm.relay.var("x", shape=[1, 2], dtype=dt) + y2 = tvm.relay.var("y", shape=[2, 2], dtype=dt) + + x3 = tvm.relay.var("x", shape=[1, 3], dtype=dt) + y3 = tvm.relay.var("y", shape=[3, 3], dtype=dt) + + test_body = tvm.relay.sum(torchop(script_fn, x2, y2)) + tvm.relay.sum( + torchop(script_fn, x3, y3) + ) + test_fn = tvm.relay.Function([x2, y2, x3, y3], test_body) + mod = tvm.IRModule({"main": test_fn}) + + tvm.relay.transform.InferType()(mod) + + # mod = tvm.relay.transform.AnnotateTarget("target.torch")(mod) + mod = tvm.relay.transform.MergeCompilerRegions()(mod) + mod = tvm.relay.transform.PartitionGraph()(mod) + mod = tvm.relay.transform.InferType()(mod) + + target = "llvm" + with tvm.transform.PassContext(opt_level=3): + lib = tvm.relay.build(mod, target, params={}) + + ctx = tvm.cpu(0) + rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](ctx)) + + # int does not have randn, so we cast... + x2t = torch.randn(1, 2).to(dtype=torch_dt) + y2t = torch.randn(2, 2).to(dtype=torch_dt) + x3t = torch.randn(1, 3).to(dtype=torch_dt) + y3t = torch.randn(3, 3).to(dtype=torch_dt) + # Set inputs + rt_mod.set_input(0, x2t) + rt_mod.set_input(1, y2t) + rt_mod.set_input(2, x3t) + rt_mod.set_input(3, y3t) + # Execute + rt_mod.run() + # Get outputs + tvm_output = rt_mod.get_output(0).numpy() + expected = (script_fn(x2t, y2t).sum() + script_fn(x3t, y3t).sum()).numpy() + print(tvm_output.dtype) + print(expected.dtype) + tvm.testing.assert_allclose(tvm_output, expected) + + +if __name__ == "__main__": + pytest.main([__file__])