Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVM PyTorch Integration] libstdc++ CXX11 ABI Compatibility & boolean tensor support #12232

Merged
merged 35 commits into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion apps/pt_tvmdsoop/tests/test_as_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# specific language governing permissions and limitations
# under the License.
"""Test script for tvm torch module"""
import tempfile

import numpy as np

import torch
Expand Down Expand Up @@ -190,7 +192,10 @@ def test_tvmscript_torch_gpu():
q1 = torch.arange(8, device=cuda0).type(torch.float32)
q2 = torch.zeros((8,), dtype=torch.float32, device=cuda0)

ModuleGPU(q1, q2)
with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
torch.save(ModuleGPU, tmp.name)
loaded_mod = torch.load(tmp.name)
loaded_mod(q1, q2)

tvm.testing.assert_allclose(q2.cpu().numpy(), (q1 + 1).cpu().numpy(), atol=1e-5, rtol=1e-5)

Expand Down
131 changes: 131 additions & 0 deletions apps/pt_tvmdsoop/tests/test_boolean_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#!/usr/bin/env python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test script for boolean tensor support"""
import tempfile

import numpy as np
import torch

import tvm
import tvm.testing
from tvm.contrib.torch import as_torch, optimize_torch
from tvm.meta_schedule.tune import TuneConfig
from tvm.script import tir as T


def negate(x):
return x.logical_not()


def sum_up_tensor(x):
return x.size(dim=0) - torch.sum(x.int())


def tensor_boolean_operation(x):
arr1 = (x + 0.3).floor().bool()
arr2 = (~((x + 0.7).int().bool())).bool()
ret = ((arr1 & arr2).byte() + 0.5).half()
return ~(ret.bool())


def test_bool_tensor_negate():
input = torch.ones(1, dtype=torch.bool)
optimized_negate = optimize_torch(
negate,
input,
)
with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
torch.save(optimized_negate, tmp.name)
loaded_mod = torch.load(tmp.name)
output = loaded_mod(negate(input))
tvm.testing.assert_allclose(input.numpy(), output.numpy(), atol=1e-5, rtol=1e-5)


def test_sum_up_tensor():
x = torch.randint(0, 2, (16,))
y = x.bool()
optimized_func = optimize_torch(
sum_up_tensor,
(y,),
)
ret1 = (x[x == 0]).size(dim=0)
ret2 = optimized_func(y).numpy()
tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5)


def test_tensor_boolean_operation():
input = torch.rand(200)
model = optimize_torch(
tensor_boolean_operation,
input,
)
ret1 = tensor_boolean_operation(input)
ret2 = model(input)
tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5)


@as_torch
@T.prim_func
def negate_tvmscript(
X: T.Buffer[(8, 8), "bool"],
Y: T.Buffer[(8, 8), "float32"],
Z: T.Buffer[(8, 8), "bool"],
U: T.Buffer[(8, 8), "float32"],
) -> None:
for i, j in T.grid(8, 8):
with T.block():
if Y[i, j] > 0.0:
Z[i, j] = X[i, j]
U[i, j] = Y[i, j]
else:
Z[i, j] = not X[i, j]
U[i, j] = 0.0 - Y[i, j]


def negate_vanila(x, y):
z = torch.zeros(8, 8).bool()
for i in range(8):
for j in range(8):
if y[i, j] > 0:
z[i, j] = x[i, j]
else:
z[i, j] = ~x[i, j]
return z


def test_tvmscript_torch_decorator():
q1 = (torch.rand(8, 8) + 0.5).int().bool()
q2 = torch.rand(8, 8) - 0.5
q3 = torch.zeros(8, 8).bool()
q4 = torch.zeros(8, 8)

std1 = negate_vanila(q1, q2)
std2 = torch.abs(q2)

negate_tvmscript(q1, q2, q3, q4)

tvm.testing.assert_allclose(std1.numpy(), q3.numpy(), atol=1e-5, rtol=1e-5)
tvm.testing.assert_allclose(std2.numpy(), q4.numpy(), atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
test_tvmscript_torch_decorator()
test_bool_tensor_negate()
test_sum_up_tensor()
test_tensor_boolean_operation()
55 changes: 45 additions & 10 deletions cmake/modules/contrib/PT_TVMDSOOP.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
Expand All @@ -21,38 +21,73 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF")
execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import torch; print(torch.__path__[0].strip())"
OUTPUT_VARIABLE PT_PATH
RESULT_VARIABLE PT_STATUS)
if (NOT ${PT_STATUS} EQUAL 0)

if(NOT ${PT_STATUS} EQUAL 0)
message(FATAL_ERROR "Fail to get pytorch path")
endif()

string(REGEX REPLACE "\n" "" PT_PATH "${PT_PATH}")
message(STATUS "PyTorch path: ${PT_PATH}")

set(PT_COMPILE_FLAGS_STR "-I${PT_PATH}/include -D_GLIBCXX_USE_CXX11_ABI=0")
execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import torch;print(torch.compiled_with_cxx11_abi())"
OUTPUT_VARIABLE PT_CXX_FLAG
RESULT_VARIABLE PT_STATUS)

string(REGEX REPLACE "\n" "" PT_CXX_FLAG "${PT_CXX_FLAG}")
message(STATUS "Found TORCH_BUILT_WITH_CXX_ABI=${PT_CXX_FLAG} ")

if(${PT_CXX_FLAG} STREQUAL "False")
set(CXX_ABI_ENABLED 0)
else()
set(CXX_ABI_ENABLED 1)
endif()

