diff --git a/src/contrib/torch/pt_call_tvm/tvm_class.cc b/src/contrib/torch/pt_call_tvm/tvm_class.cc index 5e57dc152f11..f5ae95a5a73d 100644 --- a/src/contrib/torch/pt_call_tvm/tvm_class.cc +++ b/src/contrib/torch/pt_call_tvm/tvm_class.cc @@ -167,7 +167,7 @@ class TvmVMModulePack { const auto runtime_create = *tvm::runtime::Registry::Get("runtime._VirtualMachine"); vm_ = runtime_create(exe_); auto init_func = vm_.GetFunction("init", false); - auto alloc_type = static_cast(tvm::runtime::vm::AllocatorType::kPooled); + auto alloc_type = static_cast(tvm::runtime::memory::AllocatorType::kPooled); if (device_type != kDLCPU) { // CPU is required for executing shape functions init_func(static_cast(kDLCPU), 0, alloc_type, device_type, device_id, alloc_type); diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc index c77996cf67b6..3e1c7e7c0edf 100644 --- a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -29,7 +29,7 @@ #include #include "../../../runtime/graph_executor/graph_executor_factory.h" -#include "../../support/base64.h" +#include "../../../support/base64.h" #include "runtime_bridge.h" namespace tvm { @@ -209,10 +209,10 @@ inline void b64decode(const std::string b64str, uint8_t* ret) { size_t index = 0; const auto length = b64str.size(); for (size_t i = 0; i < length; i += 4) { - int8_t ch0 = base64::DecodeTable[(int32_t)b64str[i]]; - int8_t ch1 = base64::DecodeTable[(int32_t)b64str[i + 1]]; - int8_t ch2 = base64::DecodeTable[(int32_t)b64str[i + 2]]; - int8_t ch3 = base64::DecodeTable[(int32_t)b64str[i + 3]]; + int8_t ch0 = tvm::support::base64::DecodeTable[(int32_t)b64str[i]]; + int8_t ch1 = tvm::support::base64::DecodeTable[(int32_t)b64str[i + 1]]; + int8_t ch2 = tvm::support::base64::DecodeTable[(int32_t)b64str[i + 2]]; + int8_t ch3 = tvm::support::base64::DecodeTable[(int32_t)b64str[i + 3]]; uint8_t st1 = (ch0 << 2) + (ch1 >> 4); ret[index++] = st1; if (b64str[i + 2] != '=') {