Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add inline for cached op & fixed a bug when calling backward on varia…
Browse files Browse the repository at this point in the history
…ble (#8701)

* add inline for cached op & fixed a bug when calling backward on variable

* fix

* Update test_gluon.py
  • Loading branch information
piiswrong authored Nov 28, 2017
1 parent 632c897 commit 2f8c1e8
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 39 deletions.
11 changes: 9 additions & 2 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -793,8 +793,15 @@ MXNET_DLL int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out);
/*!
* \brief create cached operator
*/
MXNET_DLL int MXCreateCachedOp(SymbolHandle handle,
CachedOpHandle *out);
MXNET_DLL int MXCreateCachedOp(SymbolHandle handle, CachedOpHandle *out);
/*!
* \brief create cached operator
*/
MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
int num_params,
const char** keys,
const char** vals,
CachedOpHandle *out);
/*!
* \brief free cached operator
*/
Expand Down
37 changes: 33 additions & 4 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,30 @@
#include <nnvm/graph.h>
#include <vector>
#include <atomic>
#include <utility>
#include <string>
#include <unordered_map>

#include "./ndarray.h"

namespace mxnet {
/*! \brief CachedOp Parameters */
struct CachedOpParam : public dmlc::Parameter<CachedOpParam> {
uint32_t inline_limit;
uint32_t forward_bulk_size;
uint32_t backward_bulk_size;
DMLC_DECLARE_PARAMETER(CachedOpParam) {
DMLC_DECLARE_FIELD(inline_limit)
.set_default(2)
.describe("Maximum number of operators that can be inlined.");
DMLC_DECLARE_FIELD(forward_bulk_size)
.set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
.describe("Segment size of bulk execution during forward pass.");
DMLC_DECLARE_FIELD(backward_bulk_size)
.set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
.describe("Segment size of bulk execution during backward pass.");
}
};
/*! \brief runtime functions for NDArray */
class Imperative {
public:
Expand Down Expand Up @@ -77,7 +96,8 @@ class Imperative {
};
class CachedOp {
public:
explicit CachedOp(const nnvm::Symbol& sym);
CachedOp(const nnvm::Symbol& sym,
const std::vector<std::pair<std::string, std::string> >& kwargs);
uint32_t num_inputs() {
return fwd_graph_.indexed_graph().input_nodes().size();
}
Expand All @@ -103,8 +123,9 @@ class Imperative {
const std::vector<NDArray*>& inputs);
std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node,
const std::vector<nnvm::NodeEntry>& ograds);
OpStatePtr Forward(const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs);
void Forward(const std::shared_ptr<CachedOp>& op_ptr,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs);
void Backward(const bool retain_graph,
const OpStatePtr& state,
const std::vector<NDArray*>& inputs,
Expand All @@ -117,9 +138,11 @@ class Imperative {
std::vector<OpStatePtr> states;
};
std::mutex mutex_;
CachedOpParam param_;
nnvm::Graph fwd_graph_;
nnvm::Graph grad_graph_;
nnvm::Graph full_graph_;
bool inlining_;
std::vector<nnvm::NodeEntry> ograd_entries_;
std::vector<bool> curr_grad_req_;
std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
Expand Down Expand Up @@ -182,7 +205,11 @@ class Imperative {
private:
friend class NDArray;
/*! \brief make constructor protected. */
Imperative() {}
Imperative() {
if (dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1)) {
backward_bulk_size_ = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15);
}
}
/*! \brief find the input/output ndarrays that are needed for backward */
void GetBackwardDependency(
const nnvm::NodePtr& node,
Expand Down Expand Up @@ -210,6 +237,8 @@ class Imperative {
std::atomic<uint64_t> node_count_{0};
/*! \brief variable count used for naming */
std::atomic<uint64_t> variable_count_{0};
/*! \brief default backward bulk size */
int backward_bulk_size_{0};
};

using CachedOpPtr = std::shared_ptr<Imperative::CachedOp>;
Expand Down
7 changes: 5 additions & 2 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,13 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
class CachedOp(object):
"""Cached operator handle."""
__slots__ = ["handle"]
def __init__(self, sym):
def __init__(self, sym, flags=()):
self.handle = CachedOpHandle()
check_call(_LIB.MXCreateCachedOp(
check_call(_LIB.MXCreateCachedOpEx(
sym.handle,
len(flags),
c_str_array([key for key, _ in flags]),
c_str_array([str(val) for _, val in flags]),
ctypes.byref(self.handle)))

def __del__(self):
Expand Down
20 changes: 15 additions & 5 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,17 +274,19 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False):
"""
self.collect_params().initialize(init, ctx, verbose)

def hybridize(self, active=True):
def hybridize(self, active=True, **kwargs):
"""Activates or deactivates :py:class:`HybridBlock` s recursively. Has no effect on
non-hybrid children.
Parameters
----------
active : bool, default True
Whether to turn hybrid on or off.
**kwargs : string
Additional flags for hybridized operator.
"""
for cld in self._children:
cld.hybridize(active)
cld.hybridize(active, **kwargs)

def cast(self, dtype):
"""Cast this Block to use another data type.
Expand Down Expand Up @@ -343,6 +345,7 @@ def __init__(self, prefix=None, params=None):
self._out_format = None
self._in_format = None
self._active = False
self._flags = {}

def __setattr__(self, name, value):
"""Registers parameters."""
Expand Down Expand Up @@ -378,7 +381,7 @@ def _get_graph(self, *args):
def _build_cache(self, *args):
inputs, out = self._get_graph(*args)
input_idx = {var.name: i for i, var in enumerate(inputs)}
self._cached_op = ndarray.CachedOp(out)
self._cached_op = ndarray.CachedOp(out, self._flags)
params = dict(self.collect_params().items())

# verify graph inputs
Expand Down Expand Up @@ -437,9 +440,11 @@ def register_child(self, block):
super(HybridBlock, self).register_child(block)
self._clear_cached_op()

def hybridize(self, active=True):
def hybridize(self, active=True, **kwargs):
self._active = active
super(HybridBlock, self).hybridize(active)
self._flags = kwargs.items()
self._clear_cached_op()
super(HybridBlock, self).hybridize(active, **kwargs)

def cast(self, dtype):
self._clear_cached_op()
Expand Down Expand Up @@ -615,5 +620,10 @@ def forward(self, x, *args):
ret._compose(**{k.name: v for k, v in zip(self._cached_graph[0], args)})
return _regroup(list(ret), self._out_format)[0]

def _clear_cached_op(self):
tmp = self._cached_graph
super(SymbolBlock, self)._clear_cached_op()
self._cached_graph = tmp

def hybrid_forward(self, F, x, *args, **kwargs):
raise NotImplementedError
6 changes: 4 additions & 2 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,21 @@ def __getitem__(self, key):
def __len__(self):
return len(self._children)

def hybridize(self, active=True):
def hybridize(self, active=True, **kwargs):
"""Activates or deactivates `HybridBlock`s recursively. Has no effect on
non-hybrid children.
Parameters
----------
active : bool, default True
Whether to turn hybrid on or off.
**kwargs : string
Additional flags for hybridized operator.
"""
if self._children and all(isinstance(c, HybridBlock) for c in self._children):
warnings.warn('All children of this Sequential layer are HybridBlocks. Consider ' \
'using HybridSequential for the best performance.')
super(Sequential, self).hybridize(active)
super(Sequential, self).hybridize(active, **kwargs)


class HybridSequential(HybridBlock):
Expand Down
31 changes: 20 additions & 11 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,25 @@ int MXCreateCachedOp(SymbolHandle handle,

API_BEGIN();
*out = new std::shared_ptr<Imperative::CachedOp>(
new Imperative::CachedOp(*sym));
new Imperative::CachedOp(
*sym, std::vector<std::pair<std::string, std::string> >()));
API_END();
}

int MXCreateCachedOpEx(SymbolHandle handle,
int num_params,
const char** keys,
const char** vals,
CachedOpHandle *out) {
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(handle);

API_BEGIN();
std::vector<std::pair<std::string, std::string> > kwargs;
for (int i = 0; i < num_params; ++i) {
kwargs.push_back({keys[i], vals[i]});
}
*out = new std::shared_ptr<Imperative::CachedOp>(
new Imperative::CachedOp(*sym, kwargs));
API_END();
}

Expand Down Expand Up @@ -198,16 +216,7 @@ int MXInvokeCachedOp(CachedOpHandle handle,
}
}

OpStatePtr state = op->Forward(ndinputs, ndoutputs);
if (Imperative::Get()->is_recording()) {
nnvm::NodeAttrs attrs;
attrs.op = cached_op;
attrs.name = "_cachedop";
attrs.parsed = op;
Imperative::Get()->RecordOp(
std::move(attrs), ndinputs, ndoutputs, state,
&op->save_inputs(), &op->save_outputs());
}
op->Forward(op, ndinputs, ndoutputs);

if (*outputs == nullptr) {
ret->ret_handles.clear();
Expand Down
41 changes: 34 additions & 7 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@

namespace mxnet {

Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
DMLC_REGISTER_PARAMETER(CachedOpParam);

Imperative::CachedOp::CachedOp(
const nnvm::Symbol& sym,
const std::vector<std::pair<std::string, std::string> >& kwargs) {
using namespace nnvm;
using namespace imperative;
static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")};
static const auto _copy = Op::Get("_copy");

param_.Init(kwargs);

// construct forward graph
{
NodeEntryMap<int> dedup_out;
Expand Down Expand Up @@ -59,6 +65,8 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {

fwd_graph_.attrs["forward_ref_count"] =
std::make_shared<dmlc::any>(std::move(ref_count));

inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <= param_.inline_limit;
}

// construct backward graph
Expand Down Expand Up @@ -321,13 +329,16 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
return g;
}

OpStatePtr Imperative::CachedOp::Forward(const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs) {
void Imperative::CachedOp::Forward(
const std::shared_ptr<CachedOp>& op_ptr,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs) {
using namespace nnvm;
using namespace imperative;
static const auto cached_op = nnvm::Op::Get("_CachedOp");

bool recording = Imperative::Get()->set_is_recording(false);
// Initialize
bool recording = Imperative::Get()->is_recording();
nnvm::Graph g = GetForwardGraph(recording, inputs);
const auto& idx = g.indexed_graph();
size_t num_inputs = idx.input_nodes().size();
Expand Down Expand Up @@ -381,20 +392,32 @@ OpStatePtr Imperative::CachedOp::Forward(const std::vector<NDArray*>& inputs,

const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");

if (recording && !inlining_) Imperative::Get()->set_is_recording(false);
int prev_bulk_size = Engine::Get()->set_bulk_size(param_.forward_bulk_size);

Imperative::Get()->RunGraph(
false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
std::move(ref_count), &states, dispatch_modes);

Engine::Get()->set_bulk_size(prev_bulk_size);
Imperative::Get()->set_is_recording(recording);

for (size_t i = 0; i < idx.num_node_entries(); ++i) {
if (arrays[i] == &buff[i]) continue;
buff[i].shape_ = arrays[i]->shape_;
buff[i].dtype_ = arrays[i]->dtype_;
buff[i].storage_type_ = arrays[i]->storage_type_;
}

Imperative::Get()->set_is_recording(recording);

return op_state_ptr;
if (recording && !inlining_) {
nnvm::NodeAttrs attrs;
attrs.op = cached_op;
attrs.name = "_cachedop";
attrs.parsed = op_ptr;
Imperative::Get()->RecordOp(
std::move(attrs), inputs, outputs, op_state_ptr,
&save_inputs(), &save_outputs());
}
}


Expand Down Expand Up @@ -452,10 +475,14 @@ void Imperative::CachedOp::Backward(

const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");

int prev_bulk_size = Engine::Get()->set_bulk_size(param_.backward_bulk_size);

Imperative::Get()->RunGraph(
retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);

Engine::Get()->set_bulk_size(prev_bulk_size);

if (retain_graph) {
buff.resize(num_forward_entries);
} else {
Expand Down
Loading

0 comments on commit 2f8c1e8

Please sign in to comment.