Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Re-enable all op segments when in batch mode #9055

Merged
merged 2 commits into from
Jan 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 82 additions & 53 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1348,71 +1348,100 @@ void GraphExecutor::InitOpSegs() {
bool prefer_bulk_exec_inference = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_INFERENCE", true);
// Whether to perform bulk exec for training
bool prefer_bulk_exec = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1);

bool is_training = num_forward_nodes_ != total_num_nodes;

if (prefer_bulk_exec && is_training) {
this->BulkTrainingOpSegs(total_num_nodes);
}

if (prefer_bulk_exec_inference && !is_training) {
this->BulkInferenceOpSegs();
}

return;
}

void GraphExecutor::BulkTrainingOpSegs(size_t total_num_nodes) {
// The maximum number of node in a segment executed in bulk
size_t num_nodes_threshold = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15);
if (prefer_bulk_exec_inference && num_forward_nodes_ == total_num_nodes) {
// bulk the whole graph for inference
num_nodes_threshold = std::numeric_limits<size_t>::max();
}

if (prefer_bulk_exec) {
// create forward segments for training
size_t topo_start = 0;
for (size_t nid = 0; nid < num_forward_nodes_; nid++) {
auto &node = graph_.indexed_graph()[nid].source;
auto &op_node = op_nodes_[nid];
// check if the segment relies on external input, or exceeds maxinum number of node,
// or requires async ops
if (node->is_variable() || nid - topo_start > num_nodes_threshold ||
op_node.exec->exec_type() != ExecType::kSync) {
// create a new segment for the previous nodes if the current one cannot be bulked
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
topo_start = nid + 1;
}
}
// the last segmenet
if (topo_start != num_forward_nodes_) {
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, num_forward_nodes_);

// create forward segments for training
size_t topo_start = 0;
for (size_t nid = 0; nid < num_forward_nodes_; nid++) {
auto &node = graph_.indexed_graph()[nid].source;
auto &op_node = op_nodes_[nid];
// check if the segment relies on external input, or exceeds maxinum number of node,
// or requires async ops
if (node->is_variable() || nid - topo_start > num_nodes_threshold ||
op_node.exec->exec_type() != ExecType::kSync) {
// create a new segment for the previous nodes if the current one cannot be bulked
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
topo_start = nid + 1;
}
}
// the last segment
if (topo_start != num_forward_nodes_) {
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, num_forward_nodes_);
}

// create backward segments for training
// get all gradient variables
std::unordered_set<engine::VarHandle> grad_vars;
for (auto &kv : grad_store_) {
grad_vars.insert(kv.second.var());
// create backward segments for training
// get all gradient variables
std::unordered_set<engine::VarHandle> grad_vars;
for (auto &kv : grad_store_) {
grad_vars.insert(kv.second.var());
}
auto &idx = graph_.indexed_graph();
topo_start = num_forward_nodes_;
for (size_t nid = num_forward_nodes_; nid < total_num_nodes; nid++) {
auto &op_node = op_nodes_[nid];
if (op_node.skip_exec_node || op_node.exec == nullptr) {
continue;
}
auto &idx = graph_.indexed_graph();
topo_start = num_forward_nodes_;
for (size_t nid = num_forward_nodes_; nid < total_num_nodes; nid++) {
auto &op_node = op_nodes_[nid];
if (op_node.skip_exec_node || op_node.exec == nullptr) {
continue;
if (idx[nid].source->is_variable() || nid - topo_start > num_nodes_threshold ||
op_node.exec->exec_type() != ExecType::kSync) {
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
topo_start = nid + 1;
} else {
// If it produces output gradient, don't include it in the segment
bool output_gradient = false;
for (auto &out_arr : op_node.exec->out_array) {
if (grad_vars.find(out_arr.var()) != grad_vars.end()) {
output_gradient = true;
}
}
if (idx[nid].source->is_variable() || nid - topo_start > num_nodes_threshold ||
op_node.exec->exec_type() != ExecType::kSync) {
if (output_gradient) {
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
topo_start = nid + 1;
} else {
// If it produces output gradient, don't include it in the segment
bool output_gradient = false;
for (auto &out_arr : op_node.exec->out_array) {
if (grad_vars.find(out_arr.var()) != grad_vars.end()) {
output_gradient = true;
}
}
if (output_gradient) {
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
topo_start = nid + 1;
}
}
}
// last segment for backward
if (topo_start < total_num_nodes) {
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, total_num_nodes);
}
}
// last segment for backward
if (topo_start < total_num_nodes) {
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, total_num_nodes);
}
}

return;
void GraphExecutor::BulkInferenceOpSegs() {
// Attempt to bulk the whole graph for inference. We will only create new segments when
// required for non-kSync operations.
size_t topo_start = 0;
for (size_t nid = 0; nid < num_forward_nodes_; nid++) {
auto &node = graph_.indexed_graph()[nid].source;
auto &op_node = op_nodes_[nid];

// Variables do not need to be segmented at inference time.
if (node->is_variable()) continue;

if (op_node.exec->exec_type() != ExecType::kSync) {
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
topo_start = nid + 1;
}
}
// The last segment
if (topo_start != num_forward_nodes_) {
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, num_forward_nodes_);
}
}

void GraphExecutor::ExecuteMonCallback(size_t nid) {
Expand Down
4 changes: 4 additions & 0 deletions src/executor/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ class GraphExecutor : public Executor {
CachedSegOpr CreateCachedSegOpr(size_t topo_start, size_t topo_end);
// run the monitor callback for node `nid`
void ExecuteMonCallback(size_t nid);
// peform bulking and segmentation on an inference graph
void BulkInferenceOpSegs();
// perform bulking and segmentation on a training graph
void BulkTrainingOpSegs(size_t total_num_nodes);

// internal graph
nnvm::Graph graph_;
Expand Down