diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index c27e82ee7f54a..218677a1945d5 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -227,7 +227,7 @@ static const InlinedHashMap op_map = { {"GlobalLpPool", "l2Pool2d"}, {"Greater", "greater"}, {"GreaterOrEqual", "greaterOrEqual"}, - {"Gru", "gru"}, + {"GRU", "gru"}, {"HardSigmoid", "hardSigmoid"}, {"HardSwish", "hardSwish"}, {"Identity", "identity"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index c92fe7366d494..ffb9b7fbf2e7a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -28,6 +28,8 @@ class GruOpBuilder : public BaseOpBuilder { const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; + bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; void GruOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -189,40 +191,64 @@ bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::va const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); - int32_t input0_type = 0; // input data type - int32_t input1_type = 0; // weight data type - int32_t input2_type = 0; // recurrentWeight data type - int32_t input3_type = 0; // bias data type - int32_t input4_type = 0; // recurrentBias data type - int32_t input5_type = 0; // initialHiddenState data type - bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists(); - bool has_input4 = input_defs.size() > 4 && input_defs[4]->Exists(); - bool has_input5 = input_defs.size() > 5 && input_defs[5]->Exists(); - - if (!GetType(*input_defs[0], input0_type, logger) || - !GetType(*input_defs[1], input1_type, logger) || - !GetType(*input_defs[2], input2_type, logger) || - (has_input3 && !GetType(*input_defs[3], input3_type, logger)) || - (has_input4 && !GetType(*input_defs[4], input4_type, logger)) || - (has_input5 && !GetType(*input_defs[5], input5_type, logger))) { + int32_t input_X_type = 0; // input data type + int32_t input_W_type = 0; // weight data type + int32_t input_R_type = 0; // recurrent weight data type + int32_t input_B_type = 0; // bias data type + int32_t input_initial_h_type = 0; // initial hidden state data type + bool has_input_B = input_defs.size() > 3 && input_defs[3]->Exists(); + bool has_input_initial_h = input_defs.size() > 5 && input_defs[5]->Exists(); + + if (!GetType(*input_defs[0], input_X_type, logger) || + !GetType(*input_defs[1], input_W_type, logger) || + !GetType(*input_defs[2], input_R_type, logger) || + (has_input_B && !GetType(*input_defs[3], input_B_type, logger)) || + // input_defs[4] refers to sequence_lens and is a fixed data type of int32. + (has_input_initial_h && !GetType(*input_defs[5], input_initial_h_type, logger))) { return false; } - InlinedVector input_types = {input0_type, input1_type, input2_type}; - if (has_input3) { - input_types.push_back(input3_type); + InlinedVector input_types = {input_X_type, input_W_type, input_R_type}; + if (has_input_B) { + input_types.push_back(input_B_type); } - if (has_input4) { - input_types.push_back(input4_type); - } - if (has_input5) { - input_types.push_back(input5_type); + if (has_input_initial_h) { + input_types.push_back(input_initial_h_type); } if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); + return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger); +} + +bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& output_defs = node.OutputDefs(); + const auto& op_type = node.OpType(); + int32_t Y_type = 0; + int32_t Y_h_type = 0; + bool has_Y = output_defs.size() > 0 && output_defs[0]->Exists(); + bool has_Y_h = output_defs.size() > 1 && output_defs[1]->Exists(); + + bool Y_supported = has_Y && GetType(*output_defs[0], Y_type, logger); + bool Y_h_supported = has_Y_h && GetType(*output_defs[1], Y_h_type, logger); + + if (Y_supported && !Y_h_supported) { + return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "outputs", "Y", logger); + } else if (!Y_supported && Y_h_supported) { + return IsDataTypeSupportedByOp(op_type, Y_h_type, wnn_limits, "outputs", "Y_h", logger); + } else if (Y_supported && Y_h_supported) { + if (Y_type != Y_h_type) { + LOGS(logger, VERBOSE) << "[GRU] Output data types must be the same."; + return false; + } + return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "outputs", "Y", logger); + } else { + LOGS(logger, VERBOSE) << "[GRU] No output found."; + return false; + } } void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {