Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add deferred compute support #17530

Merged
merged 14 commits into from
Mar 23, 2020
42 changes: 42 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,44 @@ MXNET_DLL int MXCachedOpRegisterOpHook(NDArrayHandle handle,
CachedOpMonitorCallback callback,
bool monitor_all);

/*!
* \brief Get current status of deferred compute mode
* \param curr returns the current status.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayIsDeferredCompute(int *curr);

/*!
* \brief set whether to enable deferred compute mode
* \param deferred_compute_enabled 1 to enable, 0 to disable.
* \param prev returns the previous status before this set.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArraySetIsDeferredCompute(int deferred_compute_enabled, int *prev);

/*!
* \brief Associate variables with deferred compute arrays
* \param arrays ndarray handles to be matched with variables
* \param variables symbol handles of variables to be matched with ndarrays
* \param num number of arrays and variables respectively
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArraySetDeferredComputeVariable(NDArrayHandle *arrays,
SymbolHandle *variables,
int num);

/*!
* \brief Convert the graph constructed during deferred computation mode to a Symbol.
* \param output_handles ndarray handles of outputs
* \param out grouped output symbol handle
*
* Construct a Symbol for the deferred computation graph. output_handles
* specifies the outputs of interest which the returned symbol will compute.
*/
MXNET_DLL int MXNDArrayGetDeferredComputeSymbol(NDArrayHandle *output_handles,
int num_outputs,
SymbolHandle *out);

//--------------------------------------------
// Part 3: symbolic configuration generation
//--------------------------------------------
Expand Down Expand Up @@ -1501,6 +1539,10 @@ MXNET_DLL int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char **return_type DEFAULT(NULL));
/*!
* \brief Create an AtomicSymbol.
*
* A Symbol is said to be atomic if it is not composed of other Symbols. Atomic
* Symbols can be composed.
*
* \param creator the AtomicSymbolCreator
* \param num_param the number of parameters
* \param keys the keys to the params
Expand Down
90 changes: 88 additions & 2 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class Imperative {
OpReqType grad_req;
OpStatePtr state;
std::vector<NDArray> outputs;
std::vector<NDArray> out_grads;
std::vector<NDArray> out_grads; // used to hold gradient arrays the user is
// interested in (marked variables)
bool fresh_out_grad;

AGInfo() :
Expand All @@ -79,7 +80,7 @@ class Imperative {
}

static bool IsNone(const NDArray& arr) {
return arr.entry_.node == nullptr || arr.entry_.node->info.empty();
return arr.autograd_entry_.node == nullptr || arr.autograd_entry_.node->info.empty();
}

static bool IsVariable(const nnvm::ObjectPtr& node) {
Expand All @@ -88,6 +89,73 @@ class Imperative {
&& info.out_grads.size() == 1;
}
};

/*! \brief DCInfo datastructure to enable deferred computation */
class DCInfo {
public:
explicit DCInfo(const std::vector<NDArray *> &inputs,
const std::vector<NDArray *> &outputs);

/*! \brief Compute the outputs of the associated operator. */
static void Compute(const NDArray &arr);

static DCInfo &Get(const nnvm::ObjectPtr &node) {
return dmlc::get<DCInfo>(node->info);
}

static bool IsNone(const NDArray &arr) {
return arr.deferredcompute_entry_.node == nullptr ||
arr.deferredcompute_entry_.node->info.empty();
}

static bool IsComputed(const NDArray &arr) {
return IsNone(arr) ||
dmlc::get<DCInfo>(arr.deferredcompute_entry_.node->info).is_computed_;
}

static DCInfo &Create(const nnvm::ObjectPtr &node,
const std::vector<NDArray *> &inputs,
const std::vector<NDArray *> &outputs);

private:
friend class Imperative;

/*! \brief Copies of input NDArrays
*
* If respective input NDArray is deallocated on the frontend, we still need
* to keep a copy around to facilitate deferred computation of this array.
* The copies share the chunk.
*
* They are automatically deallocated after computation finished.
*/
std::vector<NDArray> inputs_;

/*! \brief Handles of input NDArrays used by frontend
*
* Frontend may request conversion to Symbol, specifying a list of NDArray
* handles corresponding to inputs and outputs of the Symbol. We store the
* handles used by frontend to facilitate matching in
* GetDeferredComputeSymbol.
*
* Note that the frontend may have deallocated the NDArray* and the
* input_handles stored here may point to invalid memory.
*/
std::vector<const NDArray *> input_handles_;

/*! \brief Copies of output NDArrays
*
* If respective output NDArray is deallocated on the frontend, we still
* need to keep a copy around to facilitate deferred computation of arrays
* relying on the output array. The copies share the chunk.
*
* They are automatically deallocated after computation finished.
*/
std::vector<NDArray> outputs_;

/*! \brief Remember if the outputs associated with this DCInfo have been computed already */
bool is_computed_ = false;
};

