Skip to content

Commit

Permalink
Merge remote-tracking branch 'main' into unity
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Nov 6, 2023
2 parents 3f1347c + ffa0033 commit 3de77f8
Show file tree
Hide file tree
Showing 47 changed files with 768 additions and 272 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ jobs:
python -m pytest -v tests/python/all-platform-minimal-test
- name: Minimal Metal Compile-Only
shell: bash -l {0}
run: >-
run: |
python -m pytest -v -s 'tests/python/unittest/test_allreduce.py::test_allreduce_sum_compile'
python -m pytest -v -s 'tests/python/unittest/test_target_codegen_metal.py::test_func_with_trailing_pod_params'
- name: Minimal Metal Compile-and-Run
shell: bash -l {0}
run: >-
Expand Down
9 changes: 9 additions & 0 deletions cmake/utils/FindLLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ macro(find_llvm use_llvm)
message(FATAL_ERROR "Fatal error executing: ${LLVM_CONFIG} --libdir")
endif()
message(STATUS "LLVM libdir: ${__llvm_libdir}")
execute_process(COMMAND ${LLVM_CONFIG} --cmakedir
RESULT_VARIABLE __llvm_exit_code
OUTPUT_VARIABLE __llvm_cmakedir
OUTPUT_STRIP_TRAILING_WHITESPACE)
if(NOT "${__llvm_exit_code}" STREQUAL "0")
message(FATAL_ERROR "Fatal error executing: ${LLVM_CONFIG} --cmakedir")
endif()
message(STATUS "LLVM cmakedir: ${__llvm_cmakedir}")
# map prefix => $
# to handle the case when the prefix contains space.
string(REPLACE ${__llvm_prefix} "$" __llvm_cxxflags ${__llvm_cxxflags_space})
Expand Down Expand Up @@ -165,6 +173,7 @@ macro(find_llvm use_llvm)
find_package(ZLIB REQUIRED)
list(APPEND LLVM_LIBS "ZLIB::ZLIB")
elseif("${__flag}" STREQUAL "-lzstd" OR ("${__flag}" STREQUAL "zstd.dll.lib"))
list(APPEND CMAKE_MODULE_PATH "${__llvm_cmakedir}")
find_package(zstd REQUIRED)
if (TARGET "zstd::libzstd_static")
message(STATUS "LLVM links against static zstd")
Expand Down
4 changes: 2 additions & 2 deletions docs/how_to/deploy/tensorrt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ regular TVM CUDA compilation and code generation.
.. code:: python
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt
mod, config = partition_for_tensorrt(mod, params)
mod = partition_for_tensorrt(mod, params)
Build the Relay graph, using the new module and config returned by partition_for_tensorrt. The
Expand All @@ -107,7 +107,7 @@ PassContext so the values can be read during compilation.
.. code:: python
target = "cuda"
with tvm.transform.PassContext(opt_level=3, config={'relay.ext.tensorrt.options': config}):
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/runtime/memory/memory_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class Allocator {
* \param buffer The buffer to free.
*/
virtual void Free(const Buffer& buffer) = 0;
/*! \brief Clear the allocated memory. */
virtual void Clear();
/*! \brief The amount of memory currently allocated.
* \return The amount of memory currently allocated.
*/
Expand Down Expand Up @@ -119,6 +121,8 @@ class MemoryManager {
* \return The memory allocator.
*/
static Allocator* GetAllocator(Device dev, AllocatorType type);
/*! \brief Clear the allocators. */
static void Clear();

private:
MemoryManager() {}
Expand Down
80 changes: 80 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,31 @@ struct PackedFuncValueConverter {
} \
}

#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \
const char* type_key() const final { return TypeKey; } \
PackedFunc GetFunction(const String& _name, const ObjectPtr<Object>& _self) final { \
using SelfPtr = std::remove_cv_t<decltype(this)>;
#define TVM_MODULE_VTABLE_END() \
return PackedFunc(nullptr); \
}
#define TVM_MODULE_VTABLE_ENTRY(Name, MemFunc) \
if (_name == Name) { \
return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void { \
using Helper = ::tvm::runtime::detail::ModuleVTableEntryHelper<decltype(MemFunc)>; \
SelfPtr self = static_cast<SelfPtr>(_self.get()); \
CHECK_EQ(args.size(), Helper::LenArgs) \
<< "Function `" << self->type_key() << "::" << Name << "` requires " << Helper::LenArgs \
<< " arguments, but got " << args.size(); \
Helper::Call(rv, self, MemFunc, args, Helper::IndexSeq{}); \
}); \
}
#define TVM_MODULE_VTABLE_ENTRY_PACKED(Name, Func) \
if (_name == Name) { \
auto f = (Func); \
using FType = ::tvm::runtime::detail::function_signature<decltype(f)>::FType; \
return TypedPackedFunc<FType>(std::move(f)).packed(); \
}

