Skip to content

Commit

Permalink
[WebNN EP] Fix issues of GRU operator (#22123)
Browse files Browse the repository at this point in the history
### Description
This PR fixes the spelling of the key value of the GRU operator in the
map in the `GetSupportedNodes` function (Gru -> GRU) and removes the
data type check for the fifth input (sequence_lens) of the GRU operator.

PTAL, thanks!
  • Loading branch information
miaobin authored Nov 13, 2024
1 parent a9b62fa commit a15381d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 26 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"GlobalLpPool", "l2Pool2d"},
{"Greater", "greater"},
{"GreaterOrEqual", "greaterOrEqual"},
{"Gru", "gru"},
{"GRU", "gru"},
{"HardSigmoid", "hardSigmoid"},
{"HardSwish", "hardSwish"},
{"Identity", "identity"},
Expand Down
76 changes: 51 additions & 25 deletions onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class GruOpBuilder : public BaseOpBuilder {
const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};

void GruOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
Expand Down Expand Up @@ -189,40 +191,64 @@ bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::va
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type = 0; // input data type
int32_t input1_type = 0; // weight data type
int32_t input2_type = 0; // recurrentWeight data type
int32_t input3_type = 0; // bias data type
int32_t input4_type = 0; // recurrentBias data type
int32_t input5_type = 0; // initialHiddenState data type
bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists();
bool has_input4 = input_defs.size() > 4 && input_defs[4]->Exists();
bool has_input5 = input_defs.size() > 5 && input_defs[5]->Exists();

if (!GetType(*input_defs[0], input0_type, logger) ||
!GetType(*input_defs[1], input1_type, logger) ||
!GetType(*input_defs[2], input2_type, logger) ||
(has_input3 && !GetType(*input_defs[3], input3_type, logger)) ||
(has_input4 && !GetType(*input_defs[4], input4_type, logger)) ||
(has_input5 && !GetType(*input_defs[5], input5_type, logger))) {
int32_t input_X_type = 0; // input data type
int32_t input_W_type = 0; // weight data type
int32_t input_R_type = 0; // recurrent weight data type
int32_t input_B_type = 0; // bias data type
int32_t input_initial_h_type = 0; // initial hidden state data type
bool has_input_B = input_defs.size() > 3 && input_defs[3]->Exists();
bool has_input_initial_h = input_defs.size() > 5 && input_defs[5]->Exists();

if (!GetType(*input_defs[0], input_X_type, logger) ||
!GetType(*input_defs[1], input_W_type, logger) ||
!GetType(*input_defs[2], input_R_type, logger) ||
(has_input_B && !GetType(*input_defs[3], input_B_type, logger)) ||
// input_defs[4] refers to sequence_lens and is a fixed data type of int32.
(has_input_initial_h && !GetType(*input_defs[5], input_initial_h_type, logger))) {
return false;
}

InlinedVector<int32_t, 6> input_types = {input0_type, input1_type, input2_type};
if (has_input3) {
input_types.push_back(input3_type);
InlinedVector<int32_t, 5> input_types = {input_X_type, input_W_type, input_R_type};
if (has_input_B) {
input_types.push_back(input_B_type);
}
if (has_input4) {
input_types.push_back(input4_type);
}
if (has_input5) {
input_types.push_back(input5_type);
if (has_input_initial_h) {
input_types.push_back(input_initial_h_type);
}
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
return false;
}

return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger);
return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger);
}

bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& output_defs = node.OutputDefs();
const auto& op_type = node.OpType();
int32_t Y_type = 0;
int32_t Y_h_type = 0;
bool has_Y = output_defs.size() > 0 && output_defs[0]->Exists();
bool has_Y_h = output_defs.size() > 1 && output_defs[1]->Exists();

bool Y_supported = has_Y && GetType(*output_defs[0], Y_type, logger);
bool Y_h_supported = has_Y_h && GetType(*output_defs[1], Y_h_type, logger);

if (Y_supported && !Y_h_supported) {
return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "outputs", "Y", logger);
} else if (!Y_supported && Y_h_supported) {
return IsDataTypeSupportedByOp(op_type, Y_h_type, wnn_limits, "outputs", "Y_h", logger);
} else if (Y_supported && Y_h_supported) {
if (Y_type != Y_h_type) {
LOGS(logger, VERBOSE) << "[GRU] Output data types must be the same.";
return false;
}
return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "outputs", "Y", logger);
} else {
LOGS(logger, VERBOSE) << "[GRU] No output found.";
return false;
}
}

void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
Expand Down

0 comments on commit a15381d

Please sign in to comment.