From a3c4433d7db8a6bc93f05551fe0a3729cef14f72 Mon Sep 17 00:00:00 2001 From: driazati Date: Wed, 9 Mar 2022 15:14:42 -0800 Subject: [PATCH] Fix TorchScript fallback build This was missing a header `libtorch_runtime.h`. The test in `test_libtorch_ops.py` is also currently being skipped in CI since `torch` isn't available but that's left for a follow up cc @t-vi @masahi commit-id:f8998762 --- .../tvm/runtime/contrib/libtorch_runtime.h | 40 +++++++++++++++++++ .../contrib/libtorch/libtorch_codegen.cc | 2 +- .../contrib/libtorch/libtorch_runtime.cc | 1 + tests/python/contrib/test_libtorch_ops.py | 7 +++- 4 files changed, 47 insertions(+), 3 deletions(-) create mode 100644 include/tvm/runtime/contrib/libtorch_runtime.h diff --git a/include/tvm/runtime/contrib/libtorch_runtime.h b/include/tvm/runtime/contrib/libtorch_runtime.h new file mode 100644 index 000000000000..2645fb94d10d --- /dev/null +++ b/include/tvm/runtime/contrib/libtorch_runtime.h @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief runtime implementation for LibTorch/TorchScript. + */ +#ifndef TVM_RUNTIME_CONTRIB_LIBTORCH_RUNTIME_H_ +#define TVM_RUNTIME_CONTRIB_LIBTORCH_RUNTIME_H_ +#include + +#include + +namespace tvm { +namespace runtime { +namespace contrib { + +runtime::Module TorchRuntimeCreate(const String& symbol_name, + const std::string& serialized_function); + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_LIBTORCH_RUNTIME_H_ diff --git a/src/relay/backend/contrib/libtorch/libtorch_codegen.cc b/src/relay/backend/contrib/libtorch/libtorch_codegen.cc index 25bfbfad4443..f70466f00eed 100644 --- a/src/relay/backend/contrib/libtorch/libtorch_codegen.cc +++ b/src/relay/backend/contrib/libtorch/libtorch_codegen.cc @@ -32,7 +32,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/runtime/contrib/libtorch/libtorch_runtime.cc b/src/runtime/contrib/libtorch/libtorch_runtime.cc index 5076b967a1de..e76d04389ec7 100644 --- a/src/runtime/contrib/libtorch/libtorch_runtime.cc +++ b/src/runtime/contrib/libtorch/libtorch_runtime.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include diff --git a/tests/python/contrib/test_libtorch_ops.py b/tests/python/contrib/test_libtorch_ops.py index 751a547f94f5..28ae39c329f5 100644 --- a/tests/python/contrib/test_libtorch_ops.py +++ b/tests/python/contrib/test_libtorch_ops.py @@ -20,13 +20,16 @@ import tvm.relay from tvm.relay.op.contrib import torchop +import_torch_error = None + try: import torch -except ImportError as _: +except ImportError as e: torch = None + import_torch_error = str(e) -@pytest.mark.skipif(torch is None, reason="PyTorch is not available") +@pytest.mark.skipif(torch is None, reason=f"PyTorch is not available: {import_torch_error}") def test_backend(): @torch.jit.script def script_fn(x, y):