/*!
* \brief Export typed function as a PackedFunc
* that can be loaded by LibraryModule.
Expand Down Expand Up @@ -1330,6 +1355,61 @@ inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(f, std::forward<Args>(args)...);
}

template <typename T>
struct ModuleVTableEntryHelper {};

template <typename T, typename R, typename... Args>
struct ModuleVTableEntryHelper<R (T::*)(Args...) const> {
using MemFnType = R (T::*)(Args...) const;
using IndexSeq = std::index_sequence_for<Args...>;
static constexpr const std::size_t LenArgs = sizeof...(Args);

template <std::size_t... Is>
static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
std::index_sequence<Is...>) {
*rv = (self->*f)(args[Is]...);
}
};

template <typename T, typename R, typename... Args>
struct ModuleVTableEntryHelper<R (T::*)(Args...)> {
using MemFnType = R (T::*)(Args...);
using IndexSeq = std::index_sequence_for<Args...>;
static constexpr const std::size_t LenArgs = sizeof...(Args);

template <std::size_t... Is>
static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
std::index_sequence<Is...>) {
*rv = (self->*f)(args[Is]...);
}
};

template <typename T, typename... Args>
struct ModuleVTableEntryHelper<void (T::*)(Args...) const> {
using MemFnType = void (T::*)(Args...) const;
using IndexSeq = std::index_sequence_for<Args...>;
static constexpr const std::size_t LenArgs = sizeof...(Args);

template <std::size_t... Is>
static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
std::index_sequence<Is...>) {
(self->*f)(args[Is]...);
}
};

template <typename T, typename... Args>
struct ModuleVTableEntryHelper<void (T::*)(Args...)> {
using MemFnType = void (T::*)(Args...);
using IndexSeq = std::index_sequence_for<Args...>;
static constexpr const std::size_t LenArgs = sizeof...(Args);

template <std::size_t... Is>
static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
std::index_sequence<Is...>) {
(self->*f)(args[Is]...);
}
};

namespace parameter_pack {

template <typename... EnumArgs>
Expand Down
39 changes: 23 additions & 16 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,28 @@ struct VMFunction;
*/
class TVM_DLL Executable : public ModuleNode {
public:
/*!
* \brief Get a PackedFunc from an executable module.
*
* \param name the name of the function.
* \param sptr_to_self The shared_ptr that points to this module node.
*
* \return PackedFunc or nullptr when it is not available.
*/
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final;
TVM_MODULE_VTABLE_BEGIN("VMExecutable");
TVM_MODULE_VTABLE_ENTRY("get_lib", &Executable::GetLib);
TVM_MODULE_VTABLE_ENTRY("get_bytecode", &Executable::GetBytecode);
TVM_MODULE_VTABLE_ENTRY("get_constants", &Executable::GetConstants);
TVM_MODULE_VTABLE_ENTRY("get_virtual_devices", &Executable::GetVirtualDevices);
TVM_MODULE_VTABLE_ENTRY("get_primitives", &Executable::GetPrimitives);
TVM_MODULE_VTABLE_ENTRY("get_stats", &Executable::Stats);
TVM_MODULE_VTABLE_ENTRY("save", &Executable::Save);
TVM_MODULE_VTABLE_ENTRY("get_function_arity", &Executable::GetFunctionArity);
TVM_MODULE_VTABLE_ENTRY("get_function_param_name", &Executable::GetFunctionParameterName);
TVM_MODULE_VTABLE_ENTRY("vm_load_executable", &Executable::VMLoadExecutable);
TVM_MODULE_VTABLE_ENTRY("move_late_bound_consts", &Executable::MoveLateBoundConstantsToFile);
TVM_MODULE_VTABLE_ENTRY("get_late_bound_consts", &Executable::GetLateBoundConstants);
TVM_MODULE_VTABLE_ENTRY("load_late_bound_consts", &Executable::LoadLateBoundConstantsFromFile);
TVM_MODULE_VTABLE_ENTRY("load_late_bound_consts_from_map",
&Executable::LoadLateBoundConstantsFromMap);
TVM_MODULE_VTABLE_END();

/*! \brief Get the property of the runtime module .*/
int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; };

