Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: initial stab at TorchScript fallback #7401

Merged
merged 4 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF)
tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF)
tvm_option(USE_SORT "Build with sort support" ON)
tvm_option(USE_NNPACK "Build with nnpack support" OFF)
tvm_option(USE_LIBTORCH "Build with libtorch support" OFF)
tvm_option(USE_RANDOM "Build with random support" ON)
tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF)
tvm_option(USE_CPP_RPC "Build CPP RPC" OFF)
Expand Down Expand Up @@ -455,6 +456,7 @@ include(cmake/modules/contrib/Posit.cmake)
include(cmake/modules/contrib/MicroStandaloneRuntime.cmake)
include(cmake/modules/contrib/Sort.cmake)
include(cmake/modules/contrib/NNPack.cmake)
include(cmake/modules/contrib/LibTorch.cmake)
include(cmake/modules/contrib/HybridDump.cmake)
include(cmake/modules/contrib/TFLite.cmake)
include(cmake/modules/contrib/TF_TVMDSOOP.cmake)
Expand Down
5 changes: 5 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,8 @@ set(USE_CUTLASS OFF)

# Enable to show a summary of TVM options
set(SUMMARIZE OFF)

# Whether to use LibTorch as backend
# To enable pass the path to the root libtorch (or PyTorch) directory
# OFF or /path/to/torch/
set(USE_LIBTORCH OFF)
30 changes: 30 additions & 0 deletions cmake/modules/contrib/LibTorch.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

if(USE_LIBTORCH)
find_package(Torch REQUIRED PATHS ${USE_LIBTORCH}/share/cmake/Torch
)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${TORCH_LIBRARIES})
include_directories(${TORCH_INCLUDE_DIRS})

file(GLOB LIBTORCH_RELAY_CONTRIB_SRC
src/relay/backend/contrib/libtorch/libtorch_codegen.cc
src/runtime/contrib/libtorch/libtorch_runtime.cc
)
list(APPEND COMPILER_SRCS ${LIBTORCH_RELAY_CONTRIB_SRC})

endif(USE_LIBTORCH)
1 change: 1 addition & 0 deletions python/tvm/relay/op/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@
from .bnns import *
from .coreml import *
from .ethosn import *
from .libtorch import *
from .tensorrt import *
from .cutlass import *
36 changes: 36 additions & 0 deletions python/tvm/relay/op/contrib/libtorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument, no-else-return, E1102
"""Torch codegen operators"""

from tvm import relay
from tvm.relay.op.annotation import compiler_begin, compiler_end


def torchop(script_fn, *params):
"""Insert an Operation executed in the PyTorch JIT

The operation includes backend annotation

Currently, only tensors are supported. The shape inferrence
assumes that input shapes (and not values) determine output shapes."""
return compiler_end(
relay.op._make.torchop(
[compiler_begin(p, "torch") for p in params], script_fn.save_to_buffer()
),
"torch",
)
138 changes: 138 additions & 0 deletions src/relay/backend/contrib/libtorch/libtorch_codegen.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/backend/contrib/libtorch/codegen.cc
* \brief Implementation of libtorch codegen.
*/

#include <ATen/DLConvertor.h>
#include <dlpack/dlpack.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/torch.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/contrib/libtorch/libtorch_runtime.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/op.h>

#include <fstream>
#include <numeric>
#include <sstream>

#include "../../utils.h"

