Skip to content

Commit

Permalink
Fix TorchScript fallback build
Browse files Browse the repository at this point in the history
This was missing a header `libtorch_runtime.h`. The test in `test_libtorch_ops.py` is also currently being skipped in CI since `torch` isn't available but that's left for a follow up

cc @t-vi @masahi

commit-id:f8998762
  • Loading branch information
driazati committed Mar 9, 2022
1 parent f9d3918 commit 0576780
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 3 deletions.
38 changes: 38 additions & 0 deletions include/tvm/runtime/contrib/libtorch_runtime.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \brief runtime implementation for LibTorch/TorchScript.
*/
#include <tvm/runtime/module.h>

#include <string>

namespace tvm {
namespace runtime {
namespace contrib {


runtime::Module TorchRuntimeCreate(const String& symbol_name,
const std::string& serialized_function);


} // namespace contrib
} // namespace runtime
} // namespace tvm
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/libtorch/libtorch_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include <tvm/relay/op.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/contrib/libtorch/libtorch_runtime.h>
#include <tvm/runtime/contrib/libtorch_runtime.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/op.h>
Expand Down
1 change: 1 addition & 0 deletions src/runtime/contrib/libtorch/libtorch_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/module.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/contrib/libtorch_runtime.h>

#include <ATen/dlpack.h>
#include <ATen/DLConvertor.h>
Expand Down
7 changes: 5 additions & 2 deletions tests/python/contrib/test_libtorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
import tvm.relay
from tvm.relay.op.contrib import torchop

import_torch_error = None

try:
import torch
except ImportError as _:
except ImportError as e:
torch = None
import_torch_error = str(e)


@pytest.mark.skipif(torch is None, reason="PyTorch is not available")
@pytest.mark.skipif(torch is None, reason=f"PyTorch is not available: {import_torch_error}")
def test_backend():
@torch.jit.script
def script_fn(x, y):
Expand Down

0 comments on commit 0576780

Please sign in to comment.