Skip to content

Commit

Permalink
[NNAPI EP] Add NNAPI Split (microsoft#18702)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

As title.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

yolo-v8 model missing operator support.

---------

Co-authored-by: rachguo <[email protected]>
Co-authored-by: Edward Chen <[email protected]>
  • Loading branch information
3 people authored Dec 6, 2023
1 parent c4b8120 commit 7762f3f
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <onnx/onnx_pb.h>
#include <algorithm>

#include "core/common/logging/logging.h"
#include "core/common/safeint.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/graph_viewer.h"
#include "core/providers/common.h"
#include "core/optimizer/initializer.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/nnapi/nnapi_builtin/builders/helper.h"
#include "core/providers/nnapi/nnapi_builtin/builders/model_builder.h"
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h"
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h"
#include "core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h"

using namespace android::nn::wrapper;

namespace onnxruntime {
namespace nnapi {

using namespace op_builder_helpers;

class SplitOpBuilder : public BaseOpBuilder {
// Add operator related
public:
void AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const override;

private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override;

// Operator support related

private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const override;

// Split opset 13- uses "split" as attribute. Currently it's not supported.
int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 13; }

// NNAPI Split is available since NNAPI feature level 3
int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */,
const OpSupportCheckParams& /* params */) const override {
return ANEURALNETWORKS_FEATURE_LEVEL_3;
}
};

// Add operator related

void SplitOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const {
const auto& input_defs = node_unit.Inputs();

if (input_defs.size() > 1 && input_defs[1].node_arg.Exists()) { // optional second input "split"
model_builder.AddInitializerToSkip(input_defs[1].node_arg.Name());
}
}

Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const {
const auto& input_name = node_unit.Inputs()[0].node_arg.Name();
const auto& outputs = node_unit.Outputs();

NodeAttrHelper helper(node_unit);
const auto axis = helper.Get("axis", 0);

int32_t num_outputs;
if (node_unit.SinceVersion() >= 18) {
num_outputs = SafeInt<int32_t>(*helper.GetInt("num_outputs"));
} else {
num_outputs = SafeInt<int32_t>(node_unit.Outputs().size());
}

std::vector<std::string> output_names;
output_names.reserve(num_outputs);
for (int32_t i = 0; i < num_outputs; ++i) {
output_names.push_back(outputs[i].node_arg.Name());
}

ORT_RETURN_IF_ERROR(op_builder_helpers::AddNnapiSplit(model_builder, input_name, axis, output_names));

return Status::OK();
}

// Operator support related

bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const {
Shape input_shape;
if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape))
return false;

const auto& input_defs = node_unit.Inputs();
NodeAttrHelper helper(node_unit);
const auto axis = helper.Get("axis", 0);

const auto split_dims_at_axis = input_shape[HandleNegativeAxis(axis, input_shape.size())];
if (input_defs.size() > 1 && input_defs[1].node_arg.Exists()) {
// if optional input `split` is provided
auto split_initializer_it = initializers.find(input_defs[1].node_arg.Name());
if (split_initializer_it == initializers.end()) {
LOGS_DEFAULT(VERBOSE) << "Optional input 'split' must be initializer if provided.";
return false;
}
const auto& splits_tensor = *split_initializer_it->second;
Initializer unpacked_tensor(splits_tensor);
auto splits_span = unpacked_tensor.DataAsSpan<int64_t>();
uint32_t sum_of_splits = std::accumulate(splits_span.begin(), splits_span.end(), SafeInt<uint32_t>(0));
if (sum_of_splits != split_dims_at_axis) {
LOGS_DEFAULT(VERBOSE) << "Sum of the 'split' input values must equal to the dim value at 'axis' specified. "
<< "dim value at 'axis' specified: "
<< split_dims_at_axis
<< ", sum of 'split' input values: "
<< sum_of_splits;
return false;
}

auto it = std::adjacent_find(splits_span.begin(), splits_span.end(), [](const auto& a, const auto& b) {
return a != b;
});
if (it != splits_span.end()) {
LOGS_DEFAULT(VERBOSE) << "NNAPI only supports the case that number of splits evenly divides split axis size";
return false;
}
} else {
uint32_t num_outputs;
if (node_unit.SinceVersion() >= 18) {
auto num_outputs_attr = helper.GetInt("num_outputs");
if (!num_outputs_attr.has_value()) {
LOGS_DEFAULT(VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute.";
return false;
}
num_outputs = SafeInt<uint32_t>(*num_outputs_attr);
if (num_outputs != SafeInt<uint32_t>(node_unit.Outputs().size()) || num_outputs > split_dims_at_axis) {
LOGS_DEFAULT(VERBOSE) << "Invalid num_outputs provided. "
<< "The value should be less than or equal to the size of dimension being split "
<< "and align with the size of output nodes. Current num_outputs: "
<< num_outputs;
return false;
}
} else {
num_outputs = SafeInt<uint32_t>(node_unit.Outputs().size());
}
// NNAPI only supports the case where axis can be evenly divided by num of splits
if (split_dims_at_axis % num_outputs != 0) {
LOGS_DEFAULT(VERBOSE) << "split count: " << num_outputs << " doesn't evenly divide split dimension: "
<< split_dims_at_axis;
return false;
}
}
return true;
}

