Skip to content

Commit

Permalink
[DLPack][runtime] Update DLPack to v0.7 (apache#13177)
Browse files Browse the repository at this point in the history
- Update the `3rdparty/dlpack` git submodule from v0.5 to v0.7, so that
the `DLDeviceType` enumeration has an explicitly-stated underlying
storage type.  This addresses a compiler warning generated by clang
15.0.3.

- Remove `kDLHexagon` and `kDLWebGPU` from `TVMDeviceExtType`, because
those enumerators are now provided by `DLDeviceType`.

- Renumber the members of `TVMDeviceExtType` to reduce the chance of
unnoticed collision with members of `DLDeviceType`.
  • Loading branch information
Christian Convey authored Nov 1, 2022
1 parent 3259580 commit 9cdc97f
Show file tree
Hide file tree
Showing 13 changed files with 239 additions and 137 deletions.
73 changes: 65 additions & 8 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,74 @@ extern "C" {
/*! \brief type of array index. */
typedef int64_t tvm_index_t;

/*! \brief Extension device types in TVM */
/*! \brief Extension device types in TVM
*
* Additional enumerators to supplement those provided by
* DLPack's `DLDeviceType` enumeration.
*
* MAINTAINERS NOTE #1: We need to ensure that the two devices
* are identified by the same integer.
* Currently this requires manual verification.
* Discussed here: https://github.com/dmlc/dlpack/issues/111
* As of DLPack v0.7, the highest-valued enumerator in
* `DLDeviceType` is kDLHexagon = 16.
*
* MAINTAINERS NOTE #2: As of DLPack v0.7, the definition for
* `DLDeviceType` specifies an underlying storage type of
* `int32_t`. That guarantees a variable of type
* `DLDeviceType` is capable of holding any integers provided
* by *either* of these enumerations.
*
* However, the `int32_t` specification only applies when the
* header file is compiled as C++, and this header file is also
* meant to work as C code. So the unspecified storage type
* could be a latent bug when compiled as C.
*/
#ifdef __cplusplus
typedef enum : int32_t {
#else
typedef enum {
kDLAOCL = 5,
kDLSDAccel = 6,
kOpenGL = 11,
kDLMicroDev = 13,
kDLHexagon = 14,
kDLWebGPU = 15
// AddExtraTVMType which is not in DLPack here
#endif
// To help avoid accidental conflicts between `DLDeviceType`
// and this enumeration, start numbering the new enumerators
// a little higher than (currently) seems necessary.
kDLAOCL = 32,
kDLSDAccel,
kOpenGL,
kDLMicroDev,
TVMDeviceExtType_End, // sentinel value
} TVMDeviceExtType;

#ifdef __cplusplus
// Some other parts of TVM hardcode the integer identifier for
// some DLPack / TVM devices, rather then using the symbolic
// enumerator. E.g., `2` rather than `kDLCUDA`.
// These asserts should alert us when that mapping breaks.
#define TVM_HARCODED_INTEGER_CHANGED_MSG \
"Change in compile-time integer. Make sure hardcoded uses of this integer throughout TVM are " \
"updated."
static_assert(kDLCPU == 1, TVM_HARCODED_INTEGER_CHANGED_MSG);
static_assert(kDLCUDA == 2, TVM_HARCODED_INTEGER_CHANGED_MSG);
static_assert(kDLCUDAHost == 3, TVM_HARCODED_INTEGER_CHANGED_MSG);
static_assert(kDLOpenCL == 4, TVM_HARCODED_INTEGER_CHANGED_MSG);
static_assert(kDLVulkan == 7, TVM_HARCODED_INTEGER_CHANGED_MSG);
static_assert(kDLMetal == 8, TVM_HARCODED_INTEGER_CHANGED_MSG);
static_assert(kDLVPI == 9, TVM_HARCODED_INTEGER_CHANGED_MSG);
static_assert(kDLROCM == 10, TVM_HARCODED_INTEGER_CHANGED_MSG);
static_assert(kDLROCMHost == 11, TVM_HARCODED_INTEGER_CHANGED_MSG);
static_assert(kDLExtDev == 12, TVM_HARCODED_INTEGER_CHANGED_MSG);
static_assert(kDLCUDAManaged == 13, TVM_HARCODED_INTEGER_CHANGED_MSG);
static_assert(kDLOneAPI == 14, TVM_HARCODED_INTEGER_CHANGED_MSG);
static_assert(kDLWebGPU == 15, TVM_HARCODED_INTEGER_CHANGED_MSG);
static_assert(kDLHexagon == 16, TVM_HARCODED_INTEGER_CHANGED_MSG);

static_assert(kDLAOCL == 32, TVM_HARCODED_INTEGER_CHANGED_MSG);
static_assert(kDLSDAccel == 33, TVM_HARCODED_INTEGER_CHANGED_MSG);
static_assert(kOpenGL == 34, TVM_HARCODED_INTEGER_CHANGED_MSG);
static_assert(kDLMicroDev == 35, TVM_HARCODED_INTEGER_CHANGED_MSG);
#undef TVM_HARCODED_INTEGER_CHANGED_MSG
#endif

/*!
* \brief The type code in used and only used in TVM FFI for argument passing.
*
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ class TVM_DLL DeviceAPI {

/*! \brief The device type bigger than this is RPC device */
constexpr int kRPCSessMask = 128;
static_assert(kRPCSessMask >= TVMDeviceExtType_End);

/*!
* \brief The name of Device API factory.
Expand All @@ -248,6 +249,8 @@ inline const char* DeviceName(int type) {
return "cuda";
case kDLCUDAHost:
return "cuda_host";
case kDLCUDAManaged:
return "cuda_managed";
case kDLOpenCL:
return "opencl";
case kDLSDAccel:
Expand All @@ -262,12 +265,20 @@ inline const char* DeviceName(int type) {
return "vpi";
case kDLROCM:
return "rocm";
case kDLROCMHost:
return "rocm_host";
case kDLExtDev:
return "ext_dev";
case kDLOneAPI:
return "oneapi";
case kDLWebGPU:
return "webgpu";
case kDLHexagon:
return "hexagon";
case kOpenGL:
return "opengl";
case kDLMicroDev:
return "microdev";
default:
LOG(FATAL) << "unknown type =" << type;
return "Unknown";
Expand Down
79 changes: 45 additions & 34 deletions jvm/core/src/main/java/org/apache/tvm/Device.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,40 @@

package org.apache.tvm;

import org.apache.tvm.rpc.RPC;

import java.util.HashMap;
import java.util.Map;
import org.apache.tvm.rpc.RPC;

public class Device {
/**
* Provides the same information as the C++ enums DLDeviceType and
* TVMDeviceExtType.
*/
static final int kDLCPU = 1, kDLCUDA = 2, kDLCUDAHost = 3, kDLOpenCL = 4, kDLVulkan = 7,
kDLMetal = 8, kDLVPI = 9, kDLROCM = 10, kDLROCMHost = 11, kDLExtDev = 12,
kDLCUDAManaged = 13, kDLOneAPI = 14, kDLWebGPU = 15, kDLHexagon = 16,
kDLAOCL = 32, kDLSDAccel = 33, kOpenGL = 34, kDLMicroDev = 35;

private static final Map<Integer, String> MASK2STR = new HashMap<Integer, String>();
private static final Map<String, Integer> STR2MASK = new HashMap<String, Integer>();

static {
MASK2STR.put(1, "cpu");
MASK2STR.put(2, "cuda");
MASK2STR.put(4, "opencl");
MASK2STR.put(7, "vulkan");
MASK2STR.put(8, "metal");
MASK2STR.put(9, "vpi");
MASK2STR.put(14, "hexagon");

STR2MASK.put("cpu", 1);
STR2MASK.put("cuda", 2);
STR2MASK.put("cl", 4);
STR2MASK.put("opencl", 4);
STR2MASK.put("vulkan", 7);
STR2MASK.put("metal", 8);
STR2MASK.put("vpi", 9);
STR2MASK.put("hexagon", 14);
MASK2STR.put(kDLCPU, "cpu");
MASK2STR.put(kDLCUDA, "cuda");
MASK2STR.put(kDLOpenCL, "opencl");
MASK2STR.put(kDLVulkan, "vulkan");
MASK2STR.put(kDLMetal, "metal");
MASK2STR.put(kDLVPI, "vpi");
MASK2STR.put(kDLHexagon, "hexagon");

STR2MASK.put("cpu", kDLCPU);
STR2MASK.put("cuda", kDLCUDA);
STR2MASK.put("cl", kDLOpenCL);
STR2MASK.put("opencl", kDLOpenCL);
STR2MASK.put("vulkan", kDLVulkan);
STR2MASK.put("metal", kDLMetal);
STR2MASK.put("vpi", kDLVPI);
STR2MASK.put("hexagon", kDLHexagon);
}

/**
Expand All @@ -51,7 +59,7 @@ public class Device {
* @return The created device
*/
public static Device cpu(int devId) {
return new Device(1, devId);
return new Device(kDLCPU, devId);
}

public static Device cpu() {
Expand All @@ -64,7 +72,7 @@ public static Device cpu() {
* @return The created device
*/
public static Device cuda(int devId) {
return new Device(2, devId);
return new Device(kDLCUDA, devId);
}

public static Device cuda() {
Expand All @@ -77,7 +85,7 @@ public static Device cuda() {
* @return The created device
*/
public static Device opencl(int devId) {
return new Device(4, devId);
return new Device(kDLOpenCL, devId);
}

public static Device opencl() {
Expand All @@ -90,7 +98,7 @@ public static Device opencl() {
* @return The created device
*/
public static Device vulkan(int devId) {
return new Device(7, devId);
return new Device(kDLVulkan, devId);
}

public static Device vulkan() {
Expand All @@ -103,7 +111,7 @@ public static Device vulkan() {
* @return The created device
*/
public static Device metal(int devId) {
return new Device(8, devId);
return new Device(kDLMetal, devId);
}

public static Device metal() {
Expand All @@ -116,7 +124,7 @@ public static Device metal() {
* @return The created device
*/
public static Device vpi(int devId) {
return new Device(9, devId);
return new Device(kDLVPI, devId);
}

public static Device vpi() {
Expand All @@ -129,7 +137,7 @@ public static Device vpi() {
* @return The created device
*/
public static Device hexagon(int devId) {
return new Device(14, devId);
return new Device(kDLHexagon, devId);
}

public static Device hexagon() {
Expand All @@ -153,8 +161,8 @@ public Device(String deviceType, int deviceId) {
* @return true if exists.
*/
public boolean exist() {
TVMValue ret = APIInternal.get("_GetDeviceAttr")
.pushArg(deviceType).pushArg(deviceId).pushArg(0).invoke();
TVMValue ret =
APIInternal.get("_GetDeviceAttr").pushArg(deviceType).pushArg(deviceId).pushArg(0).invoke();
return ((TVMValueLong) ret).value != 0;
}

Expand All @@ -163,8 +171,8 @@ public boolean exist() {
* @return the maximum thread number.
*/
public long maxThreadsPerBlock() {
TVMValue ret = APIInternal.get("_GetDeviceAttr")
.pushArg(deviceType).pushArg(deviceId).pushArg(1).invoke();
TVMValue ret =
APIInternal.get("_GetDeviceAttr").pushArg(deviceType).pushArg(deviceId).pushArg(1).invoke();
return ((TVMValueLong) ret).value;
}

Expand All @@ -173,8 +181,8 @@ public long maxThreadsPerBlock() {
* @return the thread number.
*/
public long warpSize() {
TVMValue ret = APIInternal.get("_GetDeviceAttr")
.pushArg(deviceType).pushArg(deviceId).pushArg(2).invoke();
TVMValue ret =
APIInternal.get("_GetDeviceAttr").pushArg(deviceType).pushArg(deviceId).pushArg(2).invoke();
return ((TVMValueLong) ret).value;
}

Expand All @@ -185,19 +193,22 @@ public void sync() {
Base.checkCall(Base._LIB.tvmSynchronize(deviceType, deviceId));
}

@Override public int hashCode() {
@Override
public int hashCode() {
return (deviceType << 16) | deviceId;
}

@Override public boolean equals(Object other) {
@Override
public boolean equals(Object other) {
if (other != null && other instanceof Device) {
Device obj = (Device) other;
return deviceId == obj.deviceId && deviceType == obj.deviceType;
}
return false;
}

@Override public String toString() {
@Override
public String toString() {
if (deviceType >= RPC.RPC_SESS_MASK) {
int tblId = deviceType / RPC.RPC_SESS_MASK - 1;
int devType = deviceType % RPC.RPC_SESS_MASK;
Expand Down
Loading

0 comments on commit 9cdc97f

Please sign in to comment.