/*! \brief Creates a VM that loads `this` as the executable. */
Module VMLoadExecutable();
/*!
* \brief Write the Executable to the binary stream in serialized form.
*
Expand Down Expand Up @@ -123,17 +132,17 @@ class TVM_DLL Executable : public ModuleNode {
* Must be called before \p SaveToBinary and friends if late-bound constants are
* desired. Otherwise can be ignore.
*/
void MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byte_limit);
void MoveLateBoundConstantsToStream(dmlc::Stream* stream, int64_t byte_limit);

/*!
* \brief As for \p MoveLateBoundConstantsToStream, but save to file at \p path.
*/
void MoveLateBoundConstantsToFile(const std::string& path, size_t byte_limit);
void MoveLateBoundConstantsToFile(const std::string& path, int64_t byte_limit);

/*!
* \brief Get a map of all constants with larger that byte_limit in size.
*/
Map<String, NDArray> GetLateBoundConstants(size_t byte_limit);
Map<String, NDArray> GetLateBoundConstants(int64_t byte_limit);

/*!
* \brief Restores the late-bound constants for the executable (if any) from given byte-stream.
Expand Down Expand Up @@ -255,12 +264,10 @@ class TVM_DLL Executable : public ModuleNode {
* \param index Parameter index.
* \return The parameter name.
*/
std::string GetFunctionParameterName(std::string func, uint32_t index) const;
std::string GetFunctionParameterName(std::string func, int index) const;

virtual ~Executable() {}

const char* type_key() const final { return "VMExecutable"; }