/*! \brief whether operator recording is on. */
bool is_training() const {
return is_train_;
Expand All @@ -108,6 +176,14 @@ class Imperative {
is_recording_ = is_recording;
return old;
}
/*! \brief whether deferred compute mode is on. */
bool is_deferred_compute() const { return is_deferred_compute_; }
/*! \brief turn on or turn off operator recording for autograd. */
bool set_is_deferred_compute(bool is_deferred_compute) {
bool old = is_deferred_compute_;
is_deferred_compute_ = is_deferred_compute;
return old;
}
/*! \brief return current numpy compatibility status,
* GlobalOn(2), ThreadLocalOn(1), Off(0).
* */
Expand Down Expand Up @@ -143,6 +219,14 @@ class Imperative {
const OpStatePtr& state = OpStatePtr(),
std::vector<bool>* p_save_inputs = nullptr,
std::vector<bool>* p_save_outputs = nullptr);
/*! \brief to record operator, return corresponding node. */
void RecordDeferredCompute(nnvm::NodeAttrs&& attrs,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs);
/*! \brief obtain symbol representation of deferred compute session. */
nnvm::Symbol GetDeferredComputeSymbol(const std::vector<NDArray *> &outputs);
/*! \brief associate arrays with variables for deferred compute */
void SetDeferredComputeVariable(NDArrayHandle *arrays, SymbolHandle *variables, const int num);
/*! \brief */
OpStatePtr Invoke(const Context& default_ctx,
const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -204,12 +288,14 @@ class Imperative {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local bool is_train_;
static thread_local bool is_recording_;
static thread_local bool is_deferred_compute_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static thread_local bool is_np_shape_thread_local_;
#else
static MX_THREAD_LOCAL bool is_train_;
static MX_THREAD_LOCAL bool is_recording_;
static MX_THREAD_LOCAL bool is_deferred_compute_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static MX_THREAD_LOCAL bool is_np_shape_thread_local_;
Expand Down
61 changes: 30 additions & 31 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class NDArray {
public:
/*! \brief default constructor */
NDArray()
: entry_(nullptr) {
: autograd_entry_(nullptr) {
leezu marked this conversation as resolved.
Show resolved Hide resolved
}
/*!
* \brief constructs a new dynamic NDArray
Expand All @@ -98,7 +98,7 @@ class NDArray {
shape_(shape),
dtype_(dtype),
storage_type_(kDefaultStorage),
entry_(nullptr) {
autograd_entry_(nullptr) {
}
/*! \brief constructor for NDArray with storage type
*/
Expand All @@ -117,7 +117,7 @@ class NDArray {
shape_(),
dtype_(dtype),
storage_type_(kDefaultStorage),
entry_(nullptr) {
autograd_entry_(nullptr) {
}
/*!
* \brief constructing a static NDArray that shares data with TBlob
Expand All @@ -131,7 +131,7 @@ class NDArray {
shape_(data.shape_),
dtype_(data.type_flag_),
storage_type_(kDefaultStorage),
entry_(nullptr) {
autograd_entry_(nullptr) {
}

/*!
Expand All @@ -149,7 +149,7 @@ class NDArray {
}),
shape_(data.shape_),
dtype_(data.type_flag_), storage_type_(kDefaultStorage),
entry_(nullptr) {
autograd_entry_(nullptr) {
}

/*! \brief create ndarray from shared memory */
Expand All @@ -158,7 +158,7 @@ class NDArray {
shape_(shape),
dtype_(dtype),
storage_type_(kDefaultStorage),
entry_(nullptr) {
autograd_entry_(nullptr) {
}

/*!
Expand All @@ -177,7 +177,7 @@ class NDArray {
shape_(shape),
dtype_(data.type_flag_),
storage_type_(stype),
entry_(nullptr) {
autograd_entry_(nullptr) {
}
/*!
* \brief initialize the NDArray, assuming it is not assigned a meaningful shape before
Expand All @@ -190,7 +190,7 @@ class NDArray {
/*!
* \brief set the correct shape of NDArray directly from the storage_shape of its own chunk.
*/
void SetShapeFromChunk();
void SetShapeFromChunk() const;
/*
* This indicates whether an array is a view of another array (created by
* reshape or slice). If an array is a view and the data is stored in
Expand Down Expand Up @@ -326,9 +326,9 @@ class NDArray {
inline bool is_none() const {
return ptr_.get() == nullptr;
}
/*! \return updated grad state in entry_ */
/*! \return updated grad state in autograd_entry_ */
bool fresh_out_grad() const;
/*! \return updated grad state in entry_ */
/*! \return updated grad state in autograd_entry_ */
void set_fresh_out_grad(bool state) const;
/*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized
* Throws an exception if the indices array shape is inconsistent
Expand Down Expand Up @@ -367,27 +367,19 @@ class NDArray {
/*!
* \brief Block until all the pending write operations with respect
* to current NDArray are finished, and read can be performed.
*
* If the array has not been computed yet (deferred compute), this will
* trigger computation.
*/
inline void WaitToRead() const {
if (is_none()) return;
Engine::Get()->WaitForVar(ptr_->var);
}
void WaitToRead() const;
/*!
* \brief Block until all the pending read/write operations with respect
* to current NDArray are finished, and write can be performed.
*
* If the array has not been computed yet (deferred compute), this will
* trigger computation.
*/
inline void WaitToWrite() const {
if (is_none()) return;
/*!
* Push an empty mutable function to flush all preceding reads to the
* variable.
*/
Engine::Get()->PushAsync(
[](RunContext, Engine::CallbackOnComplete on_complete) {
on_complete();
}, Context{}, {}, {ptr_->var});
Engine::Get()->WaitForVar(ptr_->var);
}
void WaitToWrite() const;
/*! \return the associated variable of the ndarray.*/
inline Engine::VarHandle var() const {
return ptr_->var;
Expand Down Expand Up @@ -648,11 +640,13 @@ class NDArray {
*/
NDArray ReshapeWithRecord(const mxnet::TShape &shape);
/*!
* \brief Return a copy of this NDArray without autograd history
* \brief Return a copy of this NDArray without autograd and deferred compute
* history
*/
NDArray Detach() const {
NDArray ret(*this);
ret.entry_ = nnvm::NodeEntry(nullptr);
ret.autograd_entry_ = nnvm::NodeEntry(nullptr);
ret.deferredcompute_entry_ = nnvm::NodeEntry(nullptr);
return ret;
}

Expand Down Expand Up @@ -1100,8 +1094,11 @@ class NDArray {

/*! \brief internal data of NDArray */
std::shared_ptr<Chunk> ptr_{nullptr};
/*! \brief shape of current NDArray */
mxnet::TShape shape_;
/*! \brief shape of current NDArray
* \note const methods WaitToRead, WaitToWrite will set shape, if shape is
* previously unknown and array is deferred computed.
*/
mutable mxnet::TShape shape_;
/*! \brief byte offset in chunk */
size_t byte_offset_ = 0;
/*! \brief type of data */
Expand All @@ -1111,7 +1108,9 @@ class NDArray {
/*! \brief storage type of data */
NDArrayStorageType storage_type_ = kUndefinedStorage;
/*! \brief node entry for autograd */
nnvm::NodeEntry entry_;
nnvm::NodeEntry autograd_entry_;
/*! \brief node entry for deferred computation tracking */
nnvm::NodeEntry deferredcompute_entry_;
/*!
* \brief internal TBlob
* \note When user access tblob_ by some const methods like
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@
from . import rnn
from . import gluon

from . import _deferred_compute

# With the native kvstore module (such as 'dist_sync_device'), the module launches a separate
# process when role is set to "server". This should be done after other modules are initialized.
# Otherwise this may result in errors when unpickling custom LR scheduler/optimizers.
Expand Down
Loading