-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Kezhan/execute graph refactoring #1553
Changes from 13 commits
d546f7a
8c1b47d
69a66f0
398b64c
26e69ad
7d03eaf
749eab5
fc48324
53eb9f4
783ef3c
d7b0d2b
9e3c407
62cc6a1
567aecc
f14ee8f
e679733
ae1bb3d
1cae5ac
4c79280
46f3b19
749fb88
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -344,6 +344,10 @@ class PlannerImpl { | |
|
||
Status ComputeUseCounts() { | ||
// Note: for every ml-value, its definition must appear before all its uses in a topological sort of a valid model | ||
std::unordered_set<std::string> graph_inputs; | ||
for (auto& graph_input : graph_viewer_.GetInputsIncludingInitializers()) { | ||
graph_inputs.insert(graph_input->Name()); | ||
} | ||
|
||
for (auto graph_input : graph_viewer_.GetInputs()) { | ||
OrtValueIndex index = Index(graph_input->Name()); | ||
|
@@ -368,15 +372,7 @@ class PlannerImpl { | |
for (SequentialExecutionPlan::NodeExecutionPlan& step : plan_.execution_plan) { | ||
auto pnode = graph_viewer_.GetNode(step.node_index); | ||
if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the node ", step.node_index); | ||
for (auto node_input : pnode->InputDefs()) { | ||
if (node_input->Exists()) | ||
UseCount(node_input->Name())++; | ||
} | ||
|
||
for (auto node_input : pnode->ImplicitInputDefs()) { | ||
if (node_input->Exists()) | ||
UseCount(node_input->Name())++; | ||
} | ||
// Identify where each output of this node should be allocated. | ||
// This is determined by the opkernel bound to the node. | ||
const KernelCreateInfo* kernel_create_info = nullptr; | ||
|
@@ -391,31 +387,49 @@ class PlannerImpl { | |
if (!pnode->Name().empty()) errormsg << " (node " << pnode->Name() << ")"; | ||
return Status(ONNXRUNTIME, FAIL, errormsg.str()); | ||
} | ||
|
||
auto exec_provider = execution_providers_.Get(*pnode); | ||
if (exec_provider == nullptr) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the execution provider ", | ||
pnode->GetExecutionProviderType()); | ||
} | ||
|
||
auto& default_allocator_info = exec_provider->GetAllocator(0, OrtMemTypeDefault)->Info(); | ||
auto inputs = pnode->InputDefs(); | ||
auto num_inputs = inputs.size(); | ||
for (size_t i = 0; i < num_inputs; ++i) { | ||
if (inputs[i]->Exists()) { | ||
UseCount(inputs[i]->Name())++; | ||
if (graph_inputs.end() != graph_inputs.find(inputs[i]->Name())) { | ||
// If it's a graph input, set its plan. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be easier to do quite a few things (especially in the optimizers) if we had edges from the graph inputs (and initializers) to nodes, and edges from nodes to graph outputs. Is that something we should consider (separate PR). It would avoid this sort of 'iterate all nodes to find the small number that involve graph inputs/outputs' logic. |
||
// NOTE: Copy nodes should have already been added if a graph input is fed as inputs of nodes assigned to different providers. | ||
OrtValueIndex index = Index(inputs[i]->Name()); | ||
plan_.SetLocation(static_cast<size_t>(index), exec_provider->GetAllocator(0, p_kernelDef->InputMemoryType(i))->Info()); | ||
} | ||
} | ||
} | ||
|
||
auto implicit_inputs = pnode->ImplicitInputDefs(); | ||
auto num_implicit_inputs = implicit_inputs.size(); | ||
for (size_t i = 0; i < num_implicit_inputs; ++i) { | ||
if (implicit_inputs[i]->Exists()) { | ||
UseCount(implicit_inputs[i]->Name())++; | ||
if (graph_inputs.end() != graph_inputs.find(implicit_inputs[i]->Name())) { | ||
// If it's a graph input, set its plan. | ||
// NOTE: Copy nodes should have already been added if a graph input is fed as inputs of nodes assigned to different providers. | ||
OrtValueIndex index = Index(implicit_inputs[i]->Name()); | ||
plan_.SetLocation(static_cast<size_t>(index), exec_provider->GetAllocator(0, p_kernelDef->InputMemoryType(i))->Info()); | ||
} | ||
} | ||
} | ||
linkerzhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
auto outputs = pnode->OutputDefs(); | ||
auto num_outputs = outputs.size(); | ||
|
||
for (size_t i = 0; i < num_outputs; ++i) { | ||
auto* node_output = outputs[i]; | ||
if (!node_output->Exists()) continue; | ||
OrtValueIndex index = Index(node_output->Name()); | ||
ProcessDef(index, node_output); | ||
++UseCount(index); | ||
if (strcmp(default_allocator_info.name, CPU) != 0) { | ||
// By default, outputs of this node are allocated on the default device allocator, | ||
// except for outputs marked for allocation in MemoryType: | ||
auto memory_type = p_kernelDef->OutputMemoryType(i); | ||
plan_.SetLocation(static_cast<size_t>(index), memory_type == OrtMemTypeDefault | ||
? default_allocator_info | ||
: exec_provider->GetAllocator(0, memory_type)->Info()); | ||
} | ||
plan_.SetLocation(static_cast<size_t>(index), exec_provider->GetAllocator(0, p_kernelDef->OutputMemoryType(i))->Info()); | ||
} | ||
// if sync is needed, mark allocation plan as create_fence_if_async=true | ||
// note that the input arg may come from an execution provider (i.e. CPU) that does not support async, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,9 +23,18 @@ AllocatorPtr GetAllocator(const SessionState& session_state, const OrtAllocatorI | |
return session_state.GetExecutionProviders().GetAllocator(allocator_info); | ||
} | ||
|
||
common::Status AllocateHelper(const IExecutionProvider& execution_provider, int device_id, const Tensor& fetched_tensor, | ||
bool ProviderIsCpuBased(const std::string& provider_type) { | ||
return provider_type == onnxruntime::kCpuExecutionProvider || | ||
provider_type == onnxruntime::kMklDnnExecutionProvider || | ||
linkerzhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
provider_type == onnxruntime::kNGraphExecutionProvider || | ||
provider_type == onnxruntime::kNupharExecutionProvider || | ||
provider_type == onnxruntime::kOpenVINOExecutionProvider || | ||
provider_type == onnxruntime::kNnapiExecutionProvider; | ||
} | ||
|
||
common::Status AllocateHelper(const IExecutionProvider& execution_provider, const OrtDevice* device, const Tensor& fetched_tensor, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Given you have to pass a value here, please pass a reference instead of a pointer. |
||
OrtValue& output_mlvalue) { | ||
auto allocator = execution_provider.GetAllocator(device_id, OrtMemTypeDefault); | ||
auto allocator = execution_provider.GetAllocator(device->Id(), OrtMemTypeDefault); | ||
if (!allocator) { | ||
return Status(common::ONNXRUNTIME, common::FAIL, "invalid allocator"); | ||
} | ||
|
@@ -62,20 +71,15 @@ static Status CopyMLValue(const DataTransferManager& data_transfer_mgr, | |
const FeedsFetchesManager::MLValueCopyInfo& copy_info, | ||
const OrtValue& source_mlvalue, | ||
OrtValue& target_mlvalue) { | ||
if (copy_info.copy_provider == nullptr) { | ||
target_mlvalue = source_mlvalue; | ||
} else { | ||
linkerzhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
auto& source_tensor = source_mlvalue.Get<Tensor>(); | ||
|
||
if (!target_mlvalue.IsAllocated()) { | ||
ORT_RETURN_IF_ERROR(utils::AllocateHelper(*copy_info.allocation_provider, copy_info.allocation_device_id, | ||
source_tensor, target_mlvalue)); | ||
} | ||
auto& source_tensor = source_mlvalue.Get<Tensor>(); | ||
if (!target_mlvalue.IsAllocated()) { | ||
ORT_RETURN_IF_ERROR(utils::AllocateHelper(*copy_info.allocation_provider, copy_info.target_device, | ||
source_tensor, target_mlvalue)); | ||
} | ||
|
||
Tensor* p_output_tensor = target_mlvalue.GetMutable<Tensor>(); | ||
Tensor* p_output_tensor = target_mlvalue.GetMutable<Tensor>(); | ||
|
||
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor)); | ||
} | ||
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor)); | ||
|
||
return Status::OK(); | ||
} | ||
|
@@ -86,8 +90,6 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons | |
FeedsFetchesManager::MLValueCopyInfo& copy_info) { | ||
needed_copy = false; | ||
|
||
//TODO: make it configurable | ||
const int target_device_id = 0; | ||
std::vector<SessionState::NodeInfo> node_info_vec; | ||
ORT_RETURN_IF_ERROR(session_state.GetInputNodeInfo(input_name, node_info_vec)); | ||
|
||
|
@@ -111,51 +113,23 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons | |
break; | ||
} | ||
|
||
auto& required_provider_type = GetNodeInputProviderType(node_info); | ||
auto& input_tensor = orig_mlvalue.Get<Tensor>(); | ||
auto& input_tensor_loc = input_tensor.Location(); | ||
|
||
auto* p_input_provider = exec_providers.Get(input_tensor_loc); | ||
if (!p_input_provider) { | ||
p_input_provider = exec_providers.Get(onnxruntime::kCpuExecutionProvider); | ||
ORT_ENFORCE(p_input_provider); | ||
} | ||
|
||
//no copy for TRT and nGraph | ||
if (required_provider_type == onnxruntime::kTensorrtExecutionProvider || required_provider_type == onnxruntime::kNGraphExecutionProvider) { | ||
new_mlvalue = orig_mlvalue; | ||
break; | ||
} | ||
|
||
auto input_provider_type = p_input_provider->Type(); | ||
if (input_provider_type == required_provider_type && input_tensor_loc.mem_type == OrtMemTypeDefault) { | ||
new_mlvalue = orig_mlvalue; | ||
break; | ||
} | ||
|
||
// If a node requires input on cpu and input tensor is allocated with pinned memory allocator, don't do copy | ||
if (required_provider_type == onnxruntime::kCpuExecutionProvider && | ||
input_tensor_loc.mem_type == OrtMemTypeCPU) { | ||
auto& required_device = *node_info.device; | ||
auto& input_tensor_device = orig_mlvalue.Get<Tensor>().Location().device; | ||
if (required_device == input_tensor_device) { | ||
// No copy needed for same device. | ||
new_mlvalue = orig_mlvalue; | ||
break; | ||
} | ||
|
||
auto& required_provider_type = GetNodeInputProviderType(node_info); | ||
auto* required_provider = exec_providers.Get(required_provider_type); | ||
ORT_ENFORCE(required_provider); | ||
|
||
auto* p_copy_provider = (required_provider_type != onnxruntime::kCpuExecutionProvider) | ||
? required_provider | ||
: p_input_provider; | ||
|
||
copy_info.allocation_device_id = target_device_id; | ||
copy_info.target_device = &required_device; | ||
copy_info.allocation_provider = required_provider; | ||
copy_info.copy_provider = p_copy_provider; | ||
|
||
linkerzhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ORT_RETURN_IF_ERROR(CopyMLValue(session_state.GetDataTransferMgr(), copy_info, orig_mlvalue, new_mlvalue)); | ||
|
||
needed_copy = true; | ||
|
||
// } loop of node_info_vec | ||
} while (false); | ||
|
||
return Status::OK(); | ||
|
@@ -344,43 +318,26 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state | |
continue; | ||
} | ||
|
||
auto& fetched_tensor = fetched_mlvalue.Get<Tensor>(); | ||
auto& fetched_tensor_location = fetched_tensor.Location(); | ||
auto* p_fetched_provider = execution_providers.Get(fetched_tensor_location); | ||
if (!p_fetched_provider) { | ||
p_fetched_provider = cpu_execution_provider; | ||
} | ||
|
||
auto fetched_provider_type = p_fetched_provider->Type(); | ||
auto& output_mlvalue = user_fetches[idx]; | ||
|
||
const IExecutionProvider* p_output_provider = nullptr; | ||
|
||
auto target_device = OrtDevice(); | ||
auto& output_mlvalue = user_fetches[idx]; | ||
if (output_mlvalue.IsAllocated()) { | ||
Tensor* p_output_tensor = output_mlvalue.GetMutable<Tensor>(); | ||
target_device = p_output_tensor->Location().device; | ||
p_output_provider = execution_providers.Get(p_output_tensor->Location()); | ||
} | ||
auto fetch_result_device = fetched_mlvalue.Get<Tensor>().Location().device; | ||
if (target_device == fetch_result_device) { | ||
user_fetches[idx] = fetched_mlvalue; | ||
continue; | ||
} | ||
|
||
if (!p_output_provider) { | ||
p_output_provider = cpu_execution_provider; | ||
} | ||
|
||
auto output_provider_type = p_output_provider->Type(); | ||
|
||
if (fetched_provider_type == output_provider_type || | ||
(p_output_provider == cpu_execution_provider && fetched_tensor_location.mem_type == OrtMemTypeCPUOutput)) { | ||
user_fetches[idx] = fetched_mlvalue; | ||
continue; | ||
} | ||
|
||
needed_copy = true; | ||
|
||
auto* p_copy_provider = (fetched_provider_type != onnxruntime::kCpuExecutionProvider) | ||
? p_fetched_provider | ||
: p_output_provider; | ||
|
||
const int device_id = 0; // TODO: As per comment in the copy input code, make this configurable. | ||
FeedsFetchesManager::MLValueCopyInfo copy_info{device_id, p_output_provider, p_copy_provider}; | ||
FeedsFetchesManager::MLValueCopyInfo copy_info{&target_device, p_output_provider}; | ||
ORT_RETURN_IF_ERROR(CopyMLValue(session_state.GetDataTransferMgr(), copy_info, fetched_mlvalue, output_mlvalue)); | ||
|
||
if (copiers) { | ||
|
@@ -410,11 +367,7 @@ static common::Status CachedCopyOutputsAcrossDevices( | |
|
||
static DeviceCopyCheck CheckExecutionProviders(const ExecutionProviders& execution_providers) { | ||
for (const auto& execution_provider : execution_providers) { | ||
if (execution_provider->Type() != onnxruntime::kCpuExecutionProvider && | ||
execution_provider->Type() != onnxruntime::kMklDnnExecutionProvider && | ||
execution_provider->Type() != onnxruntime::kNGraphExecutionProvider && | ||
execution_provider->Type() != onnxruntime::kNupharExecutionProvider && | ||
execution_provider->Type() != onnxruntime::kOpenVINOExecutionProvider) { | ||
if (!ProviderIsCpuBased(execution_provider->Type())) { | ||
return DeviceCopyCheck::Unknown; | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is memcmp more efficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would assume that object padding is always the same
In reply to: 310326739 [](ancestors = 310326739)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yuslepukhin For the same compiler, same build, isn't it?