/*!
* \brief The (compile-time, virtual) devices corresponding to each device index.
* This vector contains a pair Device and its memory_scope.
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/_ffi/_ctypes/packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ def _init_pythonapi_inc_def_ref():
register_func = _LIB.TVMBackendRegisterEnvCAPI
register_func(c_str("Py_IncRef"), ctypes.pythonapi.Py_IncRef)
register_func(c_str("Py_DecRef"), ctypes.pythonapi.Py_DecRef)
register_func(c_str("PyGILState_Ensure"), ctypes.pythonapi.PyGILState_Ensure)
register_func(c_str("PyGILState_Release"), ctypes.pythonapi.PyGILState_Release)


_init_pythonapi_inc_def_ref()
4 changes: 3 additions & 1 deletion python/tvm/_ffi/_cython/packed_func.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import ctypes
import traceback
from cpython cimport Py_INCREF, Py_DECREF
from cpython cimport Py_INCREF, Py_DECREF, PyGILState_Ensure, PyGILState_Release
from numbers import Number, Integral
from ..base import string_types, py2cerror
from ..runtime_ctypes import DataType, Device, TVMByteArray, ObjectRValueRef
Expand Down Expand Up @@ -381,5 +381,7 @@ def _init_pythonapi_inc_def_ref():
register_func = TVMBackendRegisterEnvCAPI
register_func(c_str("Py_IncRef"), <void*>_py_incref_wrapper)
register_func(c_str("Py_DecRef"), <void*>_py_decref_wrapper)
register_func(c_str("PyGILState_Ensure"), <void*>PyGILState_Ensure)
register_func(c_str("PyGILState_Release"), <void*>PyGILState_Release)

_init_pythonapi_inc_def_ref()
15 changes: 10 additions & 5 deletions python/tvm/_ffi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import sys
import types

from typing import Callable, Sequence
from typing import Callable, Sequence, Optional

import numpy as np

Expand Down Expand Up @@ -340,15 +340,16 @@ def get_last_ffi_error():
return ERROR_TYPE.get(err_type, TVMError)(py_err_msg)


def _append_traceback_frame(tb, func_name, filepath, lineno):
def _append_traceback_frame(tb, func_name, filepath, lineno: Optional[int]):
"""Append a dummy frame to appear in the Python traceback"""

# Compile a dummy function to Python bytecode, so that with the
# filepath that we want to appear in the traceback. Any external
# debugger (e.g. pdb) that catches the exception will use the
# filepath to show code snippets from that FFI file.
header = "" if lineno is None else "\n" * (lineno - 1)
code = compile(
"{}def dummy_func(): raise NotImplementedError()".format("\n" * (lineno - 1)),
f"{header}def dummy_func(): raise NotImplementedError()",
filepath,
"exec",
)
Expand Down Expand Up @@ -446,10 +447,14 @@ def raise_last_ffi_error():
for frame in frames:
if " at " in frame:
func_name, frame = frame.split(" at ", 1)
filename, lineno = frame.rsplit(":", 1)
if ":" in frame:
filename, lineno = frame.rsplit(":", 1)
lineno = int(lineno.strip())
else:
filename = frame
lineno = None
func_name = func_name.strip()
filename = filename.strip()
lineno = int(lineno.strip())

tb = _append_traceback_frame(tb, func_name, filename, lineno)

Expand Down
19 changes: 17 additions & 2 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,32 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
arch = ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"]

temp = utils.tempdir()
file_name = "tvm_kernels"
if target_format not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target_format must be in cubin, ptx, fatbin")
temp_code = temp.relpath("my_kernel.cu")
temp_target = temp.relpath(f"my_kernel.{target_format}")
temp_code = temp.relpath(f"{file_name}.cu")
temp_target = temp.relpath(f"{file_name}.{target_format}")

pass_context = tvm.get_global_func("transform.GetCurrentPassContext")()
kernels_output_dir = (
pass_context.config["cuda.kernels_output_dir"]
if "cuda.kernels_output_dir" in pass_context.config
else None
)
if kernels_output_dir is not None:
if not os.path.isdir(kernels_output_dir):
os.makedirs(kernels_output_dir)
temp_code = os.path.join(kernels_output_dir, f"{file_name}.cu")
temp_target = os.path.join(kernels_output_dir, f"{file_name}.{target_format}")

with open(temp_code, "w") as out_file:
out_file.write(code)

file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd += [f"--{target_format}", "-O3"]
if kernels_output_dir is not None:
cmd += ["-lineinfo"]
if isinstance(arch, list):
cmd += arch
elif isinstance(arch, str):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,7 @@ def callback(
# Find the tensors that are inputs to the concat and the scales and zero points
concat_args = list()
for arg in post.args:
if isinstance(arg, tvm.relay.expr.Call):
if isinstance(arg, (tvm.relay.expr.Call, tvm.relay.expr.TupleGetItem)):
concat_args.append(arg)

axis = post.op.body.attrs.axis
Expand Down
19 changes: 19 additions & 0 deletions src/runtime/memory/memory_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief Allocate and manage memory for the runtime.
*/
#include <tvm/runtime/memory/memory_manager.h>
#include <tvm/runtime/registry.h>

#include <memory>
#include <utility>
Expand Down Expand Up @@ -166,6 +167,16 @@ Allocator* MemoryManager::GetAllocator(Device dev, AllocatorType type) {
return it->second.at(type).get();
}

void MemoryManager::Clear() {
MemoryManager* m = MemoryManager::Global();
std::lock_guard<std::mutex> lock(m->mu_);
for (const auto& [device, allocators] : m->allocators_) {
for (const auto& [allocator_type, allocator] : allocators) {
allocator->Clear();
}
}
}

NDArray Allocator::Empty(ShapeTuple shape, DLDataType dtype, DLDevice dev,
Optional<String> mem_scope) {
VerifyDataType(dtype);
Expand Down Expand Up @@ -198,6 +209,14 @@ Buffer Allocator::Alloc(Device dev, ShapeTuple shape, DLDataType type_hint,
return {};
}

void Allocator::Clear() {
// This function by default does nothing.
// For naive allocator, no explicit manual clear is needed.
// Pooled allocator will override this method.
}

TVM_REGISTER_GLOBAL("vm.builtin.memory_manager.clear").set_body_typed(MemoryManager::Clear);

} // namespace memory
} // namespace runtime
} // namespace tvm
Loading

0 comments on commit 3de77f8

Please sign in to comment.