Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add MXEnginePushAsync and MXEnginePushSync C APIs #14615

Merged
merged 7 commits into from
Apr 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,22 @@ typedef void *CudaKernelHandle;
typedef void *ProfileHandle;
/*! \brief handle to DLManagedTensor*/
typedef void *DLManagedTensorHandle;

/*! \brief handle to Context */
typedef const void *ContextHandle;
/*! \brief handle to Engine FnProperty */
typedef const void *EngineFnPropertyHandle;
/*! \brief handle to Engine VarHandle */
typedef void *EngineVarHandle;

/*! \brief Engine asynchronous operation */
typedef void (*EngineAsyncFunc)(void*, void*, void*);
/*! \brief Engine synchronous operation */
typedef void (*EngineSyncFunc)(void*, void*);
/*! \brief Callback to free the param for EngineAsyncFunc/EngineSyncFunc */
typedef void (*EngineFuncParamDeleter)(void*);
typedef void (*ExecutorMonitorCallback)(const char*,
NDArrayHandle,
void *);
void*);

struct NativeOpInfo {
void (*forward)(int, float**, int*, unsigned**, int*, void*);
Expand Down Expand Up @@ -2541,6 +2553,51 @@ MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid,
MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape,
mx_uint ndim, int dtype, NDArrayHandle *out);

/*!
* \brief Push an asynchronous operation to the engine.
* \param async_func Execution function whici takes a parameter on_complete
* that must be called when the execution ompletes.
* \param func_param The parameter set on calling async_func, can be NULL.
* \param deleter The callback to free func_param, can be NULL.
* \param ctx_handle Execution context.
* \param const_vars_handle The variables that current operation will use
* but not mutate.
* \param num_const_vars The number of const_vars.
* \param mutable_vars_handle The variables that current operation will mutate.
* \param num_mutable_vars The number of mutable_vars.
* \param prop_handle Property of the function.
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operation name.
* \param wait Whether this is a WaitForVar operation.
*/
MXNET_DLL int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
EngineVarHandle const_vars_handle, int num_const_vars,
EngineVarHandle mutable_vars_handle, int num_mutable_vars,
EngineFnPropertyHandle prop_handle = NULL, int priority = 0,
const char* opr_name = NULL, bool wait = false);

/*!
* \brief Push a synchronous operation to the engine.
* \param sync_func Execution function that executes the operation.
* \param func_param The parameter set on calling sync_func, can be NULL.
* \param deleter The callback to free func_param, can be NULL.
yuxihu marked this conversation as resolved.
Show resolved Hide resolved
* \param ctx_handle Execution context.
* \param const_vars_handle The variables that current operation will use
* but not mutate.
* \param num_const_vars The number of const_vars.
* \param mutable_vars_handle The variables that current operation will mutate.
* \param num_mutable_vars The number of mutable_vars.
* \param prop_handle Property of the function.
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operation name.
*/
MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
EngineVarHandle const_vars_handle, int num_const_vars,
EngineVarHandle mutable_vars_handle, int num_mutable_vars,
EngineFnPropertyHandle prop_handle = NULL, int priority = 0,
const char* opr_name = NULL);

#ifdef __cplusplus
}
Expand Down
88 changes: 88 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1401,3 +1401,91 @@ int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *s
*out = new NDArray(shared_pid, shared_id, mxnet::TShape(shape, shape + ndim), dtype);
API_END();
}

typedef Engine::VarHandle VarHandle;
typedef Engine::CallbackOnComplete CallbackOnComplete;

void AssertValidNumberVars(int num_const_vars, int num_mutable_vars) {
CHECK_GE(num_const_vars, 0) << "Non-negative number of const vars expected.";
CHECK_GE(num_mutable_vars, 0) << "Non-negative number of mutable vars expected.";
}

int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
EngineVarHandle const_vars_handle, int num_const_vars,
EngineVarHandle mutable_vars_handle, int num_mutable_vars,
EngineFnPropertyHandle prop_handle, int priority,
const char* opr_name, bool wait) {
API_BEGIN();

auto exec_ctx = *static_cast<const Context*>(ctx_handle);
auto const_vars = static_cast<VarHandle*>(const_vars_handle);
auto mutable_vars = static_cast<VarHandle*>(mutable_vars_handle);
auto prop = FnProperty::kNormal;
if (prop_handle) {
prop = *static_cast<const FnProperty*>(prop_handle);
}

Engine::AsyncFn exec_fn;
if (deleter == nullptr) {
exec_fn = [async_func, func_param](RunContext rctx,
CallbackOnComplete on_complete) {
async_func(&rctx, &on_complete, func_param);
};
} else {
// Wrap func_param in a shared_ptr with deleter such that deleter
// will be called when the lambda goes out of scope.
std::shared_ptr<void> shared_func_param(func_param, deleter);
exec_fn = [async_func, shared_func_param](RunContext rctx,
CallbackOnComplete on_complete) {
async_func(&rctx, &on_complete, shared_func_param.get());
};
}

AssertValidNumberVars(num_const_vars, num_mutable_vars);
std::vector<VarHandle> const_var_vec(const_vars, const_vars + num_const_vars);
std::vector<VarHandle> mutable_var_vec(mutable_vars, mutable_vars + num_mutable_vars);
Engine::Get()->PushAsync(exec_fn, exec_ctx, const_var_vec, mutable_var_vec,
prop, priority, opr_name, wait);

API_END();
}