void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<SplitOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}

} // namespace nnapi
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateResizeOpBuilder("Resize", op_registrations);
CreateSliceOpBuilder("Slice", op_registrations);
CreateSoftMaxOpBuilder("Softmax", op_registrations);
CreateSplitOpBuilder("Split", op_registrations);
CreateSqueezeOpBuilder("Squeeze", op_registrations);
CreateTransposeOpBuilder("Transpose", op_registrations);
CreateUnsqueezeOpBuilder("Unsqueeze", op_registrations);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ void CreateReluOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_
void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateSoftMaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateSqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Expand Down
15 changes: 3 additions & 12 deletions onnxruntime/test/providers/cpu/tensor/split_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -706,9 +706,8 @@ TEST(SplitOperatorTest, Split18_NumOutputs_EvenSplit) {
7.f, 8.f}});

int64_t num_outputs = 2;
#ifdef USE_COREML

RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true, num_outputs, true);
#endif
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true, num_outputs, false);
}

Expand All @@ -735,9 +734,8 @@ TEST(SplitOperatorTest, Split18_NumOutputs_UnevenSplit) {
outputs.push_back({{1, 2}, {9.f, 10.f}});

int64_t num_outputs = 3;
#ifdef USE_COREML

RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs, true);
#endif
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs, false);
}

Expand All @@ -763,10 +761,8 @@ TEST(SplitOperatorTest, Split18_InvalidNumOutputs) {
};
RunTest<float>(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, false,
"Attribute `num_outputs` value cannot be lower than 1");
#ifdef USE_COREML
RunTest<float>(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, true,
"Attribute `num_outputs` value cannot be lower than 1");
#endif

outputs.clear();
outputs.push_back({{1, 2},
Expand All @@ -775,12 +771,11 @@ TEST(SplitOperatorTest, Split18_InvalidNumOutputs) {
{0.f, 0.f}});

num_outputs = 3;

RunTest<float>(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, false,
"Invalid num_outputs value of 3. Size of dimension being split is 2");
#ifdef USE_COREML
RunTest<float>(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, true,
"Invalid num_outputs value of 3. Size of dimension being split is 2");
#endif
}

TEST(SplitOperatorTest, Split18_NumOutputsEvenSplitAxis1) {
Expand All @@ -798,9 +793,7 @@ TEST(SplitOperatorTest, Split18_NumOutputsEvenSplitAxis1) {

int64_t num_outputs = 3;
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true, num_outputs, false);
#ifdef USE_COREML
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true, num_outputs);
#endif
}

TEST(SplitOperatorTest, Split18_NumOutputsUnevenSplitAxis1) {
Expand All @@ -818,9 +811,7 @@ TEST(SplitOperatorTest, Split18_NumOutputsUnevenSplitAxis1) {
outputs.push_back({{2, 1}, {3.f, 6.f}});

int64_t num_outputs = 2;
#ifdef USE_COREML
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs);
#endif
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs, false);
}

Expand Down
1 change: 1 addition & 0 deletions tools/ci_build/github/android/nnapi_supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Keep in sync with doco generated from /docs/execution-providers/NNAPI-ExecutionP
|ai.onnx:Sin||
|ai.onnx:Slice||
|ai.onnx:Softmax||
|ai.onnx:Split|Number of splits must evenly divide split axis size. Input split should be constant if provided.|
|ai.onnx:Sqrt||
|ai.onnx:Squeeze|Input axes should be constant.|
|ai.onnx:Sub||
Expand Down

0 comments on commit 7762f3f

Please sign in to comment.