From 7e025234fea42fdab4c71e069775801da9e68a6e Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 16 Jan 2017 14:53:26 -0800 Subject: [PATCH] [RUNTIME] Add interface header of runtime (#15) * [RUNTIME] Add interface header of runtime * fix mac build --- HalideIR | 2 +- Makefile | 2 +- include/tvm/c_api.h | 84 ++++---------- include/tvm/c_runtime_api.h | 205 +++++++++++++++++++++++++++++++++++ python/tvm/_base.py | 4 - python/tvm/_ctypes/_api.py | 39 ++++--- src/README.md | 1 + src/c_api/c_api.cc | 46 +++----- src/c_api/c_api_common.h | 25 +---- src/c_api/c_api_registry.h | 3 + src/runtime/error_handle.cc | 22 ++++ src/runtime/runtime_common.h | 36 ++++++ 12 files changed, 328 insertions(+), 141 deletions(-) create mode 100644 include/tvm/c_runtime_api.h create mode 100644 src/runtime/error_handle.cc create mode 100644 src/runtime/runtime_common.h diff --git a/HalideIR b/HalideIR index 3278103721cf..6375e6b76f6b 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit 3278103721cfabf7435f1e9ba1fd75a7c38f13c9 +Subproject commit 6375e6b76f6b70d58f66b357d946c971843f3169 diff --git a/Makefile b/Makefile index e7dcebc3c586..86368ab2038f 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ export LDFLAGS = -pthread -lm -export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\ +export CFLAGS = -std=c++11 -Wall -O2\ -Iinclude -Idmlc-core/include -IHalideIR/src -fPIC # specify tensor path diff --git a/include/tvm/c_api.h b/include/tvm/c_api.h index 5f788508764e..2e7bb8545a38 100644 --- a/include/tvm/c_api.h +++ b/include/tvm/c_api.h @@ -6,74 +6,28 @@ #ifndef TVM_C_API_H_ #define TVM_C_API_H_ -#ifdef __cplusplus -#define TVM_EXTERN_C extern "C" -#else -#define TVM_EXTERN_C -#endif - -/*! \brief TVM_DLL prefix for windows */ -#ifdef _WIN32 -#ifdef TVM_EXPORTS -#define TVM_DLL __declspec(dllexport) -#else -#define TVM_DLL __declspec(dllimport) -#endif -#else -#define TVM_DLL -#endif +#include "./c_runtime_api.h" TVM_EXTERN_C { /*! \brief handle to functions */ -typedef void* FunctionHandle; +typedef void* APIFunctionHandle; /*! \brief handle to node */ typedef void* NodeHandle; -/*! - * \brief union type for returning value of attributes - * Attribute type can be identified by id - */ -typedef union { - long v_long; // NOLINT(*) - double v_double; - const char* v_str; - NodeHandle v_handle; -} ArgVariant; - -/*! \brief attribute types */ -typedef enum { - kNull = 0, - kLong = 1, - kDouble = 2, - kStr = 3, - kNodeHandle = 4 -} ArgVariantID; - -/*! - * \brief return str message of the last error - * all function in this file will return 0 when success - * and -1 when an error occured, - * NNGetLastError can be called to retrieve the error - * - * this function is threadsafe and can be called by different thread - * \return error info - */ -TVM_DLL const char *TVMGetLastError(void); - /*! * \brief List all the node function name * \param out_size The number of functions * \param out_array The array of function names. */ -TVM_DLL int TVMListFunctionNames(int *out_size, +TVM_DLL int TVMListAPIFunctionNames(int *out_size, const char*** out_array); /*! * \brief get function handle by name * \param name The name of function * \param handle The returning function handle */ -TVM_DLL int TVMGetFunctionHandle(const char* name, - FunctionHandle *handle); +TVM_DLL int TVMGetAPIFunctionHandle(const char* name, + APIFunctionHandle *handle); /*! * \brief Get the detailed information about function. @@ -88,14 +42,14 @@ TVM_DLL int TVMGetFunctionHandle(const char* name, * \param return_type Return type of the function, if any. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMGetFunctionInfo(FunctionHandle handle, - const char **real_name, - const char **description, - int *num_doc_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type); +TVM_DLL int TVMGetAPIFunctionInfo(APIFunctionHandle handle, + const char **real_name, + const char **description, + int *num_doc_args, + const char ***arg_names, + const char ***arg_type_infos, + const char ***arg_descriptions, + const char **return_type); /*! * \brief Push an argument to the function calling stack. @@ -104,8 +58,8 @@ TVM_DLL int TVMGetFunctionInfo(FunctionHandle handle, * \param arg number of attributes * \param type_id The typeid of attributes. */ -TVM_DLL int TVMPushStack(ArgVariant arg, - int type_id); +TVM_DLL int TVMAPIPushStack(TVMArg arg, + int type_id); /*! * \brief call a function by using arguments in the stack. @@ -115,9 +69,9 @@ TVM_DLL int TVMPushStack(ArgVariant arg, * \param ret_val The return value. * \param ret_typeid the type id of return value. */ -TVM_DLL int TVMFunctionCall(FunctionHandle handle, - ArgVariant* ret_val, - int* ret_typeid); +TVM_DLL int TVMAPIFunctionCall(APIFunctionHandle handle, + TVMArg* ret_val, + int* ret_typeid); /*! * \brief free the node handle @@ -135,7 +89,7 @@ TVM_DLL int TVMNodeFree(NodeHandle handle); */ TVM_DLL int TVMNodeGetAttr(NodeHandle handle, const char* key, - ArgVariant* out_value, + TVMArg* out_value, int* out_typeid, int* out_success); diff --git a/include/tvm/c_runtime_api.h b/include/tvm/c_runtime_api.h new file mode 100644 index 000000000000..f99d83e60f2a --- /dev/null +++ b/include/tvm/c_runtime_api.h @@ -0,0 +1,205 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file c_runtime_api.h + * \brief TVM runtime library. + * + * The philosophy of TVM project is to customize the compilation + * stage to generate code that can used by other projects transparently. + * + * So this is a minimum runtime code gluing, and some limited + * memory management code to enable quick testing. + */ +#ifndef TVM_C_RUNTIME_API_H_ +#define TVM_C_RUNTIME_API_H_ + +#ifdef __cplusplus +#define TVM_EXTERN_C extern "C" +#else +#define TVM_EXTERN_C +#endif + +/*! \brief TVM_DLL prefix for windows */ +#ifdef _WIN32 +#ifdef TVM_EXPORTS +#define TVM_DLL __declspec(dllexport) +#else +#define TVM_DLL __declspec(dllimport) +#endif +#else +#define TVM_DLL +#endif + +#include + + +TVM_EXTERN_C { +/*! \brief type of array index. */ +typedef unsigned tvm_index_t; + +/*! + * \brief union type for arguments and return values + * in both runtime API and TVM API calls + */ +typedef union { + long v_long; // NOLINT(*) + double v_double; + const char* v_str; + void* v_handle; +} TVMArg; + +/*! + * \brief The type index in TVM. + */ +typedef enum { + kNull = 0, + kLong = 1, + kDouble = 2, + kStr = 3, + kNodeHandle = 4, + kArrayHandle = 5 +} TVMArgTypeID; + +/*! + * \brief The device type + */ +typedef enum { + /*! \brief CPU device */ + kCPU = 1, + /*! \brief NVidia GPU device(CUDA) */ + kGPU = 2, + /*! \brief opencl device */ + KOpenCL = 4 +} TVMDeviceMask; + +/*! + * \brief The Device information, abstract away common device types. + */ +typedef struct { + /*! \brief The device type mask */ + int dev_mask; + /*! \brief the device id */ + int dev_id; +} TVMDevice; + +/*! \brief The type code in TVMDataType */ +typedef enum { + kInt = 0U, + kUInt = 1U, + kFloat = 2U +} TVMTypeCode; + +/*! + * \brief the data type + * Examples + * - float: type_code = 2, bits = 32, lanes=1 + * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 + * - int8: type_code = 0, bits = 8, lanes=1 + */ +typedef struct { + /*! \brief type code, in TVMTypeCode */ + uint8_t type_code; + /*! \brief number of bits of the type */ + uint8_t bits; + /*! \brief number of lanes, */ + uint16_t lanes; +} TVMDataType; + +/*! + * \brief Data structure representing a n-dimensional array(tensor). + * This is used to pass data specification into TVM. + */ +typedef struct { + /*! \brief The data field pointer on specified device */ + void* data; + /*! \brief The shape pointers of the array */ + const tvm_index_t* shape; + /*! + * \brief The stride data about each dimension of the array, can be NULL + * When strides is NULL, it indicates that the array is empty. + */ + const tvm_index_t* strides; + /*! \brief number of dimensions of the array */ + tvm_index_t ndim; + /*! \brief The data type flag */ + TVMDataType dtype; + /*! \brief The device this array sits on */ + TVMDevice device; +} TVMArray; + +/*! + * \brief The stream that is specific to device + * can be NULL, which indicates the default one. + */ +typedef void* TVMStreamHandle; +/*! + * \brief Pointer to function handle that points to + * a generated TVM function. + */ +typedef void* TVMFunctionHandle; +/*! \brief the array handle */ +typedef TVMArray* TVMArrayHandle; + +/*! + * \brief return str message of the last error + * all function in this file will return 0 when success + * and -1 when an error occured, + * TVMGetLastError can be called to retrieve the error + * + * this function is threadsafe and can be called by different thread + * \return error info + */ +TVM_DLL const char *TVMGetLastError(void); + +/*! + * \brief Allocate a nd-array's memory, + * including space of shape, of given spec. + * + * \param shape The shape of the array, the data content will be copied to out + * \param ndim The number of dimension of the array. + * \param dtype The array data type. + * \param device The device this array sits on. + * \param out The output handle. + * \return Whether the function is successful. + */ +TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, + tvm_index_t ndim, + int dtype, + TVMDevice device, + TVMArrayHandle* out); +/*! + * \brief Free the TVM Array. + * \param handle The array handle to be freed. + */ +TVM_DLL int TVMArrayFree(TVMArrayHandle handle); + +/*! + * \brief Copy the array, both from and to must be valid during the copy. + * \param from The array to be copied from. + * \param to The target space. + * \param stream The stream where the copy happens, can be NULL. + */ +TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, + TVMArrayHandle to, + TVMStreamHandle stream); +/*! + * \brief Wait until all computations on stream completes. + * \param stream the stream to be synchronized. + */ +TVM_DLL int TVMSynchronize(TVMStreamHandle stream); + +/*! + * \brief Launch a generated TVM function + * \param func function handle to be launched. + * \param args The arguments + * \param arg_type_ids The type id of the arguments + * \param num_args Number of arguments. + * \param stream The stream this function to be launched on. + */ +TVM_DLL int TVMLaunch(TVMFunctionHandle func, + TVMArg* args, + int* arg_type_ids, + int num_args, + TVMStreamHandle stream); +} // TVM_EXTERN_C + +#endif // TVM_C_RUNTIME_API_H_ diff --git a/python/tvm/_base.py b/python/tvm/_base.py index b67275cb4030..cfa85a55ac1c 100644 --- a/python/tvm/_base.py +++ b/python/tvm/_base.py @@ -41,10 +41,6 @@ def _load_lib(): # library instance of nnvm _LIB = _load_lib() -# type definitions -FunctionHandle = ctypes.c_void_p -NodeHandle = ctypes.c_void_p - #---------------------------- # helper function definition #---------------------------- diff --git a/python/tvm/_ctypes/_api.py b/python/tvm/_ctypes/_api.py index cf5df619001c..3ad5aa330468 100644 --- a/python/tvm/_ctypes/_api.py +++ b/python/tvm/_ctypes/_api.py @@ -10,17 +10,20 @@ from .._base import _LIB from .._base import c_str, py_str, string_types -from .._base import FunctionHandle, NodeHandle from .._base import check_call, ctypes2docstring from .. import _function_internal -class ArgVariant(ctypes.Union): - """ArgVariant in C API""" +class TVMArg(ctypes.Union): + """TVMArg in C API""" _fields_ = [("v_long", ctypes.c_long), ("v_double", ctypes.c_double), ("v_str", ctypes.c_char_p), ("v_handle", ctypes.c_void_p)] +# type definitions +APIFunctionHandle = ctypes.c_void_p +NodeHandle = ctypes.c_void_p + kNull = 0 kLong = 1 kDouble = 2 @@ -34,7 +37,7 @@ def _return_node(x): handle = x.v_handle if not isinstance(handle, NodeHandle): handle = NodeHandle(handle) - ret_val = ArgVariant() + ret_val = TVMArg() ret_typeid = ctypes.c_int() ret_success = ctypes.c_int() check_call(_LIB.TVMNodeGetAttr( @@ -77,7 +80,7 @@ def __del__(self): check_call(_LIB.TVMNodeFree(self.handle)) def __getattr__(self, name): - ret_val = ArgVariant() + ret_val = TVMArg() ret_typeid = ctypes.c_int() ret_success = ctypes.c_int() check_call(_LIB.TVMNodeGetAttr( @@ -169,21 +172,21 @@ def convert(value): def _push_arg(arg): - a = ArgVariant() + a = TVMArg() if arg is None: - _LIB.TVMPushStack(a, ctypes.c_int(kNull)) + _LIB.TVMAPIPushStack(a, ctypes.c_int(kNull)) elif isinstance(arg, NodeBase): a.v_handle = arg.handle - _LIB.TVMPushStack(a, ctypes.c_int(kNodeHandle)) + _LIB.TVMAPIPushStack(a, ctypes.c_int(kNodeHandle)) elif isinstance(arg, int): a.v_long = ctypes.c_long(arg) - _LIB.TVMPushStack(a, ctypes.c_int(kLong)) + _LIB.TVMAPIPushStack(a, ctypes.c_int(kLong)) elif isinstance(arg, Number): a.v_double = ctypes.c_double(arg) - _LIB.TVMPushStack(a, ctypes.c_int(kDouble)) + _LIB.TVMAPIPushStack(a, ctypes.c_int(kDouble)) elif isinstance(arg, string_types): a.v_str = c_str(arg) - _LIB.TVMPushStack(a, ctypes.c_int(kStr)) + _LIB.TVMAPIPushStack(a, ctypes.c_int(kStr)) else: raise TypeError("Don't know how to handle type %s" % type(arg)) @@ -198,7 +201,7 @@ def _make_function(handle, name): arg_descs = ctypes.POINTER(ctypes.c_char_p)() ret_type = ctypes.c_char_p() - check_call(_LIB.TVMGetFunctionInfo( + check_call(_LIB.TVMGetAPIFunctionInfo( handle, ctypes.byref(real_name), ctypes.byref(desc), ctypes.byref(num_args), ctypes.byref(arg_names), @@ -232,9 +235,9 @@ def func(*args): for arg in cargs: _push_arg(arg) - ret_val = ArgVariant() + ret_val = TVMArg() ret_typeid = ctypes.c_int() - check_call(_LIB.TVMFunctionCall( + check_call(_LIB.TVMAPIFunctionCall( handle, ctypes.byref(ret_val), ctypes.byref(ret_typeid))) return RET_SWITCH[ret_typeid.value](ret_val) @@ -267,8 +270,8 @@ def _init_function_module(root_namespace): plist = ctypes.POINTER(ctypes.c_char_p)() size = ctypes.c_uint() - check_call(_LIB.TVMListFunctionNames(ctypes.byref(size), - ctypes.byref(plist))) + check_call(_LIB.TVMListAPIFunctionNames(ctypes.byref(size), + ctypes.byref(plist))) op_names = [] for i in range(size.value): op_names.append(py_str(plist[i])) @@ -282,8 +285,8 @@ def _init_function_module(root_namespace): } for name in op_names: - hdl = FunctionHandle() - check_call(_LIB.TVMGetFunctionHandle(c_str(name), ctypes.byref(hdl))) + hdl = APIFunctionHandle() + check_call(_LIB.TVMGetAPIFunctionHandle(c_str(name), ctypes.byref(hdl))) fname = name target_module = module_internal if name.startswith('_') else module_obj for k, v in namespace_match.items(): diff --git a/src/README.md b/src/README.md index d1d8201e02d8..59cf081d81b1 100644 --- a/src/README.md +++ b/src/README.md @@ -4,3 +4,4 @@ - lang The definition of DSL related data structure - schedule The operations on the schedule graph before converting to IR. - pass The optimization pass on the IR structure +- runtime The runtime related codes. \ No newline at end of file diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 45ae12601c28..e6ce2a5a9c7e 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -9,8 +9,6 @@ /*! \brief entry to to easily hold returning information */ struct TVMAPIThreadLocalEntry { - /*! \brief hold last error */ - std::string last_error; /*! \brief result holder for returning strings */ std::vector ret_vec_str; /*! \brief result holder for returning string pointers */ @@ -99,16 +97,8 @@ struct APIAttrDir : public AttrVisitor { } }; -const char *TVMGetLastError() { - return TVMAPIThreadLocalStore::Get()->last_error.c_str(); -} - -void TVMAPISetLastError(const char* msg) { - TVMAPIThreadLocalStore::Get()->last_error = msg; -} - -int TVMListFunctionNames(int *out_size, - const char*** out_array) { +int TVMListAPIFunctionNames(int *out_size, + const char*** out_array) { API_BEGIN(); TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); ret->ret_vec_str = dmlc::Registry::ListAllNames(); @@ -121,23 +111,23 @@ int TVMListFunctionNames(int *out_size, API_END(); } -int TVMGetFunctionHandle(const char* fname, - FunctionHandle* out) { +int TVMGetAPIFunctionHandle(const char* fname, + APIFunctionHandle* out) { API_BEGIN(); const APIFunctionReg* reg = dmlc::Registry::Find(fname); CHECK(reg != nullptr) << "cannot find function " << fname; - *out = (FunctionHandle)reg; + *out = (APIFunctionHandle)reg; API_END(); } -int TVMGetFunctionInfo(FunctionHandle handle, - const char **real_name, - const char **description, - int *num_doc_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type) { +int TVMGetAPIFunctionInfo(APIFunctionHandle handle, + const char **real_name, + const char **description, + int *num_doc_args, + const char ***arg_names, + const char ***arg_type_infos, + const char ***arg_descriptions, + const char **return_type) { const auto *op = static_cast(handle); TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); @@ -162,8 +152,8 @@ int TVMGetFunctionInfo(FunctionHandle handle, API_END(); } -int TVMPushStack(ArgVariant arg, - int type_id) { +int TVMAPIPushStack(ArgVariant arg, + int type_id) { TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); API_BEGIN(); ret->arg_stack.resize(ret->arg_stack.size() + 1); @@ -181,9 +171,9 @@ int TVMPushStack(ArgVariant arg, API_END_HANDLE_ERROR(ret->Clear()); } -int TVMFunctionCall(FunctionHandle handle, - ArgVariant* ret_val, - int* ret_typeid) { +int TVMAPIFunctionCall(APIFunctionHandle handle, + ArgVariant* ret_val, + int* ret_typeid) { TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); API_BEGIN(); const auto *op = static_cast(handle); diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index 100f7e6513fe..c5ec453b9c26 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -14,29 +14,6 @@ #include #include #include "./c_api_registry.h" - -/*! \brief macro to guard beginning and end section of all functions */ -#define API_BEGIN() try { -/*! \brief every function starts with API_BEGIN(); - and finishes with API_END() or API_END_HANDLE_ERROR */ -#define API_END() } catch(std::runtime_error &_except_) { return TVMAPIHandleException(_except_); } return 0; // NOLINT(*) -/*! - * \brief every function starts with API_BEGIN(); - * and finishes with API_END() or API_END_HANDLE_ERROR - * The finally clause contains procedure to cleanup states when an error happens. - */ -#define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*) - -void TVMAPISetLastError(const char* msg); - -/*! - * \brief handle exception throwed out - * \param e the exception - * \return the return value of API after exception is handled - */ -inline int TVMAPIHandleException(const std::runtime_error &e) { - TVMAPISetLastError(e.what()); - return -1; -} +#include "../runtime/runtime_common.h" #endif // TVM_C_API_C_API_COMMON_H_ diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h index 0baa1fbd9a30..835368f500ce 100644 --- a/src/c_api/c_api_registry.h +++ b/src/c_api/c_api_registry.h @@ -15,6 +15,9 @@ #include #include "../base/common.h" +using ArgVariant = TVMArg; +using ArgVariantID = TVMArgTypeID; + namespace tvm { inline const char* TypeId2Str(ArgVariantID type_id) { diff --git a/src/runtime/error_handle.cc b/src/runtime/error_handle.cc new file mode 100644 index 000000000000..dad261e8e607 --- /dev/null +++ b/src/runtime/error_handle.cc @@ -0,0 +1,22 @@ +/*! + * Copyright (c) 2016 by Contributors + * Implementation of error handling API + * \file error_handle.cc + */ +#include +#include +#include "./runtime_common.h" + +struct TVMErrorEntry { + std::string last_error; +}; + +typedef dmlc::ThreadLocalStore TVMAPIErrorStore; + +const char *TVMGetLastError() { + return TVMAPIErrorStore::Get()->last_error.c_str(); +} + +void TVMAPISetLastError(const char* msg) { + TVMAPIErrorStore::Get()->last_error = msg; +} diff --git a/src/runtime/runtime_common.h b/src/runtime/runtime_common.h new file mode 100644 index 000000000000..998fd39857c9 --- /dev/null +++ b/src/runtime/runtime_common.h @@ -0,0 +1,36 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file c_runtime_common.h + * \brief Common fields of all C APIs + */ +#ifndef TVM_RUNTIME_RUNTIME_COMMON_H_ +#define TVM_RUNTIME_RUNTIME_COMMON_H_ + +#include +#include + +/*! \brief macro to guard beginning and end section of all functions */ +#define API_BEGIN() try { +/*! \brief every function starts with API_BEGIN(); + and finishes with API_END() or API_END_HANDLE_ERROR */ +#define API_END() } catch(std::runtime_error &_except_) { return TVMAPIHandleException(_except_); } return 0; // NOLINT(*) +/*! + * \brief every function starts with API_BEGIN(); + * and finishes with API_END() or API_END_HANDLE_ERROR + * The finally clause contains procedure to cleanup states when an error happens. + */ +#define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*) + +void TVMAPISetLastError(const char* msg); + +/*! + * \brief handle exception throwed out + * \param e the exception + * \return the return value of API after exception is handled + */ +inline int TVMAPIHandleException(const std::runtime_error &e) { + TVMAPISetLastError(e.what()); + return -1; +} + +#endif // TVM_RUNTIME_RUNTIME_COMMON_H_