diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 1128975fa31c..15aebfe9ba13 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1348,72 +1348,100 @@ void GraphExecutor::InitOpSegs() { bool prefer_bulk_exec_inference = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_INFERENCE", true); // Whether to perform bulk exec for training bool prefer_bulk_exec = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1); + + bool is_training = num_forward_nodes_ != total_num_nodes; + + if (prefer_bulk_exec && is_training) { + this->BulkTrainingOpSegs(total_num_nodes); + } + + if (prefer_bulk_exec_inference && !is_training) { + this->BulkInferenceOpSegs(); + } + + return; +} + +void GraphExecutor::BulkTrainingOpSegs(size_t total_num_nodes) { // The maximum number of node in a segment executed in bulk size_t num_nodes_threshold = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15); - if (prefer_bulk_exec_inference && num_forward_nodes_ == total_num_nodes) { - // Bulk the whole graph for inference - cached_seg_opr_[0] = this->CreateCachedSegOpr(0, num_forward_nodes_); - return; - } - - if (prefer_bulk_exec) { - // create forward segments for training - size_t topo_start = 0; - for (size_t nid = 0; nid < num_forward_nodes_; nid++) { - auto &node = graph_.indexed_graph()[nid].source; - auto &op_node = op_nodes_[nid]; - // check if the segment relies on external input, or exceeds maxinum number of node, - // or requires async ops - if (node->is_variable() || nid - topo_start > num_nodes_threshold || - op_node.exec->exec_type() != ExecType::kSync) { - // create a new segment for the previous nodes if the current one cannot be bulked - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); - topo_start = nid + 1; - } - } - // the last segmenet - if (topo_start != num_forward_nodes_) { - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, num_forward_nodes_); + + // create forward segments for training + size_t topo_start = 0; + for (size_t nid = 0; nid < num_forward_nodes_; nid++) { + auto &node = graph_.indexed_graph()[nid].source; + auto &op_node = op_nodes_[nid]; + // check if the segment relies on external input, or exceeds maxinum number of node, + // or requires async ops + if (node->is_variable() || nid - topo_start > num_nodes_threshold || + op_node.exec->exec_type() != ExecType::kSync) { + // create a new segment for the previous nodes if the current one cannot be bulked + cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); + topo_start = nid + 1; } + } + // the last segment + if (topo_start != num_forward_nodes_) { + cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, num_forward_nodes_); + } - // create backward segments for training - // get all gradient variables - std::unordered_set grad_vars; - for (auto &kv : grad_store_) { - grad_vars.insert(kv.second.var()); + // create backward segments for training + // get all gradient variables + std::unordered_set grad_vars; + for (auto &kv : grad_store_) { + grad_vars.insert(kv.second.var()); + } + auto &idx = graph_.indexed_graph(); + topo_start = num_forward_nodes_; + for (size_t nid = num_forward_nodes_; nid < total_num_nodes; nid++) { + auto &op_node = op_nodes_[nid]; + if (op_node.skip_exec_node || op_node.exec == nullptr) { + continue; } - auto &idx = graph_.indexed_graph(); - topo_start = num_forward_nodes_; - for (size_t nid = num_forward_nodes_; nid < total_num_nodes; nid++) { - auto &op_node = op_nodes_[nid]; - if (op_node.skip_exec_node || op_node.exec == nullptr) { - continue; + if (idx[nid].source->is_variable() || nid - topo_start > num_nodes_threshold || + op_node.exec->exec_type() != ExecType::kSync) { + cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); + topo_start = nid + 1; + } else { + // If it produces output gradient, don't include it in the segment + bool output_gradient = false; + for (auto &out_arr : op_node.exec->out_array) { + if (grad_vars.find(out_arr.var()) != grad_vars.end()) { + output_gradient = true; + } } - if (idx[nid].source->is_variable() || nid - topo_start > num_nodes_threshold || - op_node.exec->exec_type() != ExecType::kSync) { + if (output_gradient) { cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); topo_start = nid + 1; - } else { - // If it produces output gradient, don't include it in the segment - bool output_gradient = false; - for (auto &out_arr : op_node.exec->out_array) { - if (grad_vars.find(out_arr.var()) != grad_vars.end()) { - output_gradient = true; - } - } - if (output_gradient) { - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); - topo_start = nid + 1; - } } } - // last segment for backward - if (topo_start < total_num_nodes) { - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, total_num_nodes); - } } + // last segment for backward + if (topo_start < total_num_nodes) { + cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, total_num_nodes); + } +} - return; +void GraphExecutor::BulkInferenceOpSegs() { + // Attempt to bulk the whole graph for inference. We will only create new segments when + // required for non-kSync operations. + size_t topo_start = 0; + for (size_t nid = 0; nid < num_forward_nodes_; nid++) { + auto &node = graph_.indexed_graph()[nid].source; + auto &op_node = op_nodes_[nid]; + + // Variables do not need to be segmented at inference time. + if (node->is_variable()) continue; + + if (op_node.exec->exec_type() != ExecType::kSync) { + cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); + topo_start = nid + 1; + } + } + // The last segment + if (topo_start != num_forward_nodes_) { + cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, num_forward_nodes_); + } } void GraphExecutor::ExecuteMonCallback(size_t nid) { diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index 0e5ef3298945..ee32db72cebd 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -197,6 +197,10 @@ class GraphExecutor : public Executor { CachedSegOpr CreateCachedSegOpr(size_t topo_start, size_t topo_end); // run the monitor callback for node `nid` void ExecuteMonCallback(size_t nid); + // peform bulking and segmentation on an inference graph + void BulkInferenceOpSegs(); + // perform bulking and segmentation on a training graph + void BulkTrainingOpSegs(size_t total_num_nodes); // internal graph nnvm::Graph graph_;