diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 9815786ba7f6..faa453529e84 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -793,8 +793,15 @@ MXNET_DLL int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out); /*! * \brief create cached operator */ -MXNET_DLL int MXCreateCachedOp(SymbolHandle handle, - CachedOpHandle *out); +MXNET_DLL int MXCreateCachedOp(SymbolHandle handle, CachedOpHandle *out); +/*! + * \brief create cached operator + */ +MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle, + int num_params, + const char** keys, + const char** vals, + CachedOpHandle *out); /*! * \brief free cached operator */ diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index 88a9f4d597ef..d605e9d47ca0 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -28,11 +28,30 @@ #include #include #include +#include +#include #include #include "./ndarray.h" namespace mxnet { +/*! \brief CachedOp Parameters */ +struct CachedOpParam : public dmlc::Parameter { + uint32_t inline_limit; + uint32_t forward_bulk_size; + uint32_t backward_bulk_size; + DMLC_DECLARE_PARAMETER(CachedOpParam) { + DMLC_DECLARE_FIELD(inline_limit) + .set_default(2) + .describe("Maximum number of operators that can be inlined."); + DMLC_DECLARE_FIELD(forward_bulk_size) + .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)) + .describe("Segment size of bulk execution during forward pass."); + DMLC_DECLARE_FIELD(backward_bulk_size) + .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)) + .describe("Segment size of bulk execution during backward pass."); + } +}; /*! \brief runtime functions for NDArray */ class Imperative { public: @@ -77,7 +96,8 @@ class Imperative { }; class CachedOp { public: - explicit CachedOp(const nnvm::Symbol& sym); + CachedOp(const nnvm::Symbol& sym, + const std::vector >& kwargs); uint32_t num_inputs() { return fwd_graph_.indexed_graph().input_nodes().size(); } @@ -103,8 +123,9 @@ class Imperative { const std::vector& inputs); std::vector Gradient(const nnvm::NodePtr& node, const std::vector& ograds); - OpStatePtr Forward(const std::vector& inputs, - const std::vector& outputs); + void Forward(const std::shared_ptr& op_ptr, + const std::vector& inputs, + const std::vector& outputs); void Backward(const bool retain_graph, const OpStatePtr& state, const std::vector& inputs, @@ -117,9 +138,11 @@ class Imperative { std::vector states; }; std::mutex mutex_; + CachedOpParam param_; nnvm::Graph fwd_graph_; nnvm::Graph grad_graph_; nnvm::Graph full_graph_; + bool inlining_; std::vector ograd_entries_; std::vector curr_grad_req_; std::vector bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_; @@ -182,7 +205,11 @@ class Imperative { private: friend class NDArray; /*! \brief make constructor protected. */ - Imperative() {} + Imperative() { + if (dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1)) { + backward_bulk_size_ = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15); + } + } /*! \brief find the input/output ndarrays that are needed for backward */ void GetBackwardDependency( const nnvm::NodePtr& node, @@ -210,6 +237,8 @@ class Imperative { std::atomic node_count_{0}; /*! \brief variable count used for naming */ std::atomic variable_count_{0}; + /*! \brief default backward bulk size */ + int backward_bulk_size_{0}; }; using CachedOpPtr = std::shared_ptr; diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index a0c01a6e069b..20ad2bfbc555 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -105,10 +105,13 @@ def _imperative_invoke(handle, ndargs, keys, vals, out): class CachedOp(object): """Cached operator handle.""" __slots__ = ["handle"] - def __init__(self, sym): + def __init__(self, sym, flags=()): self.handle = CachedOpHandle() - check_call(_LIB.MXCreateCachedOp( + check_call(_LIB.MXCreateCachedOpEx( sym.handle, + len(flags), + c_str_array([key for key, _ in flags]), + c_str_array([str(val) for _, val in flags]), ctypes.byref(self.handle))) def __del__(self): diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 466f87fade7f..37734ac3894d 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -274,7 +274,7 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False): """ self.collect_params().initialize(init, ctx, verbose) - def hybridize(self, active=True): + def hybridize(self, active=True, **kwargs): """Activates or deactivates :py:class:`HybridBlock` s recursively. Has no effect on non-hybrid children. @@ -282,9 +282,11 @@ def hybridize(self, active=True): ---------- active : bool, default True Whether to turn hybrid on or off. + **kwargs : string + Additional flags for hybridized operator. """ for cld in self._children: - cld.hybridize(active) + cld.hybridize(active, **kwargs) def cast(self, dtype): """Cast this Block to use another data type. @@ -343,6 +345,7 @@ def __init__(self, prefix=None, params=None): self._out_format = None self._in_format = None self._active = False + self._flags = {} def __setattr__(self, name, value): """Registers parameters.""" @@ -378,7 +381,7 @@ def _get_graph(self, *args): def _build_cache(self, *args): inputs, out = self._get_graph(*args) input_idx = {var.name: i for i, var in enumerate(inputs)} - self._cached_op = ndarray.CachedOp(out) + self._cached_op = ndarray.CachedOp(out, self._flags) params = dict(self.collect_params().items()) # verify graph inputs @@ -437,9 +440,11 @@ def register_child(self, block): super(HybridBlock, self).register_child(block) self._clear_cached_op() - def hybridize(self, active=True): + def hybridize(self, active=True, **kwargs): self._active = active - super(HybridBlock, self).hybridize(active) + self._flags = kwargs.items() + self._clear_cached_op() + super(HybridBlock, self).hybridize(active, **kwargs) def cast(self, dtype): self._clear_cached_op() @@ -615,5 +620,10 @@ def forward(self, x, *args): ret._compose(**{k.name: v for k, v in zip(self._cached_graph[0], args)}) return _regroup(list(ret), self._out_format)[0] + def _clear_cached_op(self): + tmp = self._cached_graph + super(SymbolBlock, self)._clear_cached_op() + self._cached_graph = tmp + def hybrid_forward(self, F, x, *args, **kwargs): raise NotImplementedError diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index c0b4b52382f5..ab5d5e167f87 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -68,7 +68,7 @@ def __getitem__(self, key): def __len__(self): return len(self._children) - def hybridize(self, active=True): + def hybridize(self, active=True, **kwargs): """Activates or deactivates `HybridBlock`s recursively. Has no effect on non-hybrid children. @@ -76,11 +76,13 @@ def hybridize(self, active=True): ---------- active : bool, default True Whether to turn hybrid on or off. + **kwargs : string + Additional flags for hybridized operator. """ if self._children and all(isinstance(c, HybridBlock) for c in self._children): warnings.warn('All children of this Sequential layer are HybridBlocks. Consider ' \ 'using HybridSequential for the best performance.') - super(Sequential, self).hybridize(active) + super(Sequential, self).hybridize(active, **kwargs) class HybridSequential(HybridBlock): diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 2c4a30501147..51f30e223198 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -157,7 +157,25 @@ int MXCreateCachedOp(SymbolHandle handle, API_BEGIN(); *out = new std::shared_ptr( - new Imperative::CachedOp(*sym)); + new Imperative::CachedOp( + *sym, std::vector >())); + API_END(); +} + +int MXCreateCachedOpEx(SymbolHandle handle, + int num_params, + const char** keys, + const char** vals, + CachedOpHandle *out) { + nnvm::Symbol* sym = static_cast(handle); + + API_BEGIN(); + std::vector > kwargs; + for (int i = 0; i < num_params; ++i) { + kwargs.push_back({keys[i], vals[i]}); + } + *out = new std::shared_ptr( + new Imperative::CachedOp(*sym, kwargs)); API_END(); } @@ -198,16 +216,7 @@ int MXInvokeCachedOp(CachedOpHandle handle, } } - OpStatePtr state = op->Forward(ndinputs, ndoutputs); - if (Imperative::Get()->is_recording()) { - nnvm::NodeAttrs attrs; - attrs.op = cached_op; - attrs.name = "_cachedop"; - attrs.parsed = op; - Imperative::Get()->RecordOp( - std::move(attrs), ndinputs, ndoutputs, state, - &op->save_inputs(), &op->save_outputs()); - } + op->Forward(op, ndinputs, ndoutputs); if (*outputs == nullptr) { ret->ret_handles.clear(); diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 5717327f87cf..eaa95a5f2418 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -22,12 +22,18 @@ namespace mxnet { -Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) { +DMLC_REGISTER_PARAMETER(CachedOpParam); + +Imperative::CachedOp::CachedOp( + const nnvm::Symbol& sym, + const std::vector >& kwargs) { using namespace nnvm; using namespace imperative; static const std::vector zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; static const auto _copy = Op::Get("_copy"); + param_.Init(kwargs); + // construct forward graph { NodeEntryMap dedup_out; @@ -59,6 +65,8 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) { fwd_graph_.attrs["forward_ref_count"] = std::make_shared(std::move(ref_count)); + + inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <= param_.inline_limit; } // construct backward graph @@ -321,13 +329,16 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph( return g; } -OpStatePtr Imperative::CachedOp::Forward(const std::vector& inputs, - const std::vector& outputs) { +void Imperative::CachedOp::Forward( + const std::shared_ptr& op_ptr, + const std::vector& inputs, + const std::vector& outputs) { using namespace nnvm; using namespace imperative; + static const auto cached_op = nnvm::Op::Get("_CachedOp"); - bool recording = Imperative::Get()->set_is_recording(false); // Initialize + bool recording = Imperative::Get()->is_recording(); nnvm::Graph g = GetForwardGraph(recording, inputs); const auto& idx = g.indexed_graph(); size_t num_inputs = idx.input_nodes().size(); @@ -381,10 +392,16 @@ OpStatePtr Imperative::CachedOp::Forward(const std::vector& inputs, const auto& dispatch_modes = g.GetAttr("dispatch_mode"); + if (recording && !inlining_) Imperative::Get()->set_is_recording(false); + int prev_bulk_size = Engine::Get()->set_bulk_size(param_.forward_bulk_size); + Imperative::Get()->RunGraph( false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), std::move(ref_count), &states, dispatch_modes); + Engine::Get()->set_bulk_size(prev_bulk_size); + Imperative::Get()->set_is_recording(recording); + for (size_t i = 0; i < idx.num_node_entries(); ++i) { if (arrays[i] == &buff[i]) continue; buff[i].shape_ = arrays[i]->shape_; @@ -392,9 +409,15 @@ OpStatePtr Imperative::CachedOp::Forward(const std::vector& inputs, buff[i].storage_type_ = arrays[i]->storage_type_; } - Imperative::Get()->set_is_recording(recording); - - return op_state_ptr; + if (recording && !inlining_) { + nnvm::NodeAttrs attrs; + attrs.op = cached_op; + attrs.name = "_cachedop"; + attrs.parsed = op_ptr; + Imperative::Get()->RecordOp( + std::move(attrs), inputs, outputs, op_state_ptr, + &save_inputs(), &save_outputs()); + } } @@ -452,10 +475,14 @@ void Imperative::CachedOp::Backward( const auto& dispatch_modes = g.GetAttr("dispatch_mode"); + int prev_bulk_size = Engine::Get()->set_bulk_size(param_.backward_bulk_size); + Imperative::Get()->RunGraph( retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), std::move(array_reqs), std::move(ref_count), &states, dispatch_modes); + Engine::Get()->set_bulk_size(prev_bulk_size); + if (retain_graph) { buff.resize(num_forward_entries); } else { diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index 361b971a2da3..fbbaf82d1770 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -288,8 +288,6 @@ void Imperative::RunGraph( DTypeVector arg_dtypes; std::vector req; - int prev_bulk_size = Engine::Get()->set_bulk_size(10); - for (size_t i = node_start; i < node_end; ++i) { const nnvm::IndexedGraph::Node& node = idx[i]; if (node.source->op() == nullptr) continue; @@ -353,8 +351,6 @@ void Imperative::RunGraph( if (ref_count[eid] == 0) arrays[eid]->ptr_.reset(); } } - - Engine::Get()->set_bulk_size(prev_bulk_size); } @@ -367,6 +363,7 @@ std::vector Imperative::Backward( using namespace nnvm; using namespace imperative; static const std::vector zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; + static const Op* copy_op = Op::Get("_copy"); // Construct forward graph Graph graph; @@ -439,7 +436,14 @@ std::vector Imperative::Backward( zero_ops, "_copy"); CHECK_EQ(g_graph.outputs.size(), xs.size()); for (const auto &e : g_graph.outputs) { - graph.outputs.push_back(e); + if (e.node->op() == nullptr) { + auto node = Node::Create(); + node->attrs.op = copy_op; + node->inputs.push_back(e); + graph.outputs.push_back(NodeEntry{node, 0, 0}); + } else { + graph.outputs.push_back(e); + } } const auto& idx = graph.indexed_graph(); // get number of nodes used in forward pass @@ -575,10 +579,12 @@ std::vector Imperative::Backward( bool prev_recording = set_is_recording(create_graph); bool prev_training = set_is_training(is_train); + int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_); RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), std::move(array_reqs), std::move(ref_count), &states, dispatch_modes); + Engine::Get()->set_bulk_size(prev_bulk_size); set_is_recording(prev_recording); set_is_training(prev_training); diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index df9f78e0ce9a..c619056c11c6 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -23,6 +23,7 @@ from nose.tools import raises from copy import deepcopy import warnings +import json def test_parameter(): @@ -256,7 +257,6 @@ def test_deconv(): # # check_layer_forward(layer, (1, 10, 10, 10, 4)) - def test_pool(): layers1d = [ nn.MaxPool1D(), @@ -611,6 +611,31 @@ def test_fill_shape_load(): assert net2[2].weight.shape[1] == 3072, net2[2].weight.shape[1] +def test_inline(): + net = mx.gluon.nn.HybridSequential() + with net.name_scope(): + net.add(mx.gluon.nn.Dense(10)) + net.add(mx.gluon.nn.Dense(10)) + net.add(mx.gluon.nn.Dense(10)) + + net.initialize() + net.hybridize(inline_limit=3) + with mx.autograd.record(): + y = net(mx.nd.zeros((1,10))) + + len_1 = len(json.loads(mx.autograd.get_symbol(y).tojson())['nodes']) + y.backward() + + net.hybridize(inline_limit=0) + with mx.autograd.record(): + y = net(mx.nd.zeros((1,10))) + + len_2 = len(json.loads(mx.autograd.get_symbol(y).tojson())['nodes']) + y.backward() + + assert len_1 == len_2 + 2 + + if __name__ == '__main__': import nose nose.runmodule()