int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
EngineVarHandle const_vars_handle, int num_const_vars,
EngineVarHandle mutable_vars_handle, int num_mutable_vars,
EngineFnPropertyHandle prop_handle, int priority,
const char* opr_name) {
API_BEGIN();

auto exec_ctx = *static_cast<const Context*>(ctx_handle);
auto const_vars = static_cast<VarHandle*>(const_vars_handle);
auto mutable_vars = static_cast<VarHandle*>(mutable_vars_handle);
auto prop = FnProperty::kNormal;
if (prop_handle) {
prop = *static_cast<const FnProperty*>(prop_handle);
}

Engine::SyncFn exec_fn;
if (deleter == nullptr) {
exec_fn = [sync_func, func_param](RunContext rctx) {
sync_func(&rctx, func_param);
};
} else {
// Wrap func_param in a shared_ptr with deleter such that deleter
// will be called when the lambda goes out of scope.
std::shared_ptr<void> shared_func_param(func_param, deleter);
exec_fn = [sync_func, shared_func_param](RunContext rctx) {
sync_func(&rctx, shared_func_param.get());
};
}

AssertValidNumberVars(num_const_vars, num_mutable_vars);
std::vector<VarHandle> const_var_vec(const_vars, const_vars + num_const_vars);
std::vector<VarHandle> mutable_var_vec(mutable_vars, mutable_vars + num_mutable_vars);
Engine::Get()->PushSync(exec_fn, exec_ctx, const_var_vec, mutable_var_vec,
prop, priority, opr_name);

API_END();
}
78 changes: 78 additions & 0 deletions tests/cpp/engine/threaded_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <dmlc/thread_group.h>
#include <dmlc/omp.h>
#include <gtest/gtest.h>
#include <mxnet/c_api.h>
#include <mxnet/engine.h>
#include <dmlc/timer.h>
#include <cstdio>
Expand Down Expand Up @@ -176,6 +177,83 @@ TEST(Engine, RandSumExpr) {

void Foo(mxnet::RunContext, int i) { printf("The fox says %d\n", i); }

void FooAsyncFunc(void*, void* cb_ptr, void* param) {
if (param == nullptr) {
LOG(INFO) << "The fox asynchronously says receiving nothing.";
} else {
auto num = static_cast<int*>(param);
EXPECT_EQ(*num, 100);
LOG(INFO) << "The fox asynchronously says receiving " << *num;
}
auto cb = *static_cast<mxnet::engine::CallbackOnComplete*>(cb_ptr);
cb();
}

void FooSyncFunc(void*, void* param) {
if (param == nullptr) {
LOG(INFO) << "The fox synchronously says receiving nothing.";
} else {
auto num = static_cast<int*>(param);
EXPECT_EQ(*num, 101);
LOG(INFO) << "The fox synchronously says receiving " << *num;
}
}

void FooFuncDeleter(void* param) {
if (param != nullptr) {
auto num = static_cast<int*>(param);
LOG(INFO) << "The fox says deleting " << *num;
delete num;
}
}

TEST(Engine, PushFunc) {
auto var = mxnet::Engine::Get()->NewVariable();
auto ctx = mxnet::Context{};

// Test #1
LOG(INFO) << "===== Test #1: PushAsync param and deleter =====";
int* a = new int(100);
int res = MXEnginePushAsync(FooAsyncFunc, a, FooFuncDeleter, &ctx, &var, 1, nullptr, 0);
EXPECT_EQ(res, 0);

// Test #2
LOG(INFO) << "===== Test #2: PushAsync NULL param and NULL deleter =====";
res = MXEnginePushAsync(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &var, 0);
EXPECT_EQ(res, 0);

// Test #3
LOG(INFO) << "===== Test #3: PushAsync invalid number of const vars =====";
res = MXEnginePushAsync(FooAsyncFunc, nullptr, nullptr, &ctx, &var, -1, nullptr, 0);
EXPECT_EQ(res, -1);

// Test #4
LOG(INFO) << "===== Test #4: PushAsync invalid number of mutable vars =====";
res = MXEnginePushAsync(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &var, -1);
EXPECT_EQ(res, -1);

// Test #5
LOG(INFO) << "===== Test #5: PushSync param and deleter =====";
int* b = new int(101);
res = MXEnginePushSync(FooSyncFunc, b, FooFuncDeleter, &ctx, &var, 1, nullptr, 0);
EXPECT_EQ(res, 0);

// Test #6
LOG(INFO) << "===== Test #6: PushSync NULL param and NULL deleter =====";
res = MXEnginePushSync(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &var, 1);
EXPECT_EQ(res, 0);

// Test #7
LOG(INFO) << "===== Test #7: PushSync invalid number of const vars =====";
res = MXEnginePushSync(FooSyncFunc, nullptr, nullptr, &ctx, &var, -1, nullptr, 0);
EXPECT_EQ(res, -1);

// Test #8
LOG(INFO) << "===== Test #8: PushSync invalid number of mutable vars =====";
res = MXEnginePushSync(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &var, -1);
EXPECT_EQ(res, -1);
}

TEST(Engine, basics) {
auto&& engine = mxnet::Engine::Get();
auto&& var = engine->NewVariable();
Expand Down