diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 4dd858a51c4b..55c26bc980b2 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -987,6 +987,11 @@ MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle, int num_flags, const char** keys, const char** vals, + int num_inputs, + const char** input_names, + int num_params, + const char** param_names, + NDArrayHandle* params, CachedOpHandle *out); /*! * \brief free cached operator diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index 7ea60df33028..758ce8513213 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -35,6 +35,23 @@ #include "./ndarray.h" namespace mxnet { +/*! \brief CachedOp Parameters */ +struct CachedOpConfig : public dmlc::Parameter { + uint32_t inline_limit; + uint32_t forward_bulk_size; + uint32_t backward_bulk_size; + DMLC_DECLARE_PARAMETER(CachedOpConfig) { + DMLC_DECLARE_FIELD(inline_limit) + .set_default(2) + .describe("Maximum number of operators that can be inlined."); + DMLC_DECLARE_FIELD(forward_bulk_size) + .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)) + .describe("Segment size of bulk execution during forward pass."); + DMLC_DECLARE_FIELD(backward_bulk_size) + .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)) + .describe("Segment size of bulk execution during backward pass."); + } +}; /*! \brief runtime functions for NDArray */ class Imperative { public: @@ -77,6 +94,67 @@ class Imperative { && info.out_grads.size() == 1; } }; + class CachedOp { + public: + CachedOp( + const nnvm::Symbol& sym, + const std::vector >& flags, + const std::vector arg_names, + const std::unordered_map >& params); + uint32_t num_inputs() { + return fwd_graph_.indexed_graph().input_nodes().size(); + } + uint32_t num_outputs() { + return fwd_graph_.outputs.size(); + } + uint32_t num_backward_inputs() { + return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size(); + } + std::vector& save_inputs() { + return save_inputs_; + } + std::vector& save_outputs() { + return save_outputs_; + } + const std::unordered_set& mutable_input_nodes() { + return fwd_graph_.indexed_graph().mutable_input_nodes(); + } + nnvm::Graph GetForwardGraph(const bool recording, + const std::vector& inputs); + nnvm::Graph GetBackwardGraph(const OpStatePtr& state, + const std::vector& reqs, + const std::vector& inputs); + std::vector Gradient(const nnvm::NodePtr& node, + const std::vector& ograds); + void Forward(const std::shared_ptr& op_ptr, + const std::vector& args, + const std::vector& outputs); + void Backward(const bool retain_graph, + const OpStatePtr& state, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs); + + private: + struct CachedOpState { + std::vector buff; + std::vector states; + }; + std::mutex mutex_; + CachedOpConfig config_; + nnvm::Graph fwd_graph_; + nnvm::Graph grad_graph_; + nnvm::Graph full_graph_; + std::unordered_map > params_; + bool inlining_; + std::vector ograd_entries_; + std::vector curr_grad_req_; + std::vector bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_; + std::vector fwd_args_idx_; + std::vector fwd_params_idx_; + std::vector bwd_input_eid_; + std::vector save_inputs_, save_outputs_; + }; /*! \brief whether operator recording is on. */ bool is_training() const { return is_train_; @@ -144,6 +222,15 @@ class Imperative { uint32_t num_inputs, uint32_t num_outputs, std::vector *p_save_inputs, std::vector *p_save_outputs); + void RunGraph( + const bool retain_graph, + const nnvm::IndexedGraph& idx, + const std::vector arrays, + size_t node_start, size_t node_end, + std::vector&& array_reqs, + std::vector&& ref_count, + std::vector *p_states, + const DispatchModeVector& dispatch_modes); /*! \brief indicate whether is training. */ #if DMLC_CXX11_THREAD_LOCAL static thread_local bool is_train_; @@ -160,5 +247,7 @@ class Imperative { int backward_bulk_size_{0}; }; +using CachedOpPtr = std::shared_ptr; + } // namespace mxnet #endif // MXNET_IMPERATIVE_H_ diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index ae96fd87b0db..e243eb71c477 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -155,14 +155,6 @@ class NDArray { return byte_offset_ > 0 || shape() != ptr_->storage_shape; } - /* \brief Check whether the two arrays are the same array */ - inline bool IsSame(const NDArray& other) { - return ptr_ == other.ptr_ && - shape_ == other.shape_ && - byte_offset_ == other.byte_offset_ && - dtype_ == other.dtype_; - } - /*! * \return the shape of current NDArray. */ diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index f4694efad297..3969d8445be1 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -126,36 +126,25 @@ class OpStatePtr { template static OpStatePtr Create(Args&&... args) { OpStatePtr ret; - auto state = new T(std::forward(args)...); - auto var = Engine::Get()->NewVariable(); - ret.ptr_.reset( - new OpState(var, state), - [](OpState* p) { - Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var); - delete reinterpret_cast(p->state); - delete p; - }); + ret.ptr_ = std::make_shared(); + ret.ptr_->var_ = Engine::Get()->NewVariable(); + ret.ptr_->state_.construct(std::forward(args)...); return ret; } /* \brief Get engine variable associated with this state */ engine::VarHandle get_var() const { - return ptr_->var; + return ptr_->var_; } /* \brief Get state of type T */ template T& get_state() const { - return *reinterpret_cast(ptr_->state); + return dmlc::get(ptr_->state_); } /* \brief clear state */ void reset() { ptr_.reset(); } - /* \brief checks whether the managed object is managed only by the current - OpStatePtr instance */ - bool unique() const { - return ptr_.unique(); - } /* \brief Whether state is empty */ explicit operator bool() const { return ptr_ ? true : false; @@ -164,12 +153,16 @@ class OpStatePtr { private: /* \brief state structure */ struct OpState { - engine::VarHandle var; - void* state; - - OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {} + OpState() {} OpState(const OpState& other) = delete; OpState& operator=(const OpState& other) = delete; + + ~OpState() { + Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), var_); + } + + engine::VarHandle var_; + dmlc::any state_; }; /* \brief shared pointer to state */ std::shared_ptr ptr_; diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index f324545a2352..d2cae0c45aaa 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -105,14 +105,28 @@ def _imperative_invoke(handle, ndargs, keys, vals, out): class CachedOp(object): """Cached operator handle.""" __slots__ = ["handle"] - def __init__(self, sym, flags=()): + def __init__(self, sym, flags=(), inputs=None, params=None): self.handle = CachedOpHandle() + param_names = [] + param_arrays = [] + if inputs is None: + assert params is None, "When inputs is None params must also be None." + inputs = sym.list_inputs() + elif params is not None: + for name, arrs in params.items(): + param_arrays.extend(arrs) + param_names.extend([name] * len(arrs)) check_call(_LIB.MXCreateCachedOpEx( sym.handle, len(flags), c_str_array([key for key, _ in flags]), c_str_array([str(val) for _, val in flags]), + len(inputs), + c_str_array(inputs), + len(param_names), + c_str_array(param_names), + c_handle_array(param_arrays), ctypes.byref(self.handle))) def __del__(self): diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 293fafab487b..3b97c0578cae 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -502,16 +502,8 @@ def hybridize(self, active=True, **kwargs): ---------- active : bool, default True Whether to turn hybrid on or off. - static_alloc : bool, default False - Statically allocate memory to improve speed. Memory usage may increase. - static_shape : bool, default False - Optimize for invariant input shapes between iterations. Must also - set static_alloc to True. Change of input shapes is still allowed - but slower. - forward_bulk_size : int, default 15 - Segment size of bulk execution during forward pass. - backward_bulk_size : int, default 15 - Segment size of bulk execution during backward pass. + **kwargs : string + Additional flags for hybridized operator. """ for cld in self._children.values(): cld.hybridize(active, **kwargs) @@ -704,7 +696,7 @@ def __init__(self, prefix=None, params=None): self._out_format = None self._in_format = None self._active = False - self._flags = [] + self._flags = {} def __setattr__(self, name, value): """Registers parameters.""" @@ -731,43 +723,39 @@ def _get_graph(self, *args): return self._cached_graph def _build_cache(self, *args): - data, out = self._get_graph(*args) - data_names = {data.name : i for i, data in enumerate(data)} - params = self.collect_params() - input_names = out.list_inputs() + inputs, out = self._get_graph(*args) + input_names = [i.name for i in inputs] + params = self.collect_params() param_names = set(params.keys()) - expected_names = set(input_names) + expected_names = set(out.list_inputs()) for name in expected_names: - assert name in param_names or name in data_names, \ + assert name in param_names or name in input_names, \ "Unknown input to HybridBlock: %s"%name - used_data_names = [i for i in data_names if i in expected_names] - if len(used_data_names) != len(data_names): - unused = ', '.join(['%d-th'%i for name, i in data_names.items() + used_input_names = [i for i in input_names if i in expected_names] + if len(used_input_names) != len(input_names): + unused = ', '.join(['%d-th'%i for i, name in enumerate(input_names) if name not in expected_names]) warnings.warn("The %s input to HybridBlock is not used by any " "computation. Is this intended?"%unused, stacklevel=4) - used_param_names = [i for i in param_names if i in expected_names] + used_param_names = set(i for i in param_names if i in expected_names) if len(used_param_names) != len(param_names): - unused = ', '.join(list(param_names - set(used_param_names))) + unused = ', '.join(list(param_names - used_param_names)) warnings.warn("Parameter %s is not used by any computation. " "Is this intended?"%unused, stacklevel=4) - data_indices = [] - param_indices = [] - self._cached_op_args = [] - for i, name in enumerate(input_names): - if name in data_names: - data_indices.append(i) - self._cached_op_args.append((True, data_names[name])) - else: - param_indices.append(i) - self._cached_op_args.append((False, params[name])) - flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ - self._flags - self._cached_op = ndarray.CachedOp(out, flags) + used_params = {k: params[k] for k in used_param_names} + try: + param_dict = {k: v.list_data() for k, v in used_params.items()} + except DeferredInitializationError: + self._deferred_infer_shape(*args) + for i in used_params.values(): + i._finish_deferred_init() + param_dict = {k: v.list_data() for k, v in used_params.items()} + + self._cached_op = ndarray.CachedOp(out, self._flags, input_names, param_dict) def _deferred_infer_shape(self, *args): try: @@ -783,19 +771,7 @@ def _call_cached_op(self, *args): args, fmt = _flatten(args, "input") assert fmt == self._in_format, "Invalid input format" - try: - cargs = [args[i] if is_arg else i.data() - for is_arg, i in self._cached_op_args] - except DeferredInitializationError: - self._deferred_infer_shape(*args) - cargs = [] - for is_arg, i in self._cached_op_args: - if is_arg: - cargs.append(args[i]) - else: - i._finish_deferred_init() - cargs.append(i.data()) - out = self._cached_op(*cargs) + out = self._cached_op(*args) if isinstance(out, NDArray): out = [out] return _regroup(out, self._out_format)[0] @@ -816,7 +792,7 @@ def register_child(self, block, name=None): def hybridize(self, active=True, **kwargs): self._active = active - self._flags = list(kwargs.items()) + self._flags = kwargs.items() self._clear_cached_op() if active and self._forward_hooks or self._forward_pre_hooks: warnings.warn('"{}" is being hybridized while still having forward hook/pre-hook. ' diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 34bd4b20aa54..9aabe04656e5 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -36,7 +36,6 @@ #include "../common/utils.h" #include "../common/exec_utils.h" #include "../imperative/imperative_utils.h" -#include "../imperative/cached_op.h" using namespace mxnet; @@ -161,8 +160,12 @@ int MXCreateCachedOp(SymbolHandle handle, std::vector input_names; input_names.reserve(inputs.size()); for (const auto& i : inputs) input_names.push_back(i->attrs.name); - *out = new CachedOpPtr(new CachedOp( - *sym, std::vector >())); + *out = new std::shared_ptr( + new Imperative::CachedOp( + *sym, + std::vector >(), + input_names, + std::unordered_map >())); API_END(); } @@ -170,6 +173,11 @@ int MXCreateCachedOpEx(SymbolHandle handle, int num_flags, const char** keys, const char** vals, + int num_args, + const char** arg_names, + int num_params, + const char** param_names, + NDArrayHandle* params, CachedOpHandle *out) { nnvm::Symbol* sym = static_cast(handle); @@ -178,7 +186,17 @@ int MXCreateCachedOpEx(SymbolHandle handle, for (int i = 0; i < num_flags; ++i) { flags.push_back({keys[i], vals[i]}); } - *out = new CachedOpPtr(new CachedOp(*sym, flags)); + std::vector args; + for (int i = 0; i < num_args; ++i) { + args.push_back(arg_names[i]); + } + std::unordered_map > param_dict; + for (int i = 0; i < num_params; ++i) { + param_dict[param_names[i]].emplace_back( + *reinterpret_cast(params[i])); + } + *out = new std::shared_ptr( + new Imperative::CachedOp(*sym, flags, args, param_dict)); API_END(); } diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index e70cc197c0c3..dc0436e02a8e 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -278,8 +278,6 @@ void ThreadedEngine::DeleteOperator(OprHandle op) { } void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool profiling) { - BulkFlush(); - ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op); OprBlock* opr_block = OprBlock::New(); opr_block->opr = threaded_opr; @@ -325,6 +323,7 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx, << device_count_; } #endif + BulkFlush(); ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait); opr->temporary = true; const bool profiling = profiler_->IsProfiling(profiler::Profiler::kImperative); diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index 72919d90c620..697e4869a049 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -134,10 +134,6 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor { return state_.get_var(); } - OpStatePtr state() const override { - return state_; - } - explicit StatefulComputeExecutor(const OpStatePtr& state, const FStatefulCompute& fcompute, ExecType exec_type, @@ -146,6 +142,7 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor { state_(state), fcompute_(fcompute), exec_type_(exec_type) {} private: + friend Graph AttachOpExecs(Graph g); OpStatePtr state_; FStatefulCompute fcompute_; ExecType exec_type_; @@ -173,16 +170,13 @@ class StatefulComputeExExecutor : public OpExecutor { return state_.get_var(); } - OpStatePtr state() const override { - return state_; - } - explicit StatefulComputeExExecutor(const OpStatePtr& state, const FStatefulComputeEx& fcompute, ExecType exec_type) : state_(state), fcompute_(fcompute), exec_type_(exec_type) {} private: + friend Graph AttachOpExecs(Graph g); OpStatePtr state_; FStatefulComputeEx fcompute_; ExecType exec_type_; @@ -247,15 +241,16 @@ class FComputeExExecutor : public OpExecutor { ExecType exec_type_; }; -void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) { +// pass to attach operator executors +Graph AttachOpExecs(Graph g) { using nnvm::DTypeVector; using nnvm::ShapeVector; using nnvm::FMutateInputs; - static auto& fcreate_op_state = nnvm::Op::GetAttr("FCreateOpState"); - static auto& fmutate_inputs = nnvm::Op::GetAttr("FMutateInputs"); - static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); - static auto& is_layer_backward = nnvm::Op::GetAttr("TIsLayerOpBackward"); + auto& fcreate_op_state = nnvm::Op::GetAttr("FCreateOpState"); + auto& fmutate_inputs = nnvm::Op::GetAttr("FMutateInputs"); + auto& fexec_type = nnvm::Op::GetAttr("FExecType"); + auto& is_layer_backward = nnvm::Op::GetAttr("TIsLayerOpBackward"); const auto& vdtype = g.GetAttr("dtype"); const auto& vshape = g.GetAttr("shape"); @@ -264,88 +259,82 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) { // get the graph const auto& idx = g.indexed_graph(); - OpExecVector& ret = *p_ret; + std::vector > ret(idx.num_nodes()); // initialize the nodes - const auto& inode = idx[i]; - if (inode.source->is_variable()) return; - const nnvm::Op *op = inode.source->op(); - ExecType exec_type = ExecType::kSync; - std::vector mutate_index; - if (fmutate_inputs.count(op)) { - mutate_index = fmutate_inputs[op](inode.source->attrs); - } - if (fexec_type.count(op)) { - exec_type = fexec_type[op](inode.source->attrs); - } - CHECK(dispatch_modes[i] != DispatchMode::kUndefined); - if (fcreate_op_state.count(op)) { - std::vector ishape; - std::vector itype; - for (const auto& e : inode.inputs) { - ishape.emplace_back(vshape[idx.entry_id(e)]); - itype.emplace_back(vdtype[idx.entry_id(e)]); - } - - OpStatePtr state = fcreate_op_state[op]( - inode.source->attrs, vctx[i], ishape, itype); - FStatefulComputeEx fcompute_ex = common::GetFCompute( - op, "FStatefulComputeEx", vctx[i]); - // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx - if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { - ret[i] = std::make_shared(state, fcompute_ex, exec_type); - } else { - FStatefulCompute fcompute = common::GetFCompute( - op, "FStatefulCompute", vctx[i]); - CHECK(fcompute != nullptr) - << "One of FStatefulCompute and FStatefulComputeEx must be registered " - << "for stateful operator " << op->name; - ret[i] = std::make_shared(state, fcompute, - exec_type, mutate_index); + for (size_t i = 0; i < idx.num_nodes(); ++i) { + const auto& inode = idx[i]; + if (inode.source->is_variable()) continue; + const nnvm::Op *op = inode.source->op(); + ExecType exec_type = ExecType::kSync; + std::vector mutate_index; + if (fmutate_inputs.count(op)) { + mutate_index = fmutate_inputs[op](inode.source->attrs); } - } else if (is_layer_backward.get(op, false)) { - CHECK_GE(inode.control_deps.size(), 1); - uint32_t fwd_id = inode.control_deps[0]; - CHECK(vctx[fwd_id] == vctx[i]); - CHECK(ret[fwd_id] != nullptr); - FStatefulComputeEx fcompute_ex = common::GetFCompute( - op, "FStatefulComputeEx", vctx[i]); - // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx - if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { - ret[i] = std::make_shared( - ret[fwd_id].get()->state(), fcompute_ex, exec_type); - } else { - FStatefulCompute fcompute = common::GetFCompute( - op, "FStatefulCompute", vctx[i]); - CHECK(fcompute != nullptr) - << "One of FStatefulCompute and FStatefulComputeEx must be registered " - << "for stateful operator " << op->name; - ret[i] = std::make_shared( - ret[fwd_id].get()->state(), fcompute, exec_type, mutate_index); + if (fexec_type.count(op)) { + exec_type = fexec_type[op](inode.source->attrs); } - } else { - FCompute fcompute = common::GetFCompute(op, "FCompute", vctx[i]); - FComputeEx fcomp_ex = common::GetFCompute(op, "FComputeEx", vctx[i]); - if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { - ret[i] = std::make_shared( - inode.source->attrs, fcomp_ex, exec_type); - } else if (fcompute != nullptr) { - ret[i] = std::make_shared( - inode.source->attrs, fcompute, exec_type, mutate_index); + CHECK(dispatch_modes[i] != DispatchMode::kUndefined); + if (fcreate_op_state.count(op)) { + std::vector ishape; + std::vector itype; + for (const auto& e : inode.inputs) { + ishape.emplace_back(vshape[idx.entry_id(e)]); + itype.emplace_back(vdtype[idx.entry_id(e)]); + } + + OpStatePtr state = fcreate_op_state[op]( + inode.source->attrs, vctx[i], ishape, itype); + FStatefulComputeEx fcompute_ex = common::GetFCompute( + op, "FStatefulComputeEx", vctx[i]); + // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx + if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { + ret[i] = std::make_shared(state, fcompute_ex, exec_type); + } else { + FStatefulCompute fcompute = common::GetFCompute( + op, "FStatefulCompute", vctx[i]); + CHECK(fcompute != nullptr) + << "One of FStatefulCompute and FStatefulComputeEx must be registered " + << "for stateful operator " << op->name; + ret[i] = std::make_shared(state, fcompute, + exec_type, mutate_index); + } + } else if (is_layer_backward.get(op, false)) { + CHECK_GE(inode.control_deps.size(), 1); + uint32_t fwd_id = inode.control_deps[0]; + CHECK(vctx[fwd_id] == vctx[i]); + CHECK(ret[fwd_id] != nullptr); + FStatefulComputeEx fcompute_ex = common::GetFCompute( + op, "FStatefulComputeEx", vctx[i]); + // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx + if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { + ret[i] = std::make_shared( + dynamic_cast(ret[fwd_id].get())->state_, + fcompute_ex, exec_type); + } else { + FStatefulCompute fcompute = common::GetFCompute( + op, "FStatefulCompute", vctx[i]); + CHECK(fcompute != nullptr) + << "One of FStatefulCompute and FStatefulComputeEx must be registered " + << "for stateful operator " << op->name; + ret[i] = std::make_shared( + dynamic_cast(ret[fwd_id].get())->state_, + fcompute, exec_type, mutate_index); + } } else { - LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name; + FCompute fcompute = common::GetFCompute(op, "FCompute", vctx[i]); + FComputeEx fcomp_ex = common::GetFCompute(op, "FComputeEx", vctx[i]); + if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { + ret[i] = std::make_shared( + inode.source->attrs, fcomp_ex, exec_type); + } else if (fcompute != nullptr) { + ret[i] = std::make_shared( + inode.source->attrs, fcompute, exec_type, mutate_index); + } else { + LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name; + } } } -} - - -// pass to attach operator executors -Graph AttachOpExecs(Graph g) { - const auto& idx = g.indexed_graph(); - OpExecVector ret(idx.num_nodes()); - for (size_t i = 0; i < idx.num_nodes(); ++i) { - CreateOpExecs(g, &ret, i); - } g.attrs["op_execs"] = std::make_shared(ret); return g; } diff --git a/src/executor/attach_op_resource_pass.cc b/src/executor/attach_op_resource_pass.cc index 56122cda6ff0..681866296e1c 100644 --- a/src/executor/attach_op_resource_pass.cc +++ b/src/executor/attach_op_resource_pass.cc @@ -30,15 +30,12 @@ namespace mxnet { namespace exec { -void AttachOpResources( - const Graph& g, - const OpExecVector& op_execs, - size_t start_nid, - size_t end_nid) { +Graph AttachOpResources(Graph g) { static auto& fresource = nnvm::Op::GetAttr("FResourceRequest"); static auto& fresource_ex = nnvm::Op::GetAttr("FResourceRequestEx"); + auto& op_execs = nnvm::get(*g.attrs.at("op_execs")); const auto& vctx = g.GetAttr("context"); const auto& vdispatch = g.GetAttr("dispatch_mode"); const auto& dev_masks = g.GetAttr("dev_mask"); @@ -46,7 +43,7 @@ void AttachOpResources( // Use global resource pool for each executor for now. std::map cached_temp; // Resource allocation - for (uint32_t nid = start_nid; nid < end_nid; ++nid) { + for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; if (inode.source->is_variable()) continue; const Context &ctx = vctx[nid]; @@ -87,12 +84,7 @@ void AttachOpResources( requested.push_back(ResourceManager::Get()->Request(ctx, ResourceRequest::kTempSpace)); } } + return g; } - -void AttachOpResources(const Graph& g) { - const auto& op_execs = g.GetAttr("op_execs"); - AttachOpResources(g, op_execs, 0, g.indexed_graph().num_nodes()); -} - } // namespace exec } // namespace mxnet diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 26a249118940..99b1b162eaee 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -82,10 +82,6 @@ class OpExecutor { virtual engine::VarHandle var() const { return nullptr; } - /*! \return return operator state */ - virtual OpStatePtr state() const { - return OpStatePtr(); - } }; /*! @@ -106,14 +102,6 @@ using ContextVector = std::vector; */ using DevMaskVector = std::vector; -/*! - * \brief create OpExecutor for a node in graph - * - * \param g input graph - * \param p_ret OpExecVector for input and output - * \param i the id of the node - */ -void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i); /*! * \brief Attach OpExecutor to the graph attributes. * @@ -127,20 +115,12 @@ Graph AttachOpExecs(Graph g); * \brief Attach Resource to the OpExecVector of the graph. * * \param g input graph need to contain op_exec attribute. - */ -void AttachOpResources(const Graph& g); -/*! - * \brief Attach Resource to the OpExecVector * - * \param g input graph - * \param op_execs OpExecutor vector - * \param start_nid starting node id - * \param end_nid end node id + * \return graph with new attribute "op_exec" of type OpExecVector + * The fields on the OpExecVector are not yet been setup. */ -void AttachOpResources(const Graph& g, - const OpExecVector& op_execs, - size_t start_nid, - size_t end_nid); +Graph AttachOpResources(Graph g); + /*! * \brief Discover chance of inplace addto operators. * i.e. z = plus(z, source_op), and encourage it to become z += source_op. diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 831b5f900237..e28867d5488e 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -912,7 +912,7 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, } g = AttachOpExecs(g); - AttachOpResources(g); + g = AttachOpResources(g); graph_ = std::move(g); if (shared_exec != nullptr) { diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index b40605bd25e2..140b5a5d81e0 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -19,78 +19,16 @@ #include #include #include "./imperative_utils.h" -#include "./cached_op.h" -#include "../executor/exec_pass.h" -#include "../profiler/profiler.h" - namespace mxnet { DMLC_REGISTER_PARAMETER(CachedOpConfig); -struct CachedOp::GraphInfo { - nnvm::Graph fwd_graph; - nnvm::Graph full_graph; - std::vector bwd_output_reqs; - std::vector bwd_input_eid; -}; - -struct CachedOp::DynamicRuntime { - GraphInfo info; - std::vector buff; - std::vector op_states; -}; - -struct CachedOp::CachedOpState { - CachedOpState(const Context& context_, - const nnvm::Graph& fwd_graph_, - const nnvm::Graph& full_graph_) { - context = context_; - info.fwd_graph = fwd_graph_; - info.full_graph = full_graph_; - - size_t max_nodes = info.full_graph.indexed_graph().num_nodes(); - size_t max_entries = info.full_graph.indexed_graph().num_node_entries(); - info.fwd_graph.attrs["context"] = std::make_shared( - std::vector(info.fwd_graph.indexed_graph().num_nodes(), context)); - info.full_graph.attrs["context"] = std::make_shared( - std::vector(max_nodes, context)); - - buff.resize(max_entries); - arrays.resize(max_entries); - array_reqs.resize(max_entries); - dynamic_entries.resize(max_entries, false); - op_states.resize(max_nodes); - execs.resize(max_nodes); - opr_segs.resize(max_nodes); - } - - std::mutex mutex; - Context context; - GraphInfo info; - - bool recording = false; - bool fwd_alloc = false; - bool bwd_alloc = false; - bool fwd_exec_init = false; - bool bwd_exec_init = false; - - std::vector buff; - std::vector arrays; - std::vector array_reqs; - - std::vector op_states; - std::vector > execs; - std::vector opr_segs; - - std::vector dynamic_entries; - std::multimap fwd_reuse_pool; - std::multimap bwd_reuse_pool; -}; - -CachedOp::CachedOp( +Imperative::CachedOp::CachedOp( const nnvm::Symbol& sym, - const std::vector >& flags) { + const std::vector >& flags, + const std::vector arg_names, + const std::unordered_map >& params) { using namespace nnvm; using namespace imperative; static const std::vector zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; @@ -130,22 +68,34 @@ CachedOp::CachedOp( fwd_graph_.attrs["forward_ref_count"] = std::make_shared(std::move(ref_count)); - inlining_ = !config_.static_alloc && - (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit; + inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit; } // Set params { const auto& idx = fwd_graph_.indexed_graph(); - if (config_.data_indices.ndim() || config_.param_indices.ndim()) { - CHECK_EQ(config_.data_indices.ndim() + config_.param_indices.ndim(), - idx.input_nodes().size()); - } else { - std::vector tmp; - for (size_t i = 0; i < idx.input_nodes().size(); ++i) { - tmp.push_back(i); + std::unordered_map arg_name_to_id; + for (size_t i = 0; i < idx.input_nodes().size(); ++i) { + const auto& name = idx[idx.input_nodes()[i]].source->attrs.name; + auto iter = params.find(name); + if (iter == params.end()) { + arg_name_to_id[name] = i; + continue; + } + fwd_params_idx_.push_back(i); + for (const auto& param : iter->second) { + params_[param.ctx()].emplace_back(param); } - config_.data_indices.assign(tmp.begin(), tmp.end()); + } + + CHECK_EQ(arg_name_to_id.size(), arg_names.size()) + << "CachedOp expects " << arg_name_to_id.size() + << " inputs, given " << arg_names.size(); + + for (const auto& name : arg_names) { + auto iter = arg_name_to_id.find(name); + CHECK(iter != arg_name_to_id.end()) << "Unexpected input name " << name; + fwd_args_idx_.push_back(iter->second); } } @@ -157,14 +107,9 @@ CachedOp::CachedOp( } std::vector xs; - const auto& idx = fwd_graph_.indexed_graph(); - for (size_t i = 0; i < idx.input_nodes().size(); ++i) { - auto nid = idx.input_nodes()[i]; - if (idx.mutable_input_nodes().count(nid)) continue; - fwd_input_to_grad_output_[i] = xs.size(); - xs.emplace_back(NodeEntry{idx[nid].weak_ref.lock(), 0, 0}); - } - + std::vector args = sym.ListInputs(Symbol::kReadOnlyArgs); + xs.reserve(args.size()); + for (const auto& i : args) xs.emplace_back(NodeEntry{i, 0, 0}); CHECK_GT(xs.size(), 0) << "There are no inputs in computation graph that require gradients."; @@ -180,7 +125,7 @@ CachedOp::CachedOp( size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries(); full_graph_.outputs = fwd_graph_.outputs; - bwd_output_reqs_ = std::vector(grad_graph_.outputs.size(), kWriteTo); + curr_grad_req_ = std::vector(grad_graph_.outputs.size(), true); for (const auto& i : grad_graph_.outputs) full_graph_.outputs.emplace_back(i); const auto& idx = full_graph_.indexed_graph(); @@ -224,10 +169,7 @@ CachedOp::CachedOp( } } -CachedOp::~CachedOp() { -} - -std::vector CachedOp::Gradient( +std::vector Imperative::CachedOp::Gradient( const nnvm::NodePtr& node, const std::vector& ograds) { using namespace nnvm; @@ -264,15 +206,13 @@ std::vector CachedOp::Gradient( return ret; } - -bool CachedOp::SetForwardGraph( - GraphInfo* info, - const bool recording, - const std::vector& inputs) { +nnvm::Graph Imperative::CachedOp::GetForwardGraph( + const bool recording, const std::vector& inputs) { using namespace nnvm; using namespace imperative; + std::lock_guard lock(mutex_); CHECK_EQ(inputs.size(), num_inputs()); - nnvm::Graph& g = info->fwd_graph; + nnvm::Graph& g = fwd_graph_; ShapeVector shape_inputs; DTypeVector dtype_inputs; @@ -297,22 +237,18 @@ bool CachedOp::SetForwardGraph( g.attrs.erase("forward_mem_plan"); g.attrs.erase("full_mem_plan"); } else if (g.attrs.count(recording ? "full_mem_plan" : "forward_mem_plan")) { - return true; + return g; } const auto& idx = g.indexed_graph(); StorageVector storage(idx.num_node_entries(), exec::kBadStorageID); + for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID; const auto& stypes = g.GetAttr("storage_type"); CHECK_EQ(stypes.size(), storage.size()); for (size_t i = 0; i < stypes.size(); i++) { - if (stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID; - } - for (const auto i : idx.input_nodes()) { - storage[idx.entry_id(i, 0)] = exec::kExternalStorageID; - } - for (size_t i = 0; i < idx.outputs().size(); ++i) { - storage[idx.entry_id(idx.outputs()[i])] = exec::kExternalStorageID; + if (stypes[i] != kDefaultStorage) + storage[i] = exec::kDynamicStorageID; } auto mem_plan = PlanMemory( @@ -321,50 +257,51 @@ bool CachedOp::SetForwardGraph( g.attrs[recording ? "full_mem_plan" : "forward_mem_plan"] = std::make_shared(std::move(mem_plan)); - return false; + return g; } -bool CachedOp::SetBackwardGraph( - GraphInfo* info, +nnvm::Graph Imperative::CachedOp::GetBackwardGraph( + const OpStatePtr& op_state, const std::vector& reqs, - const std::vector& inputs, - bool detect_inplace_addto) { + const std::vector& inputs) { using namespace nnvm; using namespace imperative; std::lock_guard lock(mutex_); - Context default_ctx = inputs[0]->ctx(); - nnvm::Graph& g = info->full_graph; - - if (info->bwd_output_reqs != reqs) { - info->bwd_output_reqs = reqs; - info->bwd_input_eid.clear(); + nnvm::Graph& g = full_graph_; + auto& state = op_state.get_state(); + bool req_match = true; + for (size_t i = 0; i < reqs.size(); ++i) { + if (curr_grad_req_[i] != (reqs[i] != kNullOp)) { + curr_grad_req_[i] = reqs[i] != kNullOp; + req_match = false; + } + } + if (!req_match) { g = nnvm::Graph(); g.outputs = fwd_graph_.outputs; for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) { - if (info->bwd_output_reqs[i] == kNullOp) continue; - g.outputs.emplace_back(grad_graph_.outputs[i]); + if (curr_grad_req_[i]) g.outputs.emplace_back(grad_graph_.outputs[i]); } - g.attrs["context"] = std::make_shared( - std::vector(g.indexed_graph().num_nodes(), default_ctx)); + bwd_input_eid_.clear(); } const auto& idx = g.indexed_graph(); - if (info->bwd_input_eid.size() != inputs.size()) { - info->bwd_input_eid.clear(); + if (bwd_input_eid_.size() != inputs.size()) { + bwd_input_eid_.clear(); for (const auto& i : bwd_ograd_dep_) { auto eid = idx.entry_id(ograd_entries_[i]); - info->bwd_input_eid.push_back(eid); + bwd_input_eid_.push_back(eid); } for (const auto& i : bwd_in_dep_) { auto eid = idx.entry_id(idx.input_nodes()[i], 0); - info->bwd_input_eid.push_back(eid); + bwd_input_eid_.push_back(eid); } for (const auto& i : bwd_out_dep_) { auto eid = idx.entry_id(idx.outputs()[i]); - info->bwd_input_eid.push_back(eid); + bwd_input_eid_.push_back(eid); } - CHECK_EQ(inputs.size(), info->bwd_input_eid.size()); + CHECK_EQ(inputs.size(), bwd_input_eid_.size()); } size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes(); @@ -375,22 +312,25 @@ bool CachedOp::SetBackwardGraph( for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) { for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)]; } - for (size_t i = 0; i < inputs.size(); ++i) ++ref_count[info->bwd_input_eid[i]]; + for (size_t i = 0; i < inputs.size(); ++i) ++ref_count[bwd_input_eid_[i]]; for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)]; g.attrs["backward_ref_count"] = std::make_shared(std::move(ref_count)); } - auto shapes = info->fwd_graph.GetAttr("shape"); - shapes.resize(idx.num_node_entries(), TShape()); - auto dtypes = info->fwd_graph.GetAttr("dtype"); - dtypes.resize(idx.num_node_entries(), -1); - auto stypes = info->fwd_graph.GetAttr("storage_type"); - stypes.resize(idx.num_node_entries(), -1); + ShapeVector shapes(idx.num_node_entries(), TShape()); + DTypeVector dtypes(idx.num_node_entries(), -1); + StorageTypeVector stypes(idx.num_node_entries(), -1); + + for (size_t i = 0; i < num_forward_entries; ++i) { + shapes[i] = state.buff[i].shape(); + dtypes[i] = state.buff[i].dtype(); + stypes[i] = state.buff[i].storage_type(); + } for (size_t i = 0; i < inputs.size(); ++i) { - shapes[info->bwd_input_eid[i]] = inputs[i]->shape(); - dtypes[info->bwd_input_eid[i]] = inputs[i]->dtype(); - stypes[info->bwd_input_eid[i]] = inputs[i]->storage_type(); + shapes[bwd_input_eid_[i]] = inputs[i]->shape(); + dtypes[bwd_input_eid_[i]] = inputs[i]->dtype(); + stypes[bwd_input_eid_[i]] = inputs[i]->storage_type(); } std::pair node_range, entry_range; @@ -402,353 +342,79 @@ bool CachedOp::SetBackwardGraph( node_range, entry_range); match &= CheckAndInferType(&g, std::move(dtypes), false, node_range, entry_range); - exec::DevMaskVector dev_mask(idx.num_nodes(), default_ctx.dev_mask()); + exec::DevMaskVector dev_mask(idx.num_nodes(), inputs[0]->ctx().dev_mask()); match &= CheckAndInferStorageType(&g, std::move(dev_mask), std::move(stypes), false, node_range, entry_range); if (!match) { g.attrs.erase("backward_mem_plan"); } else if (g.attrs.count("backward_mem_plan")) { - return true; + return g; } StorageVector storage(idx.num_node_entries(), exec::kBadStorageID); - const auto& bwd_stypes = g.GetAttr("storage_type"); - for (size_t i = 0; i < bwd_stypes.size(); i++) { - if (bwd_stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID; - } for (size_t i = 0; i < num_forward_entries; ++i) storage[i] = exec::kExternalStorageID; for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID; for (const auto i : idx.outputs()) storage[idx.entry_id(i)] = exec::kExternalStorageID; + for (size_t i = 0; i < stypes.size(); i++) { + if (stypes[i] != kDefaultStorage) + storage[i] = exec::kDynamicStorageID; + } auto mem_plan = PlanMemory( &g, std::move(storage), g.GetAttr >("backward_ref_count"), - {num_forward_nodes, idx.num_nodes()}, - {num_forward_entries, idx.num_node_entries()}, - detect_inplace_addto); + {num_forward_nodes, idx.num_nodes()}, {num_forward_entries, idx.num_node_entries()}); g.attrs["backward_mem_plan"] = std::make_shared(std::move(mem_plan)); - return false; -} - -OpStatePtr CachedOp::GetCachedOpState( - const Context& ctx) { - std::lock_guard lock(mutex_); - for (const auto& i : cached_op_states_[ctx]) { - // only create one state per device when not using static memory - if (!config_.static_alloc || i.unique()) { - return i; - } - } - auto state_ptr = OpStatePtr::Create(ctx, fwd_graph_, full_graph_); - - cached_op_states_[ctx].push_back(state_ptr); - return state_ptr; -} - -void CachedOp::StaticAllocMemory( - const OpStatePtr& state_ptr, - bool recording, - bool keep_fwd) { - using namespace nnvm; - using namespace imperative; - - auto& state = state_ptr.get_state(); - const auto& default_ctx = state.context; - nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph; - const auto& idx = g.indexed_graph(); - const auto& vstorage_inplace = g.GetAttr >("storage_inplace_index"); - const auto& mem_plan = g.GetAttr( - keep_fwd ? "backward_mem_plan" : (recording ? "full_mem_plan" : "forward_mem_plan")); - std::vector addto_entry; - if (g.attrs.count("addto_entry")) { - addto_entry = g.GetAttr >("addto_entry"); - } - size_t start_eid = - keep_fwd ? state.info.fwd_graph.indexed_graph().num_node_entries() : 0; - size_t end_eid = idx.num_node_entries(); - - if (!keep_fwd) state.fwd_alloc = false; - state.bwd_alloc = false; - for (size_t i = start_eid; i < state.buff.size(); ++i) { - state.buff[i] = NDArray(); - state.arrays[i] = &state.buff[i]; - state.array_reqs[i] = kNullOp; - state.dynamic_entries[i] = false; - } - - for (auto i : idx.input_nodes()) { - auto eid = idx.entry_id(i, 0); - if (eid >= start_eid) state.dynamic_entries[eid] = true; - } - for (auto i : idx.outputs()) { - auto eid = idx.entry_id(i); - if (eid >= start_eid) state.dynamic_entries[eid] = true; - } - - for (size_t i = start_eid; i < end_eid; ++i) { - if (addto_entry.size() && addto_entry[i]) { - state.array_reqs[i] = kAddTo; - } else if (vstorage_inplace[i] >= 0) { - state.array_reqs[i] = kWriteInplace; - } else if (vstorage_inplace[i] == -2) { - // -2 indicate that the entry is never referenced. - state.array_reqs[i] = kNullOp; - } else { - state.array_reqs[i] = kWriteTo; - } - } - - auto& reuse_pool = keep_fwd ? state.bwd_reuse_pool : state.fwd_reuse_pool; - reuse_pool = imperative::AllocateMemory( - g, idx, default_ctx, start_eid, end_eid, mem_plan, - state.arrays, &state.array_reqs, std::move(reuse_pool)); - - state.recording = recording; - if (keep_fwd) { - state.bwd_alloc = true; - } else { - state.fwd_alloc = true; - } + return g; } -void CachedOp::StaticInitExec( - const OpStatePtr& state_ptr, - bool recording, - bool keep_fwd) { +void Imperative::CachedOp::Forward( + const std::shared_ptr& op_ptr, + const std::vector& args, + const std::vector& outputs) { using namespace nnvm; using namespace imperative; + static const auto cached_op = nnvm::Op::Get("_CachedOp"); - auto& state = state_ptr.get_state(); - const auto& default_ctx = state.context; - nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph; - const auto& idx = g.indexed_graph(); - std::vector skip_plus_node; - if (g.attrs.count("skip_plus_node")) { - skip_plus_node = g.GetAttr >("skip_plus_node"); - } - size_t start_nid = - keep_fwd ? state.info.fwd_graph.indexed_graph().num_nodes() : 0; - size_t end_nid = idx.num_nodes(); - - if (!keep_fwd) state.fwd_exec_init = false; - state.bwd_exec_init = false; - - for (size_t i = start_nid; i < state.execs.size(); ++i) { - state.execs[i].reset(); - state.opr_segs[i] = EngineOprSeg(); - } - - if (!config_.static_shape) { - for (size_t i = start_nid; i < end_nid; ++i) { - state.opr_segs[i].next_nid = i + 1; - state.opr_segs[i].skip = skip_plus_node.size() && skip_plus_node[i]; - } - } else { - for (size_t i = start_nid; i < end_nid; ++i) { - exec::CreateOpExecs(g, &state.execs, i); - } - exec::AttachOpResources(g, state.execs, start_nid, end_nid); - - for (size_t i = start_nid; i < end_nid; ++i) { - bool skip = idx[i].source->is_variable(); - for (size_t j = 0; !skip && j < idx[i].inputs.size(); ++j) { - skip = state.dynamic_entries[idx.entry_id(idx[i].inputs[j])]; - } - for (size_t j = 0; !skip && j < idx[i].source->num_outputs(); ++j) { - skip = state.dynamic_entries[idx.entry_id(i, j)]; - } - if (skip) continue; - SetupOpExec(g, i, state.execs[i], state.arrays, state.array_reqs); - } + CHECK_EQ(args.size(), fwd_args_idx_.size()) + << "CachedOp requires " << fwd_args_idx_.size() + << " inputs but got " << args.size(); - size_t bulk_size = idx.num_nodes(); - std::unordered_set excludes; - if (recording || keep_fwd) { - bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size; - for (const auto& i : idx.outputs()) excludes.insert(idx.entry_id(i)); - for (const auto& i : idx.input_nodes()) excludes.insert(idx.entry_id(i, 0)); - } + Context default_ctx = args[0]->ctx(); - CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size, excludes, - state.execs, skip_plus_node, &state.opr_segs); - } - if (keep_fwd) { - state.bwd_exec_init = true; - } else { - state.fwd_exec_init = true; + std::vector inputs(num_inputs()); + for (index_t i = 0; i < fwd_args_idx_.size(); ++i) { + inputs[fwd_args_idx_[i]] = args[i]; } -} - -void CachedOp::StaticRunOps( - const Context& default_ctx, - const nnvm::Graph& g, - const OpStatePtr& state_ptr, - size_t start_nid, - size_t end_nid) { - static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); - static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); - - bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning; - bool is_training = Imperative::Get()->is_training(); - auto& state = state_ptr.get_state(); - const auto& idx = g.indexed_graph(); - const auto& dispatch_modes = g.GetAttr("dispatch_mode"); - const auto& op_execs = state.execs; - - std::vector ndinputs, ndoutputs; - nnvm::ShapeVector arg_shapes; - nnvm::DTypeVector arg_dtypes; - std::vector req; + if (fwd_params_idx_.size()) { + CHECK(params_.find(default_ctx) != params_.end()) + << "CachedOp is not initialized on context " << default_ctx; - for (size_t i = start_nid; config_.static_shape && i < end_nid; ++i) { - if (op_execs[i]) op_execs[i]->op_ctx.is_train = is_training; - } - - for (size_t i = start_nid; i < end_nid; i = state.opr_segs[i].next_nid) { - const auto& opr_seg = state.opr_segs[i]; - if (opr_seg.skip) continue; - if (opr_seg.opr != nullptr) { - Engine::Get()->Push(opr_seg.opr.get(), default_ctx, 0, profiling); - } else { - const nnvm::IndexedGraph::Node& node = idx[i]; - if (node.source->is_variable()) continue; - auto num_outputs = node.source->num_outputs(); - ndinputs.clear(); - ndinputs.reserve(node.inputs.size()); - for (const auto& j : node.inputs) { - ndinputs.emplace_back(state.arrays[idx.entry_id(j)]); - CHECK(!ndinputs.back()->is_none()); - } - ndoutputs.clear(); - ndoutputs.reserve(num_outputs); - req.clear(); - req.reserve(num_outputs); - for (size_t j = 0; j < num_outputs; ++j) { - size_t eid = idx.entry_id(i, j); - ndoutputs.emplace_back(state.arrays[eid]); - req.push_back(state.array_reqs[eid]); - CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none()); - } - const DispatchMode dispatch_mode = dispatch_modes[i]; - if (createop.count(node.source->op())) { - arg_shapes.clear(); - arg_dtypes.clear(); - arg_shapes.reserve(ndinputs.size()); - arg_dtypes.reserve(ndinputs.size()); - for (size_t i = 0; i < ndinputs.size(); ++i) { - arg_shapes.emplace_back(ndinputs[i]->shape()); - arg_dtypes.emplace_back(ndinputs[i]->dtype()); - } - state.op_states[i] = createop[node.source->op()]( - node.source->attrs, default_ctx, arg_shapes, arg_dtypes); - Imperative::Get()->InvokeOp( - default_ctx, node.source->attrs, ndinputs, ndoutputs, req, - dispatch_mode, state.op_states[i]); - } else if (is_layer_backward.get(node.source->op(), false)) { - nnvm::Node* fwd_node = node.source->control_deps[0].get(); - auto fwd_node_id = idx.node_id(fwd_node); - Imperative::Get()->InvokeOp( - default_ctx, node.source->attrs, ndinputs, ndoutputs, - req, dispatch_mode, state.op_states[fwd_node_id]); - } else { - Imperative::Get()->InvokeOp( - default_ctx, node.source->attrs, ndinputs, ndoutputs, req, - dispatch_mode); - } + for (size_t i = 0; i < fwd_params_idx_.size(); ++i) { + inputs[fwd_params_idx_[i]] = ¶ms_[default_ctx][i]; } } -} - -OpStatePtr CachedOp::StaticForward( - const Context& default_ctx, - const std::vector& inputs, - const std::vector& outputs) { - using namespace nnvm; - using namespace imperative; + // Initialize bool recording = Imperative::Get()->is_recording(); - auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); - std::lock_guard lock(state.mutex); - - bool match = SetForwardGraph(&state.info, recording, inputs); - match = match && state.recording != recording; - - nnvm::Graph& g = state.info.fwd_graph; + nnvm::Graph g = GetForwardGraph(recording, inputs); const auto& idx = g.indexed_graph(); - if (!state.fwd_alloc || !match) { - StaticAllocMemory(state_ptr, recording, false); - } - - if (config_.static_shape) { - for (auto i : config_.param_indices) { - auto nid = idx.input_nodes()[i]; - if (!state.arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) { - match = false; - auto ptr = &state.buff[idx.entry_id(nid, 0)]; - CHECK_EQ(state.arrays[idx.entry_id(nid, 0)], ptr); - *state.arrays[idx.entry_id(nid, 0)] = *inputs[i]; - state.dynamic_entries[idx.entry_id(nid, 0)] = false; - } - } - for (auto i : config_.data_indices) { - auto eid = idx.entry_id(idx.input_nodes()[i], 0); - state.arrays[eid] = inputs[i]; - } - } else { - for (size_t i = 0; i < num_inputs(); ++i) { - auto nid = idx.input_nodes()[i]; - state.arrays[idx.entry_id(nid, 0)] = inputs[i]; - } - } - - if (!state.fwd_exec_init || !match) { - StaticInitExec(state_ptr, recording, false); - } - - const auto& dtypes = g.GetAttr("dtype"); - const auto& shapes = g.GetAttr("shape"); - const auto& stypes = g.GetAttr("storage_type"); + size_t num_inputs = idx.input_nodes().size(); - for (size_t i = 0; i < outputs.size(); ++i) { - auto eid = idx.entry_id(idx.outputs()[i]); - state.arrays[eid] = outputs[i]; - if (!outputs[i]->is_none()) continue; - *outputs[i] = NDArray(static_cast(stypes[eid]), - shapes[eid], default_ctx, true, dtypes[eid]); + for (size_t i = 0; i < inputs.size(); ++i) { + CHECK_EQ(inputs[i]->ctx(), default_ctx) + << "CachedOp requires all inputs to live on the same context. But " + << idx[idx.input_nodes()[0]].source->attrs.name << " is on " << default_ctx + << " while " << idx[idx.input_nodes()[i]].source->attrs.name << " is on " + << inputs[i]->ctx(); } - StaticRunOps(default_ctx, g, state_ptr, 0, idx.num_nodes()); - - return recording ? state_ptr : OpStatePtr(); -} - - -OpStatePtr CachedOp::DynamicForward( - const Context& default_ctx, - const std::vector& inputs, - const std::vector& outputs) { - using namespace nnvm; - using namespace imperative; - - // Initialize - bool recording = Imperative::Get()->is_recording(); - auto op_state = OpStatePtr::Create(); - auto& runtime = op_state.get_state(); - { - auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); - std::lock_guard lock(state.mutex); - SetForwardGraph(&state.info, recording, inputs); - runtime.info.fwd_graph = state.info.fwd_graph; - } - nnvm::Graph& g = runtime.info.fwd_graph; - const auto& idx = g.indexed_graph(); - size_t num_inputs = idx.input_nodes().size(); - auto& buff = runtime.buff; - auto& states = runtime.op_states; + auto op_state_ptr = OpStatePtr::Create(); + auto& cached_op_state = op_state_ptr.get_state(); + auto& buff = cached_op_state.buff; + auto& states = cached_op_state.states; // Allocate entries states.resize(idx.num_nodes()); @@ -780,98 +446,57 @@ OpStatePtr CachedOp::DynamicForward( AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(), mem_plan, arrays, &array_reqs); - const auto& dtypes = g.GetAttr("dtype"); - const auto& shapes = g.GetAttr("shape"); - const auto& stypes = g.GetAttr("storage_type"); - - for (size_t i = 0; i < outputs.size(); ++i) { - auto eid = idx.entry_id(idx.outputs()[i]); - arrays[eid] = outputs[i]; - if (!outputs[i]->is_none()) continue; - *outputs[i] = NDArray(static_cast(stypes[eid]), - shapes[eid], default_ctx, true, dtypes[eid]); - } - const auto& dispatch_modes = g.GetAttr("dispatch_mode"); if (recording && !inlining_) Imperative::Get()->set_is_recording(false); + int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size); - RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), - std::move(ref_count), &states, dispatch_modes); + Imperative::Get()->RunGraph( + false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), + std::move(ref_count), &states, dispatch_modes); + Engine::Get()->set_bulk_size(prev_bulk_size); Imperative::Get()->set_is_recording(recording); - return op_state; -} - -void CachedOp::Forward( - const std::shared_ptr& op_ptr, - const std::vector& inputs, - const std::vector& outputs) { - static const auto cached_op = nnvm::Op::Get("_CachedOp"); - - CHECK_EQ(inputs.size(), num_inputs()); - - Context default_ctx = inputs[0]->ctx(); - - const auto& idx = fwd_graph_.indexed_graph(); - for (size_t i = 0; i < inputs.size(); ++i) { - CHECK_EQ(inputs[i]->ctx(), default_ctx) - << "CachedOp requires all inputs to live on the same context. But " - << idx[idx.input_nodes()[0]].source->attrs.name - << " is on " << default_ctx << " while " - << idx[idx.input_nodes()[i]].source->attrs.name - << " is on " << inputs[i]->ctx(); - } - - int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size); - - OpStatePtr op_state; - if (config_.static_alloc) { - op_state = StaticForward(default_ctx, inputs, outputs); - } else { - op_state = DynamicForward(default_ctx, inputs, outputs); + for (size_t i = 0; i < idx.num_node_entries(); ++i) { + if (arrays[i] == &buff[i]) continue; + buff[i].shape_ = arrays[i]->shape_; + buff[i].dtype_ = arrays[i]->dtype_; + buff[i].storage_type_ = arrays[i]->storage_type_; } - Engine::Get()->set_bulk_size(prev_bulk_size); - - if (Imperative::Get()->is_recording() && !inlining_) { + if (recording && !inlining_) { nnvm::NodeAttrs attrs; attrs.op = cached_op; attrs.name = "_cachedop"; attrs.parsed = op_ptr; Imperative::Get()->RecordOp( - std::move(attrs), inputs, outputs, op_state, + std::move(attrs), inputs, outputs, op_state_ptr, &save_inputs(), &save_outputs()); } } -void CachedOp::DynamicBackward( +void Imperative::CachedOp::Backward( const bool retain_graph, - const OpStatePtr& op_state, + const OpStatePtr& state, const std::vector& inputs, const std::vector& reqs, const std::vector& outputs) { using namespace nnvm; using namespace imperative; + CHECK(!Imperative::Get()->is_recording()) + << "CachedOp does not support higher order gradients. " + << "If you want to do backward with create_graph=True please " + << "do not use hybridize."; // Initialize - Context default_ctx = outputs[0]->ctx(); - auto& runtime = op_state.get_state(); - { - auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); - std::lock_guard lock(state.mutex); - state.info.fwd_graph = runtime.info.fwd_graph; - SetBackwardGraph(&state.info, reqs, inputs); - runtime.info.full_graph = state.info.full_graph; - runtime.info.bwd_input_eid = state.info.bwd_input_eid; - } - nnvm::Graph& g = runtime.info.full_graph; + nnvm::Graph g = GetBackwardGraph(state, reqs, inputs); const auto& idx = g.indexed_graph(); - auto& buff = runtime.buff; - auto& states = runtime.op_states; + + auto& cached_op_state = state.get_state(); + auto& buff = cached_op_state.buff; + auto& states = cached_op_state.states; size_t num_forward_outputs = fwd_graph_.outputs.size(); size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes(); @@ -881,7 +506,7 @@ void CachedOp::DynamicBackward( arrays.reserve(buff.size()); for (size_t i = 0; i < buff.size(); ++i) arrays.push_back(&buff[i]); for (size_t i = 0; i < inputs.size(); ++i) { - arrays[runtime.info.bwd_input_eid[i]] = inputs[i]; + arrays[bwd_input_eid_[i]] = inputs[i]; } for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) { if (reqs[i] == kNullOp) continue; @@ -905,14 +530,20 @@ void CachedOp::DynamicBackward( if (ref_count[i] == 0) array_reqs[i] = kNullOp; } + Context default_ctx = outputs[0]->ctx(); const auto& mem_plan = g.GetAttr("backward_mem_plan"); AllocateMemory(g, idx, default_ctx, num_forward_entries, idx.num_node_entries(), mem_plan, arrays, &array_reqs); const auto& dispatch_modes = g.GetAttr("dispatch_mode"); - RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), - std::move(array_reqs), std::move(ref_count), &states, dispatch_modes); + int prev_bulk_size = Engine::Get()->set_bulk_size(config_.backward_bulk_size); + + Imperative::Get()->RunGraph( + retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), + std::move(array_reqs), std::move(ref_count), &states, dispatch_modes); + + Engine::Get()->set_bulk_size(prev_bulk_size); if (retain_graph) { buff.resize(num_forward_entries); @@ -922,99 +553,6 @@ void CachedOp::DynamicBackward( } } -void CachedOp::StaticBackward( - const bool retain_graph, - const OpStatePtr& state_ptr, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs) { - using namespace nnvm; - using namespace imperative; - - Context default_ctx = outputs[0]->ctx(); - - auto& state = state_ptr.get_state(); - std::lock_guard lock(state.mutex); - - bool match = SetBackwardGraph(&state.info, reqs, inputs, true); - - nnvm::Graph& g = state.info.full_graph; - const auto& idx = g.indexed_graph(); - auto num_forward_nodes = state.info.fwd_graph.indexed_graph().num_nodes(); - - if (!state.bwd_alloc || !match) { - StaticAllocMemory(state_ptr, true, true); - } - - if (config_.static_shape) { - for (auto i : config_.param_indices) { - const auto iter = fwd_input_to_grad_output_.find(i); - if (iter == fwd_input_to_grad_output_.end()) continue; - auto entry = grad_graph_.outputs[iter->second]; - if (!idx.exist(entry.node.get())) continue; - auto eid = idx.entry_id(entry); - if (!state.arrays[eid]->IsSame(*outputs[iter->second]) || - !(state.array_reqs[eid] == reqs[iter->second])) { - match = false; - state.array_reqs[eid] = reqs[iter->second]; - *state.arrays[eid] = *outputs[iter->second]; - state.dynamic_entries[eid] = false; - } - } - for (auto i : config_.data_indices) { - const auto iter = fwd_input_to_grad_output_.find(i); - if (iter == fwd_input_to_grad_output_.end()) continue; - auto entry = grad_graph_.outputs[iter->second]; - if (!idx.exist(entry.node.get())) continue; - auto eid = idx.entry_id(entry); - state.array_reqs[eid] = reqs[iter->second]; - state.arrays[eid] = outputs[iter->second]; - } - } else { - for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) { - auto entry = grad_graph_.outputs[i]; - if (!idx.exist(entry.node.get())) continue; - auto eid = idx.entry_id(entry); - state.array_reqs[eid] = reqs[i]; - state.arrays[eid] = outputs[i]; - } - } - - if (!state.bwd_exec_init || !match) { - StaticInitExec(state_ptr, true, true); - } - - for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) { - auto eid = state.info.bwd_input_eid[i]; - if (state.dynamic_entries[eid]) state.arrays[eid] = inputs[i]; - } - - StaticRunOps(default_ctx, g, state_ptr, num_forward_nodes, idx.num_nodes()); -} - -void CachedOp::Backward( - const bool retain_graph, - const OpStatePtr& state, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs) { - using namespace imperative; - CHECK(!Imperative::Get()->is_recording()) - << "CachedOp does not support higher order gradients. " - << "If you want to do backward with create_graph=True please " - << "do not use hybridize."; - - int prev_bulk_size = Engine::Get()->set_bulk_size(config_.backward_bulk_size); - - if (config_.static_alloc) { - StaticBackward(retain_graph, state, inputs, reqs, outputs); - } else { - DynamicBackward(retain_graph, state, inputs, reqs, outputs); - } - - Engine::Get()->set_bulk_size(prev_bulk_size); -} - NNVM_REGISTER_OP(_CachedOp) .set_num_inputs([](const NodeAttrs& attrs) { diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h deleted file mode 100644 index 60a40c5e4a52..000000000000 --- a/src/imperative/cached_op.h +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef MXNET_IMPERATIVE_CACHED_OP_H_ -#define MXNET_IMPERATIVE_CACHED_OP_H_ - -#include -#include -#include -#include -#include -#include - -namespace mxnet { -/*! \brief CachedOp Parameters */ -struct CachedOpConfig : public dmlc::Parameter { - uint32_t inline_limit; - uint32_t forward_bulk_size; - uint32_t backward_bulk_size; - bool static_alloc; - bool static_shape; - nnvm::Tuple data_indices; - nnvm::Tuple param_indices; - DMLC_DECLARE_PARAMETER(CachedOpConfig) { - DMLC_DECLARE_FIELD(static_alloc) - .set_default(false) - .describe("Statically allocate memory to improve speed. " - "Memory usage may increase."); - DMLC_DECLARE_FIELD(static_shape) - .set_default(false) - .describe("Optimize for invariant input shapes between iterations. " - "Must also set static_alloc to True. " - "Change of input shapes is still allowed but slower."); - DMLC_DECLARE_FIELD(inline_limit) - .set_default(2) - .describe("Maximum number of operators that can be inlined."); - DMLC_DECLARE_FIELD(forward_bulk_size) - .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)) - .describe("Segment size of bulk execution during forward pass."); - DMLC_DECLARE_FIELD(backward_bulk_size) - .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)) - .describe("Segment size of bulk execution during backward pass."); - DMLC_DECLARE_FIELD(data_indices) - .set_default(nnvm::Tuple()) - .describe("Position of argument variables."); - DMLC_DECLARE_FIELD(param_indices) - .set_default(nnvm::Tuple()) - .describe("Position of parameters."); - } -}; - -class CachedOp { - public: - CachedOp( - const nnvm::Symbol& sym, - const std::vector >& flags); - ~CachedOp(); - uint32_t num_inputs() { - return fwd_graph_.indexed_graph().input_nodes().size(); - } - uint32_t num_outputs() { - return fwd_graph_.outputs.size(); - } - uint32_t num_backward_inputs() { - return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size(); - } - std::vector& save_inputs() { - return save_inputs_; - } - std::vector& save_outputs() { - return save_outputs_; - } - const std::unordered_set& mutable_input_nodes() { - return fwd_graph_.indexed_graph().mutable_input_nodes(); - } - std::vector Gradient( - const nnvm::NodePtr& node, - const std::vector& ograds); - void Forward( - const std::shared_ptr& op_ptr, - const std::vector& inputs, - const std::vector& outputs); - void Backward( - const bool retain_graph, - const OpStatePtr& state, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs); - - private: - struct GraphInfo; - struct DynamicRuntime; - struct CachedOpState; - - OpStatePtr GetCachedOpState(const Context& ctx); - bool SetForwardGraph( - GraphInfo* info, - const bool recording, - const std::vector& inputs); - bool SetBackwardGraph( - GraphInfo* info, - const std::vector& reqs, - const std::vector& inputs, - bool detect_inplace_addto = false); - OpStatePtr DynamicForward( - const Context& default_ctx, - const std::vector& inputs, - const std::vector& outputs); - void DynamicBackward( - const bool retain_graph, - const OpStatePtr& op_state, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs); - void StaticAllocMemory( - const OpStatePtr& state_ptr, - bool recording, - bool keep_fwd); - void StaticInitExec( - const OpStatePtr& state_ptr, - bool recording, - bool keep_fwd); - void StaticRunOps( - const Context& default_ctx, - const nnvm::Graph& g, - const OpStatePtr& state_ptr, - size_t start_nid, - size_t end_nid); - OpStatePtr StaticForward( - const Context& default_ctx, - const std::vector& inputs, - const std::vector& outputs); - void StaticBackward( - const bool retain_graph, - const OpStatePtr& state_ptr, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs); - - CachedOpConfig config_; - nnvm::Graph fwd_graph_; - nnvm::Graph grad_graph_; - nnvm::Graph full_graph_; - bool inlining_; - std::vector ograd_entries_; - std::vector bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_; - std::unordered_map fwd_input_to_grad_output_; - std::vector save_inputs_, save_outputs_; - std::vector bwd_output_reqs_; - - std::mutex mutex_; - std::unordered_map > cached_op_states_; -}; - -using CachedOpPtr = std::shared_ptr; - -} // namespace mxnet -#endif // MXNET_IMPERATIVE_CACHED_OP_H_ diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index e1654259a2fb..7caf305eac75 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -19,7 +19,6 @@ #include #include #include "./imperative_utils.h" -#include "./cached_op.h" namespace mxnet { #if DMLC_CXX11_THREAD_LOCAL @@ -267,6 +266,95 @@ void Imperative::RecordOp( } } +void Imperative::RunGraph( + const bool retain_graph, + const nnvm::IndexedGraph& idx, + const std::vector arrays, + size_t node_start, size_t node_end, + std::vector&& array_reqs, + std::vector&& ref_count, + std::vector *p_states, + const DispatchModeVector &dispatch_modes) { + using namespace nnvm; + using namespace imperative; + static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); + static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); + static const auto bwd_cached_op = Op::Get("_backward_CachedOp"); + + std::vector& states = *p_states; + bool recording = is_recording(); + + std::vector ndinputs, ndoutputs; + ShapeVector arg_shapes; + DTypeVector arg_dtypes; + std::vector req; + + for (size_t i = node_start; i < node_end; ++i) { + const nnvm::IndexedGraph::Node& node = idx[i]; + if (node.source->op() == nullptr) continue; + auto num_outputs = node.source->num_outputs(); + ndinputs.clear(); + ndinputs.reserve(node.inputs.size()); + for (const auto& j : node.inputs) { + ndinputs.emplace_back(arrays[idx.entry_id(j)]); + CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index; + } + ndoutputs.clear(); + ndoutputs.reserve(num_outputs); + req.clear(); + req.reserve(num_outputs); + for (size_t j = 0; j < num_outputs; ++j) { + size_t eid = idx.entry_id(i, j); + ndoutputs.emplace_back(arrays[eid]); + req.push_back(array_reqs[eid]); + CHECK(!ndoutputs.back()->is_none()); + } + const Context& ctx = ndoutputs[0]->ctx(); + const DispatchMode dispatch_mode = dispatch_modes[i]; + if (node.source->op() == bwd_cached_op) { + const auto& cached_op = dmlc::get(node.source->attrs.parsed); + nnvm::Node* fwd_node = node.source->control_deps[0].get(); + auto fwd_node_id = idx.node_id(fwd_node); + cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs); + } else if (createop.count(node.source->op())) { + arg_shapes.clear(); + arg_dtypes.clear(); + arg_shapes.reserve(ndinputs.size()); + arg_dtypes.reserve(ndinputs.size()); + for (size_t i = 0; i < ndinputs.size(); ++i) { + arg_shapes.emplace_back(ndinputs[i]->shape()); + arg_dtypes.emplace_back(ndinputs[i]->dtype()); + } + states[i] = createop[node.source->op()]( + node.source->attrs, ctx, arg_shapes, arg_dtypes); + InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]); + if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]); + } else if (is_layer_backward.get(node.source->op(), false)) { + nnvm::Node* fwd_node = node.source->control_deps[0].get(); + auto fwd_node_id = idx.node_id(fwd_node); + InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, + req, dispatch_mode, states[fwd_node_id]); + if (recording) { + RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]); + } + } else { + InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode); + if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs); + } + + for (const auto& j : node.inputs) { + size_t eid = idx.entry_id(j); + --ref_count[eid]; + if (ref_count[eid] == 0) arrays[eid]->ptr_.reset(); + } + for (size_t j = 0; j < ndoutputs.size(); ++j) { + size_t eid = idx.entry_id(i, j); + if (ref_count[eid] == 0) arrays[eid]->ptr_.reset(); + } + } +} + + std::vector Imperative::Backward( const std::vector& outputs, const std::vector& ograds, diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc deleted file mode 100644 index 464aefc220de..000000000000 --- a/src/imperative/imperative_utils.cc +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "./imperative_utils.h" -#include "./cached_op.h" - -namespace mxnet { -namespace imperative { -void RunGraph( - const bool retain_graph, - const nnvm::IndexedGraph& idx, - const std::vector arrays, - size_t node_start, size_t node_end, - std::vector&& array_reqs, - std::vector&& ref_count, - std::vector *p_states, - const DispatchModeVector &dispatch_modes) { - using namespace nnvm; - using namespace imperative; - static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); - static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); - static const auto bwd_cached_op = Op::Get("_backward_CachedOp"); - - const auto imp = Imperative::Get(); - - std::vector& states = *p_states; - bool recording = imp->is_recording(); - - std::vector ndinputs, ndoutputs; - ShapeVector arg_shapes; - DTypeVector arg_dtypes; - std::vector req; - - for (size_t i = node_start; i < node_end; ++i) { - const nnvm::IndexedGraph::Node& node = idx[i]; - if (node.source->op() == nullptr) continue; - auto num_outputs = node.source->num_outputs(); - ndinputs.clear(); - ndinputs.reserve(node.inputs.size()); - for (const auto& j : node.inputs) { - ndinputs.emplace_back(arrays[idx.entry_id(j)]); - CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index; - } - ndoutputs.clear(); - ndoutputs.reserve(num_outputs); - req.clear(); - req.reserve(num_outputs); - for (size_t j = 0; j < num_outputs; ++j) { - size_t eid = idx.entry_id(i, j); - ndoutputs.emplace_back(arrays[eid]); - req.push_back(array_reqs[eid]); - CHECK(array_reqs[eid] == kNullOp || !ndoutputs.back()->is_none()); - } - const Context& ctx = ndoutputs[0]->ctx(); - const DispatchMode dispatch_mode = dispatch_modes[i]; - if (node.source->op() == bwd_cached_op) { - const auto& cached_op = dmlc::get(node.source->attrs.parsed); - nnvm::Node* fwd_node = node.source->control_deps[0].get(); - auto fwd_node_id = idx.node_id(fwd_node); - cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs); - } else if (createop.count(node.source->op())) { - arg_shapes.clear(); - arg_dtypes.clear(); - arg_shapes.reserve(ndinputs.size()); - arg_dtypes.reserve(ndinputs.size()); - for (size_t i = 0; i < ndinputs.size(); ++i) { - arg_shapes.emplace_back(ndinputs[i]->shape()); - arg_dtypes.emplace_back(ndinputs[i]->dtype()); - } - states[i] = createop[node.source->op()]( - node.source->attrs, ctx, arg_shapes, arg_dtypes); - imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]); - if (recording) { - imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]); - } - } else if (is_layer_backward.get(node.source->op(), false)) { - nnvm::Node* fwd_node = node.source->control_deps[0].get(); - auto fwd_node_id = idx.node_id(fwd_node); - imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, - req, dispatch_mode, states[fwd_node_id]); - if (recording) { - imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]); - } - } else { - imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode); - if (recording) { - imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs); - } - } - - for (const auto& j : node.inputs) { - size_t eid = idx.entry_id(j); - --ref_count[eid]; - if (ref_count[eid] == 0) *arrays[eid] = NDArray(); - } - for (size_t j = 0; j < ndoutputs.size(); ++j) { - size_t eid = idx.entry_id(i, j); - if (ref_count[eid] == 0) *arrays[eid] = NDArray(); - } - } -} - -} // namespace imperative -} // namespace mxnet diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 726531d02994..06b7e058dd14 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -23,7 +23,6 @@ #include #include #include -#include #include #include "../executor/graph_executor.h" #include "../executor/exec_pass.h" @@ -39,24 +38,11 @@ namespace mxnet { namespace imperative { struct MemoryPlanInfo { - int storage_id; - uint32_t root; + uint32_t sid; size_t size; bool inplace; }; -struct EngineOprDeleter { - void operator()(engine::Opr* handle) { - Engine::Get()->DeleteOperator(handle); - } -}; - -struct EngineOprSeg { - bool skip; - size_t next_nid; - std::unique_ptr opr; -}; - using MemoryPlanVector = std::vector; inline Context GetContext(const nnvm::NodeAttrs& attrs, @@ -729,12 +715,10 @@ inline std::vector PlaceDevice(const nnvm::IndexedGraph& idx) { inline MemoryPlanVector PlanMemory( - nnvm::Graph* p_g, - nnvm::StorageVector&& storage, + nnvm::Graph* p_g, nnvm::StorageVector&& storage, const std::vector& ref_count, const std::pair& node_range = {0, 0}, - const std::pair& entry_range = {0, 0}, - bool detect_inplace_addto = false) { + const std::pair& entry_range = {0, 0}) { using namespace nnvm; nnvm::Graph& g = *p_g; const auto& idx = g.indexed_graph(); @@ -744,31 +728,31 @@ inline MemoryPlanVector PlanMemory( g.attrs["ref_count"] = std::make_shared(ref_count); g.attrs["storage"] = std::make_shared(std::move(storage)); g = nnvm::ApplyPass(g, "PlanMemory"); - if (detect_inplace_addto) g = exec::DetectInplaceAddTo(g); const auto& dtypes = g.GetAttr("dtype"); const auto& shapes = g.GetAttr("shape"); - const auto& storage_inplace = g.GetAttr >("storage_inplace_index"); - const auto& storage_ids = g.GetAttr("storage_id"); + const auto& stypes = g.GetAttr("storage_type"); + auto storage_ids = g.MoveCopyAttr("storage_id"); + auto storage_inplace = g.MoveCopyAttr >("storage_inplace_index"); uint32_t entry_start = entry_range.first; uint32_t entry_end = entry_range.second > entry_start ? entry_range.second : idx.num_node_entries(); MemoryPlanVector mem_plan(idx.num_node_entries()); - std::unordered_map sid_to_root; + std::unordered_map sid_to_loc; for (uint32_t i = entry_start; i < entry_end; ++i) { + if (stypes[i] != kDefaultStorage) continue; if (storage_ids[i] < 0) { - mem_plan[i] = {storage_ids[i], i, 0, false}; - } else if (!sid_to_root.count(storage_ids[i])) { + mem_plan[i] = {i, mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(), false}; + } else if (!sid_to_loc.count(storage_ids[i])) { CHECK_LT(storage_inplace[i], 0); - sid_to_root[storage_ids[i]] = i; - mem_plan[i] = {storage_ids[i], i, - mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(), - false}; + sid_to_loc[storage_ids[i]] = i; + mem_plan[i].sid = i; + mem_plan[i].size = mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(); } else { - uint32_t root = sid_to_root[storage_ids[i]]; - mem_plan[i] = {storage_ids[i], root, 0, storage_inplace[i] >= 0}; - mem_plan[root].size = std::max(mem_plan[root].size, + uint32_t loc = sid_to_loc[storage_ids[i]]; + mem_plan[i] = {loc, 0, storage_inplace[i] >= 0}; + mem_plan[loc].size = std::max(mem_plan[loc].size, mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size()); } } @@ -777,213 +761,39 @@ inline MemoryPlanVector PlanMemory( } -inline std::multimap AllocateMemory( - const nnvm::Graph& g, - const nnvm::IndexedGraph& idx, - const Context& default_ctx, - const uint32_t entry_start, const uint32_t entry_end, - const MemoryPlanVector& mem_plan, - const std::vector& arrays, - std::vector *array_reqs, - std::multimap&& pool = std::multimap()) { +inline void AllocateMemory(const nnvm::Graph& g, + const nnvm::IndexedGraph& idx, + const Context& default_ctx, + const uint32_t entry_start, const uint32_t entry_end, + const MemoryPlanVector& mem_plan, + const std::vector& arrays, + std::vector *array_reqs) { using namespace nnvm; const auto& dtypes = g.GetAttr("dtype"); const auto& shapes = g.GetAttr("shape"); const auto& stypes = g.GetAttr("storage_type"); - std::multimap new_pool; - for (uint32_t i = entry_start; i < entry_end; ++i) { - if (mem_plan[i].storage_id == exec::kExternalStorageID) continue; - CHECK(arrays[i]->is_none()); - if (mem_plan[i].storage_id == exec::kDynamicStorageID) { - *arrays[i] = NDArray(static_cast(stypes[i]), - shapes[i], default_ctx, true, dtypes[i]); - continue; - } - CHECK_EQ(stypes[i], kDefaultStorage); - if (mem_plan[i].root == i) { - CHECK_GT(mem_plan[i].size, 0); - auto iter = pool.lower_bound(mem_plan[i].size); - if (iter != pool.end()) { - *arrays[i] = iter->second.AsArray(shapes[i], dtypes[i]); - new_pool.insert(*iter); - pool.erase(iter); - } else { + if (!arrays[i]->is_none()) continue; + if (stypes[i] == kDefaultStorage) { + if (mem_plan[i].sid == i) { + CHECK_GT(mem_plan[i].size, 0); NDArray buff(TShape({static_cast(mem_plan[i].size)}), default_ctx, true, mshadow::kUint8); *arrays[i] = buff.AsArray(shapes[i], dtypes[i]); - new_pool.insert({mem_plan[i].size, buff}); - } - } else { - CHECK_GE(mem_plan[mem_plan[i].root].storage_id, 0); - *arrays[i] = arrays[mem_plan[i].root]->AsArray(shapes[i], dtypes[i]); - if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) { - array_reqs->at(i) = kWriteInplace; - } - } - } - - return new_pool; -} - -inline void SetupOpExec( - const nnvm::Graph& g, - size_t nid, - const std::shared_ptr& exec, - const std::vector arrays, - const std::vector array_reqs) { - const auto& idx = g.indexed_graph(); - const auto& inode = idx[nid]; - CHECK_EQ(exec->in_array.size(), 0U); - CHECK_EQ(exec->out_array.size(), 0U); - for (const auto& e : inode.inputs) { - CHECK(!arrays[idx.entry_id(e)]->is_none()) << inode.source->attrs.name; - exec->in_array.push_back(*arrays[idx.entry_id(e)]); - } - for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { - uint32_t eid = idx.entry_id(nid, index); - CHECK(!arrays[eid]->is_none()) << inode.source->attrs.name; - exec->out_array.push_back(*arrays[eid]); - exec->req.push_back(array_reqs[eid]); - } - - exec->Setup(); -} - -inline Engine::OprHandle CreateEngineOp( - const Context& default_ctx, - const std::vector >& execs) { - CHECK_GT(execs.size(), 0); - std::vector use_vars, mutate_vars; - - for (const auto& exec : execs) { - CHECK_GT(exec->out_array.size(), 0); - CHECK(execs.size() == 1 || exec->exec_type() == ExecType::kSync); - - // the variables - for (const auto& nd : exec->in_array) { - use_vars.push_back(nd.var()); - } - for (auto& r : exec->op_ctx.requested) { - mutate_vars.push_back(r.var); - } - for (auto& nd : exec->out_array) { - mutate_vars.push_back(nd.var()); - } - if (exec->var() != nullptr) { - mutate_vars.push_back(exec->var()); - } - } - - // dedup vars - Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars); - bool is_gpu = default_ctx.dev_mask() == gpu::kDevMask; - bool is_async = execs.size() > 1 ? false : execs[0]->exec_type() == ExecType::kAsync; - - auto exec_fun = [execs, is_async, is_gpu] ( - RunContext ctx, Engine::CallbackOnComplete on_complete) { - if (is_async) { - execs[0]->op_ctx.async_on_complete = on_complete; - } - for (const auto& exec : execs) exec->Run(ctx, is_gpu); - // call on complete only if it is async op - if (!is_async) { - if (is_gpu) { - #if MXNET_USE_CUDA - // Wait GPU kernel to finish. - ctx.get_stream()->Wait(); - #else - LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; - #endif - } - on_complete(); - } - }; - - return Engine::Get()->NewOperator( - exec_fun, use_vars, mutate_vars, FnProperty::kNormal); -} - -inline void CreateEngineOpSeg( - const nnvm::IndexedGraph& idx, - const Context default_ctx, - const size_t start_nid, - const size_t end_nid, - const size_t bulk_size, - const std::unordered_set& excludes, - const std::vector >& execs, - const std::vector skip_plus_node, - std::vector *opr_segs) { - size_t seg_start = start_nid; - std::vector > seg_execs; - for (size_t nid = start_nid; nid < end_nid; ++nid) { - const auto& node = idx[nid]; - if (node.source->is_variable()) continue; - if (skip_plus_node.size() && skip_plus_node[nid]) continue; - auto& exec = execs[nid]; - bool is_async = exec->exec_type() != ExecType::kSync; - bool valid = exec->out_array.size() > 0; - - // Stop at async nodes and invalid node (due to input/output is not allocated) - bool stop = is_async || !valid || seg_execs.size() >= bulk_size; - for (size_t i = 0; i < node.inputs.size() && !stop; ++i) { - if (excludes.count(idx.entry_id(node.inputs[i]))) stop = true; - } - auto num_outputs = node.source->num_outputs(); - for (size_t i = 0; i < num_outputs && !stop; ++i) { - if (excludes.count(idx.entry_id(nid, i))) stop = true; - } - - // Create opr segment for previous nodes. - if (stop && nid > seg_start) { - auto& seg = (*opr_segs)[seg_start]; - if (seg_execs.size()) { - seg = EngineOprSeg{false, nid}; - seg.opr.reset(CreateEngineOp(default_ctx, seg_execs)); } else { - seg = EngineOprSeg{true, nid, nullptr}; + *arrays[i] = arrays[mem_plan[i].sid]->AsArray(shapes[i], dtypes[i]); + if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) { + array_reqs->at(i) = kWriteInplace; + } } - seg_start = nid; - seg_execs.clear(); - } - - seg_execs.push_back(exec); - - auto& seg = (*opr_segs)[nid]; - if (is_async) { - seg = EngineOprSeg{false, nid + 1}; - seg.opr.reset(CreateEngineOp(default_ctx, seg_execs)); - seg_execs.clear(); - seg_start = nid + 1; - } else if (!valid) { - seg = EngineOprSeg{false, nid + 1, nullptr}; - seg_execs.clear(); - seg_start = nid + 1; - } - } - // The last segment - if (end_nid > seg_start) { - auto& seg = (*opr_segs)[seg_start]; - if (seg_execs.size()) { - seg = EngineOprSeg{false, end_nid}; - seg.opr.reset(CreateEngineOp(default_ctx, seg_execs)); } else { - seg = EngineOprSeg{true, end_nid, nullptr}; + *arrays[i] = NDArray(static_cast(stypes[i]), + shapes[i], default_ctx, true, dtypes[i]); } } } - -void RunGraph(const bool retain_graph, - const nnvm::IndexedGraph& idx, - const std::vector arrays, - size_t node_start, size_t node_end, - std::vector&& array_reqs, - std::vector&& ref_count, - std::vector *p_states, - const DispatchModeVector &dispatch_modes); - } // namespace imperative } // namespace mxnet diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index bb61af127240..451fde2eb867 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -22,7 +22,6 @@ from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID from common import setup_module, with_seed, assertRaises, teardown import numpy as np -from numpy.testing import assert_array_equal from nose.tools import raises, assert_raises from copy import deepcopy import warnings @@ -1125,6 +1124,7 @@ def test_hybrid_multi_context(): net.hybridize() net(mx.nd.zeros((1, 3, 32, 32), ctx=mx.cpu(0))).asnumpy() + @with_seed() def test_zero_grad(): data = mx.nd.random.uniform(shape=(3,3)) @@ -1137,60 +1137,6 @@ def test_zero_grad(): grad = net.collect_params()['test_zero_grad_weight'].grad() assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0) -def check_hybrid_static_memory(**kwargs): - x = mx.nd.random.uniform(shape=(2, 3, 32, 32)) - x.attach_grad() - - net1 = gluon.model_zoo.vision.get_resnet( - 1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context()) - net2 = gluon.model_zoo.vision.get_resnet( - 1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context()) - net2.hybridize(**kwargs) - net1(x) - net2(x) - - def test(net, x): - with mx.autograd.record(): - y = net(x) + net(x) - y.backward() - - grads = {k: v.grad() for k, v in net.collect_params().items() if v.grad_req != 'null'} - - return y, grads - - y1, grads1 = test(net1, x) - y2, grads2 = test(net2, x) - - assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5) - for key in grads1: - assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5) - -def test_hybrid_static_memory(): - check_hybrid_static_memory() - check_hybrid_static_memory(static_alloc=True) - check_hybrid_static_memory(static_alloc=True, static_shape=True) - -def check_hybrid_static_memory_switching(**kwargs): - net = gluon.model_zoo.vision.get_resnet( - 1, 18, pretrained=True, ctx=mx.context.current_context()) - net.hybridize(**kwargs) - - x = mx.nd.random.uniform(shape=(4, 3, 32, 32)) - net(x) - with mx.autograd.record(): - y = net(x) - y.backward() - x = mx.nd.random.uniform(shape=(2, 3, 32, 32)) - net(x) - with mx.autograd.record(): - y = net(x) - y.backward() - mx.nd.waitall() - -def test_hybrid_static_memory_switching(): - check_hybrid_static_memory_switching() - check_hybrid_static_memory_switching(static_alloc=True) - check_hybrid_static_memory_switching(static_alloc=True, static_shape=True) @with_seed() def test_hook():