set_property(
SOURCE
${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc
APPEND PROPERTY
COMPILE_OPTIONS
"-D_GLIBCXX_USE_CXX11_ABI=${CXX_ABI_ENABLED}"
"-I${PT_PATH}/include"
)

set_property(
SOURCE
${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/tvm_class.cc
APPEND PROPERTY
COMPILE_OPTIONS
"-I${PT_PATH}/include"
)

set(PT_LINK_FLAGS_STR "-L${PT_PATH}/lib -l:libtorch.so -l:libtorch_python.so")

if(NOT USE_CUDA STREQUAL "OFF")
add_definitions(-DPT_TVMDSOOP_ENABLE_GPU)
endif()


string(REGEX REPLACE "\n" " " PT_FLAGS "${PT_COMPILE_FLAGS} ${PT_LINK_FLAGS}")
separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND ${PT_COMPILE_FLAGS_STR})
separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND)
separate_arguments(PT_LINK_FLAGS UNIX_COMMAND ${PT_LINK_FLAGS_STR})


set(LIBRARY_NAME pt_tvmdsoop)
tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/**/*.cc)
set(LIBRARY_TORCH_NAME pt_tvmdsoop_new)
tvm_file_glob(GLOB_RECURSE PTTVM_TORCH ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/*.cc)

tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/*.cc)

add_library(${LIBRARY_NAME} SHARED ${PTTVM_SRCS})
add_library(${LIBRARY_TORCH_NAME} SHARED ${PTTVM_TORCH})
set(PTTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR})

if (NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON")
add_dependencies(${LIBRARY_NAME} tvm)
if(NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON")
add_dependencies(${LIBRARY_NAME} tvm)
endif()

target_compile_options(${LIBRARY_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS})
target_link_libraries(${LIBRARY_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS})
target_compile_definitions(${LIBRARY_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)

target_compile_options(${LIBRARY_TORCH_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS})
target_link_libraries(${LIBRARY_TORCH_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS})
target_compile_definitions(${LIBRARY_TORCH_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
endif()

25 changes: 21 additions & 4 deletions python/tvm/contrib/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
"""Module container of Pytorch custom class"""
import os
import platform
import warnings
import torch
from tvm._ffi import libinfo


def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):
def _load_platform_specific_library(lib_name):
system = platform.system()
if system == "Darwin":
lib_file_name = lib_name + ".dylib"
Expand All @@ -33,11 +34,27 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):
lib_path = libinfo.find_lib_path()[0]
lib_dir = os.path.dirname(lib_path)
lib_file_path = os.path.join(lib_dir, lib_file_name)
torch.classes.load_library(lib_file_path)
try:
torch.classes.load_library(lib_file_path)
except OSError as err:
errmsg = str(err)
if errmsg.find("undefined symbol") != -1:
reason = " ".join(
(
"Got undefined symbol error,",
"which might be due to the CXXABI incompatibility.",
)
)
else:
reason = errmsg
warnings.warn(
f"The library {lib_name} is not built successfully. {reason}",
RuntimeWarning,
)


_load_platform_specific_library()

_load_platform_specific_library("libpt_tvmdsoop")
_load_platform_specific_library("libpt_tvmdsoop_new")

from . import module

Expand Down
17 changes: 17 additions & 0 deletions python/tvm/contrib/torch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
# under the License.
# pylint: disable=invalid-name
"""Module container of PyTorch custom class"""
import warnings
from typing import List

import torch


Expand All @@ -29,6 +31,11 @@ def shape_repr(cls, input_shapes):
return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes)

def __init__(self, num_inputs, num_outputs, device=None):
warnings.warn(
"This module will be removed at TVM version 0.11",
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.dummy_param = torch.nn.Parameter(torch.empty(0))
self.engine = None
Expand Down Expand Up @@ -67,6 +74,11 @@ def shape_repr(cls, input_shapes):
return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes)

def __init__(self, num_inputs, num_outputs, device=None):
warnings.warn(
"This module will be removed at TVM version 0.11",
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.dummy_param = torch.nn.Parameter(torch.empty(0))
self.engine = None
Expand Down Expand Up @@ -113,6 +125,11 @@ class TraceTvmModule(torch.nn.Module):
"""

def __init__(self, tvm_module):
warnings.warn(
"This module will be removed at TVM version 0.11",
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.tvm_module = tvm_module

Expand Down
21 changes: 21 additions & 0 deletions python/tvm/contrib/torch/pytorch_tvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# pylint: disable=redefined-builtin
"""`compile` api that convert torch module to torch tvm module"""
import os
import warnings
import tvm
import tvm.testing
from tvm import relay, autotvm
Expand Down Expand Up @@ -183,6 +184,16 @@ def load_tvm(self, export_dir):

def build_pytorch_module(self, num_inputs, num_outputs, input_infos=None):
"""Build pytorch module containing TVM Graph Module"""
warnings.warn(
" ".join(
(
"This function will be removed at TVM version 0.11,",
"we suggest users to use `optimized_torch` for tuning Torch modules instead.",
)
),
DeprecationWarning,
stacklevel=2,
)
assert self.export_dir, "you must build_tvm or load_tvm before"
input_infos = input_infos or self.input_infos
assert input_infos
Expand Down Expand Up @@ -224,6 +235,16 @@ def compile(script_module, option):
pytorch_tvm_module = compile(script_module, option)
pytorch_tvm_module("model_tvm.pt")
"""
warnings.warn(
" ".join(
(
"This function will be removed at TVM version 0.11,",
"we suggest users to use `optimized_torch` for tuning Torch modules instead.",
)
),
DeprecationWarning,
stacklevel=2,
)
input_infos = option["input_infos"]
default_dtype = option.get("default_dtype", "float32")
export_dir = option.get("export_dir", "pytorch_compiled")
Expand Down
Loading