diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 9a24b7516128..2f9d74dc5ba0 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -95,10 +95,22 @@ typedef void *CudaKernelHandle; typedef void *ProfileHandle; /*! \brief handle to DLManagedTensor*/ typedef void *DLManagedTensorHandle; - +/*! \brief handle to Context */ +typedef const void *ContextHandle; +/*! \brief handle to Engine FnProperty */ +typedef const void *EngineFnPropertyHandle; +/*! \brief handle to Engine VarHandle */ +typedef void *EngineVarHandle; + +/*! \brief Engine asynchronous operation */ +typedef void (*EngineAsyncFunc)(void*, void*, void*); +/*! \brief Engine synchronous operation */ +typedef void (*EngineSyncFunc)(void*, void*); +/*! \brief Callback to free the param for EngineAsyncFunc/EngineSyncFunc */ +typedef void (*EngineFuncParamDeleter)(void*); typedef void (*ExecutorMonitorCallback)(const char*, NDArrayHandle, - void *); + void*); struct NativeOpInfo { void (*forward)(int, float**, int*, unsigned**, int*, void*); @@ -2541,6 +2553,51 @@ MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid, MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape, mx_uint ndim, int dtype, NDArrayHandle *out); +/*! + * \brief Push an asynchronous operation to the engine. + * \param async_func Execution function whici takes a parameter on_complete + * that must be called when the execution ompletes. + * \param func_param The parameter set on calling async_func, can be NULL. + * \param deleter The callback to free func_param, can be NULL. + * \param ctx_handle Execution context. + * \param const_vars_handle The variables that current operation will use + * but not mutate. + * \param num_const_vars The number of const_vars. + * \param mutable_vars_handle The variables that current operation will mutate. + * \param num_mutable_vars The number of mutable_vars. + * \param prop_handle Property of the function. + * \param priority Priority of the action, as hint to the engine. + * \param opr_name The operation name. + * \param wait Whether this is a WaitForVar operation. + */ +MXNET_DLL int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param, + EngineFuncParamDeleter deleter, ContextHandle ctx_handle, + EngineVarHandle const_vars_handle, int num_const_vars, + EngineVarHandle mutable_vars_handle, int num_mutable_vars, + EngineFnPropertyHandle prop_handle = NULL, int priority = 0, + const char* opr_name = NULL, bool wait = false); + +/*! + * \brief Push a synchronous operation to the engine. + * \param sync_func Execution function that executes the operation. + * \param func_param The parameter set on calling sync_func, can be NULL. + * \param deleter The callback to free func_param, can be NULL. + * \param ctx_handle Execution context. + * \param const_vars_handle The variables that current operation will use + * but not mutate. + * \param num_const_vars The number of const_vars. + * \param mutable_vars_handle The variables that current operation will mutate. + * \param num_mutable_vars The number of mutable_vars. + * \param prop_handle Property of the function. + * \param priority Priority of the action, as hint to the engine. + * \param opr_name The operation name. + */ +MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param, + EngineFuncParamDeleter deleter, ContextHandle ctx_handle, + EngineVarHandle const_vars_handle, int num_const_vars, + EngineVarHandle mutable_vars_handle, int num_mutable_vars, + EngineFnPropertyHandle prop_handle = NULL, int priority = 0, + const char* opr_name = NULL); #ifdef __cplusplus } diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 70ba84b5f94b..45197aafe019 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1401,3 +1401,91 @@ int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *s *out = new NDArray(shared_pid, shared_id, mxnet::TShape(shape, shape + ndim), dtype); API_END(); } + +typedef Engine::VarHandle VarHandle; +typedef Engine::CallbackOnComplete CallbackOnComplete; + +void AssertValidNumberVars(int num_const_vars, int num_mutable_vars) { + CHECK_GE(num_const_vars, 0) << "Non-negative number of const vars expected."; + CHECK_GE(num_mutable_vars, 0) << "Non-negative number of mutable vars expected."; +} + +int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param, + EngineFuncParamDeleter deleter, ContextHandle ctx_handle, + EngineVarHandle const_vars_handle, int num_const_vars, + EngineVarHandle mutable_vars_handle, int num_mutable_vars, + EngineFnPropertyHandle prop_handle, int priority, + const char* opr_name, bool wait) { + API_BEGIN(); + + auto exec_ctx = *static_cast(ctx_handle); + auto const_vars = static_cast(const_vars_handle); + auto mutable_vars = static_cast(mutable_vars_handle); + auto prop = FnProperty::kNormal; + if (prop_handle) { + prop = *static_cast(prop_handle); + } + + Engine::AsyncFn exec_fn; + if (deleter == nullptr) { + exec_fn = [async_func, func_param](RunContext rctx, + CallbackOnComplete on_complete) { + async_func(&rctx, &on_complete, func_param); + }; + } else { + // Wrap func_param in a shared_ptr with deleter such that deleter + // will be called when the lambda goes out of scope. + std::shared_ptr shared_func_param(func_param, deleter); + exec_fn = [async_func, shared_func_param](RunContext rctx, + CallbackOnComplete on_complete) { + async_func(&rctx, &on_complete, shared_func_param.get()); + }; + } + + AssertValidNumberVars(num_const_vars, num_mutable_vars); + std::vector const_var_vec(const_vars, const_vars + num_const_vars); + std::vector mutable_var_vec(mutable_vars, mutable_vars + num_mutable_vars); + Engine::Get()->PushAsync(exec_fn, exec_ctx, const_var_vec, mutable_var_vec, + prop, priority, opr_name, wait); + + API_END(); +} + +int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param, + EngineFuncParamDeleter deleter, ContextHandle ctx_handle, + EngineVarHandle const_vars_handle, int num_const_vars, + EngineVarHandle mutable_vars_handle, int num_mutable_vars, + EngineFnPropertyHandle prop_handle, int priority, + const char* opr_name) { + API_BEGIN(); + + auto exec_ctx = *static_cast(ctx_handle); + auto const_vars = static_cast(const_vars_handle); + auto mutable_vars = static_cast(mutable_vars_handle); + auto prop = FnProperty::kNormal; + if (prop_handle) { + prop = *static_cast(prop_handle); + } + + Engine::SyncFn exec_fn; + if (deleter == nullptr) { + exec_fn = [sync_func, func_param](RunContext rctx) { + sync_func(&rctx, func_param); + }; + } else { + // Wrap func_param in a shared_ptr with deleter such that deleter + // will be called when the lambda goes out of scope. + std::shared_ptr shared_func_param(func_param, deleter); + exec_fn = [sync_func, shared_func_param](RunContext rctx) { + sync_func(&rctx, shared_func_param.get()); + }; + } + + AssertValidNumberVars(num_const_vars, num_mutable_vars); + std::vector const_var_vec(const_vars, const_vars + num_const_vars); + std::vector mutable_var_vec(mutable_vars, mutable_vars + num_mutable_vars); + Engine::Get()->PushSync(exec_fn, exec_ctx, const_var_vec, mutable_var_vec, + prop, priority, opr_name); + + API_END(); +} diff --git a/tests/cpp/engine/threaded_engine_test.cc b/tests/cpp/engine/threaded_engine_test.cc index 6d669c19bcaa..405f3b30a176 100644 --- a/tests/cpp/engine/threaded_engine_test.cc +++ b/tests/cpp/engine/threaded_engine_test.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -176,6 +177,83 @@ TEST(Engine, RandSumExpr) { void Foo(mxnet::RunContext, int i) { printf("The fox says %d\n", i); } +void FooAsyncFunc(void*, void* cb_ptr, void* param) { + if (param == nullptr) { + LOG(INFO) << "The fox asynchronously says receiving nothing."; + } else { + auto num = static_cast(param); + EXPECT_EQ(*num, 100); + LOG(INFO) << "The fox asynchronously says receiving " << *num; + } + auto cb = *static_cast(cb_ptr); + cb(); +} + +void FooSyncFunc(void*, void* param) { + if (param == nullptr) { + LOG(INFO) << "The fox synchronously says receiving nothing."; + } else { + auto num = static_cast(param); + EXPECT_EQ(*num, 101); + LOG(INFO) << "The fox synchronously says receiving " << *num; + } +} + +void FooFuncDeleter(void* param) { + if (param != nullptr) { + auto num = static_cast(param); + LOG(INFO) << "The fox says deleting " << *num; + delete num; + } +} + +TEST(Engine, PushFunc) { + auto var = mxnet::Engine::Get()->NewVariable(); + auto ctx = mxnet::Context{}; + + // Test #1 + LOG(INFO) << "===== Test #1: PushAsync param and deleter ====="; + int* a = new int(100); + int res = MXEnginePushAsync(FooAsyncFunc, a, FooFuncDeleter, &ctx, &var, 1, nullptr, 0); + EXPECT_EQ(res, 0); + + // Test #2 + LOG(INFO) << "===== Test #2: PushAsync NULL param and NULL deleter ====="; + res = MXEnginePushAsync(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &var, 0); + EXPECT_EQ(res, 0); + + // Test #3 + LOG(INFO) << "===== Test #3: PushAsync invalid number of const vars ====="; + res = MXEnginePushAsync(FooAsyncFunc, nullptr, nullptr, &ctx, &var, -1, nullptr, 0); + EXPECT_EQ(res, -1); + + // Test #4 + LOG(INFO) << "===== Test #4: PushAsync invalid number of mutable vars ====="; + res = MXEnginePushAsync(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &var, -1); + EXPECT_EQ(res, -1); + + // Test #5 + LOG(INFO) << "===== Test #5: PushSync param and deleter ====="; + int* b = new int(101); + res = MXEnginePushSync(FooSyncFunc, b, FooFuncDeleter, &ctx, &var, 1, nullptr, 0); + EXPECT_EQ(res, 0); + + // Test #6 + LOG(INFO) << "===== Test #6: PushSync NULL param and NULL deleter ====="; + res = MXEnginePushSync(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &var, 1); + EXPECT_EQ(res, 0); + + // Test #7 + LOG(INFO) << "===== Test #7: PushSync invalid number of const vars ====="; + res = MXEnginePushSync(FooSyncFunc, nullptr, nullptr, &ctx, &var, -1, nullptr, 0); + EXPECT_EQ(res, -1); + + // Test #8 + LOG(INFO) << "===== Test #8: PushSync invalid number of mutable vars ====="; + res = MXEnginePushSync(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &var, -1); + EXPECT_EQ(res, -1); +} + TEST(Engine, basics) { auto&& engine = mxnet::Engine::Get(); auto&& var = engine->NewVariable();