Skip to content

Commit

Permalink
Optimize Fence checking performance (#1593)
Browse files Browse the repository at this point in the history
* For majority of nodes, we do not need to do fence check. Instead, we only need to do FenceCheck for CPU<->GPU mem sync node
But we pay the Fence check cost for every single node and every single input and output.

This change will minimize the Fence check to only do it when necessary.
  • Loading branch information
ybrnathan authored Aug 9, 2019
1 parent 1c5b15c commit 9b83545
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 69 deletions.
51 changes: 51 additions & 0 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ class PlannerImpl {
// Initialize execution plan:
plan_.execution_plan.reserve(num_graph_nodes);

// Initialize node_has_fence.
plan_.node_has_fence.resize(graph_viewer_.MaxNodeIndex());

// Initialize allocation plan:
plan_.allocation_plan.resize(num_ml_values);
}
Expand Down Expand Up @@ -585,6 +588,51 @@ class PlannerImpl {
return Status::OK();
}

// Whether a given NodeArg has fence or not.
// If the buffer is reused, need to check whether original OrtValue has fence or not.
bool HasFence(const onnxruntime::NodeArg* arg) {
bool has_fence = false;
if (arg && arg->Exists()) {
OrtValueIndex index = Index(arg->Name());
AllocPlanPerValue& value_plan = AllocPlan(index);

has_fence = value_plan.create_fence_if_async;
if (value_plan.alloc_kind == AllocKind::kReuse)
{
// Buffer reused, check original buffer to see if fence is shared.
has_fence = has_fence || AllocPlan(value_plan.reused_buffer).create_fence_if_async;
}
}

return has_fence;
}

// Compute fence check. Set has_fence flag if either one of inputs, implicit inputs or outputs of a given node has fence.
Status ComputeFenceCheck() {

for (SequentialExecutionPlan::NodeExecutionPlan& step : plan_.execution_plan) {
auto pnode = graph_viewer_.GetNode(step.node_index);
if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the node ", step.node_index);

bool has_fence = false;
for (auto node_input : pnode->InputDefs()) {
has_fence = has_fence || HasFence(node_input);
}

for (auto node_input : pnode->ImplicitInputDefs()) {
has_fence = has_fence || HasFence(node_input);
}

for (auto node_output : pnode->OutputDefs()) {
has_fence = has_fence || HasFence(node_output);
}

plan_.node_has_fence[step.node_index] = has_fence;
}

return Status::OK();
}

// Convert information in a freelist (about which ml-value becomes free when) into
// a deallocation plan in the format required in an ExecutionPlan
void GenerateDeallocationPlan() {
Expand Down Expand Up @@ -642,6 +690,9 @@ Status PlannerImpl::CreatePlan() {
// determine sharing/reuse among ml-values
ORT_RETURN_IF_ERROR(ComputeReusePlan());

// Determine nodes that need fence check. This needs to be done after ComputeUseCounts and ComputeReusePlan.
ORT_RETURN_IF_ERROR(ComputeFenceCheck());

// convert information in the freelist_ into a deallocation plan in required format
GenerateDeallocationPlan();

Expand Down
76 changes: 41 additions & 35 deletions onnxruntime/core/framework/parallel_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ Status ParallelExecutor::RunNodeAsync(size_t p_node_index,
TimePoint sync_time_begin;
TimePoint kernel_begin_time;
const bool f_profiler_enabled = session_state.Profiler().IsEnabled();
const SequentialExecutionPlan& exec_plan = *session_state.GetExecutionPlan();

// Avoid context switching if possible.
while (keep_running) {
Expand Down Expand Up @@ -149,33 +150,34 @@ Status ParallelExecutor::RunNodeAsync(size_t p_node_index,
}
// sync before compute
int queue_id = p_op_kernel->KernelDef().ExecQueueId();

for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
if (exec_plan.NodeHasFence(node_index)) {
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
}

for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
}

for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->BeforeUsingAsOutput(p_op_kernel->Node().GetExecutionProviderType(), queue_id);
for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->BeforeUsingAsOutput(p_op_kernel->Node().GetExecutionProviderType(), queue_id);
}
}
}

Expand Down Expand Up @@ -209,32 +211,36 @@ Status ParallelExecutor::RunNodeAsync(size_t p_node_index,
sync_time_begin = session_state.Profiler().StartTime();
}
// sync after compute for outputs
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
if (exec_plan.NodeHasFence(node_index)) {
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
}
}
}

for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
}
}
}

for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->AfterUsedAsOutput(queue_id);
for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->AfterUsedAsOutput(queue_id);
}
}
}

if (f_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
p_op_kernel->Node().Name() + "_fence_after",
sync_time_begin,
{{"op_name", p_op_kernel->KernelDef().OpName()}});
}

//std::cout << "Run async node finish: " << p_node_index << std::endl;

keep_running = false;
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/core/framework/sequential_execution_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ struct SequentialExecutionPlan : public ExecutionPlanBase {
// Execution_plan: represents the nodes in the sequential order to be executed
std::vector<NodeExecutionPlan> execution_plan;

// Records whether a given node has fence on its input or output, key is node index.
std::vector<bool> node_has_fence;

// to_be_freed: vector elements represent indices of ml-values to be freed (as described above)
std::vector<OrtValueIndex> to_be_freed;

Expand All @@ -84,6 +87,12 @@ struct SequentialExecutionPlan : public ExecutionPlanBase {
}
return locations;
}

// Whether a given node needs fence check or not.
bool NodeHasFence(onnxruntime::NodeIndex node_index) const {
return node_has_fence[node_index];
}

};

// Output details of an execution plan:
Expand Down
72 changes: 38 additions & 34 deletions onnxruntime/core/framework/sequential_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,32 +71,34 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:

// sync before compute
int queue_id = p_op_kernel->KernelDef().ExecQueueId();
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
if (seq_exec_plan.NodeHasFence(node_index)) {
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
}

for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
auto execution_provider_type = p_op_kernel->Node().GetExecutionProviderType();
if (OrtMemTypeCPUInput == p_op_kernel->KernelDef().InputMemoryType(input_index)) {
execution_provider_type = kCpuExecutionProvider;
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
fence->BeforeUsingAsInput(execution_provider_type, queue_id);
}
}

for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->BeforeUsingAsOutput(p_op_kernel->Node().GetExecutionProviderType(), queue_id);
for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->BeforeUsingAsOutput(p_op_kernel->Node().GetExecutionProviderType(), queue_id);
}
}
}

Expand Down Expand Up @@ -138,24 +140,26 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
}

// sync after compute for outputs
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
if (seq_exec_plan.NodeHasFence(node_index)) {
for (int input_index = 0; input_index < op_kernel_context.InputCount(); ++input_index) {
Fence_t fence = op_kernel_context.InputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
}
}
}

for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
for (int input_index = 0; input_index < op_kernel_context.ImplicitInputCount(); ++input_index) {
Fence_t fence = op_kernel_context.ImplicitInputFence(input_index);
if (fence) {
fence->AfterUsedAsInput(queue_id);
}
}
}

for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->AfterUsedAsOutput(queue_id);
for (int output_index = 0; output_index < op_kernel_context.OutputCount(); ++output_index) {
Fence_t fence = op_kernel_context.OutputFence(output_index);
if (fence) {
fence->AfterUsedAsOutput(queue_id);
}
}
}

Expand Down

0 comments on commit 9b83545

Please sign in to comment.