namespace tvm {
namespace relay {
namespace contrib {

using namespace backend;

/*! \brief Attributes of a TorchFunction node */
struct TorchFunctionAttrs : public tvm::AttrsNode<TorchFunctionAttrs> {
std::string serialized_function;
int64_t len;

TVM_DECLARE_ATTRS(TorchFunctionAttrs, "relay.attrs.TorchFunctionAttrs") {
TVM_ATTR_FIELD(serialized_function).set_default("").describe("Function from fn.save(...)");
TVM_ATTR_FIELD(len).set_default(-1).describe("Function from fn.save(...)");
}
};

TVM_REGISTER_NODE_TYPE(TorchFunctionAttrs);

bool TorchOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
const auto* sfattrs = attrs.as<TorchFunctionAttrs>();
std::stringstream str(sfattrs->serialized_function);
torch::jit::Module mod = torch::jit::load(str);

std::vector<torch::jit::IValue> inputs;
for (int i = 0; i < num_inputs; i++) {
auto* ty = types[i].as<TensorTypeNode>();
ICHECK(ty) << "only accept tensors as inputs";
std::vector<int64_t> shape;
for (const auto& s : ty->shape) {
auto* si = s.as<IntImmNode>();
if (!si) {
return false;
}
shape.push_back(si->value);
}
auto torchScalarType = at::toScalarType(ty->dtype);

inputs.emplace_back(torch::zeros(shape, at::TensorOptions().dtype(torchScalarType)));
}
auto res = mod.forward(inputs);
auto res_t = res.toTensor();
ICHECK((int)types.size() == num_inputs + 1) << "only single output supported";
Array<PrimExpr> res_sizes;
for (int d = 0; d < res_t.dim(); d++) {
res_sizes.push_back(IntImm(DataType::Int(32), res_t.size(d)));
}
reporter->Assign(types[num_inputs], TensorType(res_sizes, DataType(at::getDLDataType(res_t))));
return true;
}

RELAY_REGISTER_OP("torch_op")
.set_support_level(99)
.add_type_rel("TorchOpRel", TorchOpRel)
.set_attrs_type<TorchFunctionAttrs>();

Expr MakeTorchOp(Array<Expr> args, const std::string& serialized_function) {
static const Op& op = Op::Get("torch_op");
auto attrs = make_object<TorchFunctionAttrs>();
attrs->serialized_function = serialized_function;
attrs->len = serialized_function.size();
return Call(op, args, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.torchop").set_body_typed(MakeTorchOp);

/*!
* \brief The external compiler/codegen tool. It takes a Relay expression/module and
* compile it into a runtime module.
*/
runtime::Module TorchCompiler(const ObjectRef& ref) {
ICHECK(ref->IsInstance<FunctionNode>()) << "The input ref is expected to be a Relay function.";
Function func = Downcast<Function>(ref);
std::string func_name = backend::GetExtSymbol(func);

ICHECK(func.defined()) << "Input error: expect a Relay function.";
const auto* call = func->body.as<CallNode>();
ICHECK(call) << "Expected call node\n";
const auto* op_node = call->op.as<OpNode>();
ICHECK(op_node) << "Expect OpNode, but got " << call->op->GetTypeKey();
const auto op_name = GetRef<Op>(op_node)->name;
ICHECK(op_name == "torch_op") << "Unsupported op: " << AsText(call->op, false) << "\n";

const auto* attrs = call->attrs.as<TorchFunctionAttrs>();
return tvm::runtime::contrib::TorchRuntimeCreate(func_name, attrs->serialized_function);
}

TVM_REGISTER_GLOBAL("relay.ext.torch").set_body_typed(TorchCompiler);

} // namespace contrib
} // namespace relay
} // namespace tvm
48 changes: 35 additions & 13 deletions src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -506,27 +506,49 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,

size_t type_arity = fn_ty->arg_types.size();
size_t number_of_args = arg_types.size();
bool is_variable = false;

if (type_arity != number_of_args) {
if (type_arity < number_of_args) {
this->EmitFatal(Diagnostic::Error(call->span)
<< "the function is provided too many arguments "
<< "expected " << type_arity << ", found " << number_of_args);
} else {
this->EmitFatal(Diagnostic::Error(call->span)
<< "the function is provided too few arguments "
<< "expected " << type_arity << ", found " << number_of_args);
if (const OpNode* opnode = call->op.as<OpNode>()) {
if (opnode->num_inputs == -1) {
is_variable = true;
}
}

for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
this->Unify(fn_ty->arg_types[i], arg_types[i], call->span, true, false);
if ((type_arity < number_of_args) && !is_variable) {
this->EmitFatal(Diagnostic::Error(call->span)
<< "the function is provided too many arguments "
<< "expected " << type_arity << ", found " << number_of_args);
} else if (type_arity > number_of_args) {
this->EmitFatal(Diagnostic::Error(call->span)
<< "the function is provided too few arguments "
<< "expected " << type_arity << ", found " << number_of_args);
}

Array<Type> unified_arg_types;
if (!is_variable) {
for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
this->Unify(fn_ty->arg_types[i], arg_types[i], call->span, true, false);
}
} else {
for (size_t i = 0; i < number_of_args; i++) {
if (i < fn_ty->arg_types.size()) {
unified_arg_types.push_back(
this->Unify(fn_ty->arg_types[i], arg_types[i], call->span, false, false));
} else {
unified_arg_types.push_back(arg_types[i]);
}
}
unified_arg_types.push_back(fn_ty->ret_type);
}
for (auto cs : fn_ty->type_constraints) {
if (const auto* tr = cs.as<TypeRelationNode>()) {
solver_.AddConstraint(TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs),
call->span);
if (!is_variable) {
solver_.AddConstraint(TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs),
call->span);
} else {
solver_.AddConstraint(
TypeRelation(tr->func, unified_arg_types, number_of_args, call->attrs), call->span);
}
} else {
solver_.AddConstraint(cs, call->span);
}
Expand Down
Loading