diff --git a/src/relay/op/contrib/ethosu/convolution.cc b/src/relay/op/contrib/ethosu/convolution.cc index 2cc6c3416e276..9471f88ac3769 100644 --- a/src/relay/op/contrib/ethosu/convolution.cc +++ b/src/relay/op/contrib/ethosu/convolution.cc @@ -123,12 +123,28 @@ bool EthosuConv2DRel(const Array& types, int num_inputs, const Attrs& attr if (ifm == nullptr || weight == nullptr) return false; const auto* param = attrs.as(); CHECK(param != nullptr) << "EthosuConv2DAttrs cannot be nullptr."; - CHECK(ifm->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8)) - << "Expected ethosu_conv2d type(uint8) or type(int8) for ifm but was " << ifm->dtype; - CHECK(weight->dtype == DataType::UInt(8) || weight->dtype == DataType::Int(8)) - << "Expected ethosu_conv2d type(uint8) or type(int8) for weight but was " << weight->dtype; - CHECK(scale_bias->dtype == DataType::UInt(8)) - << "Expected ethosu_conv2d type(uint8) for scale_bias but was " << scale_bias->dtype; + + if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_conv2d input data type " + << "of type(uint8) or type(int8) but was " << ifm->dtype); + return false; + } + + if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_conv2d weight data type " + << "of type(uint8) or type(int8) but was " << weight->dtype); + return false; + } + + if (scale_bias->dtype != DataType::UInt(8)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_conv2d scale bias data type " + << "of type(uint8) but was " << scale_bias->dtype); + return false; + } // The scale_bias should be provided as a tensor of size {ofm_channels, 10} reporter->Assign(types[2], TensorType({weight->shape[0], 10}, DataType::UInt(8))); diff --git a/src/relay/op/contrib/ethosu/depthwise.cc b/src/relay/op/contrib/ethosu/depthwise.cc index 5ff27de51b2fe..7918285ce1b75 100644 --- a/src/relay/op/contrib/ethosu/depthwise.cc +++ b/src/relay/op/contrib/ethosu/depthwise.cc @@ -123,15 +123,30 @@ bool EthosuDepthwiseConv2DRel(const Array& types, int num_inputs, const At const auto* param = attrs.as(); ICHECK(param != nullptr) << "EthosuDepthwiseConv2DAttrs cannot be nullptr."; - ICHECK(ifm->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8)) - << "Expected ethosu_depthwise_conv2d type(uint8) or type(int8) for ifm but was " - << ifm->dtype; - ICHECK(weight->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8)) - << "Expected ethosu_depthwise_conv2d type(uint8) or type(int8) for weight but was " - << weight->dtype; - ICHECK(scale_bias->dtype == DataType::UInt(8)) - << "Expected ethosu_depthwise_conv2d type(uint8) for scale_bias but was " - << scale_bias->dtype; + + if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_depthwise_conv2d input data type " + << "of type(uint8) or type(int8) but was " << ifm->dtype); + return false; + } + + if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_depthwise_conv2d weight data type " + << "of type(uint8) or type(int8) but was " << weight->dtype); + return false; + } + + if (scale_bias->dtype != DataType::UInt(8)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_depthwise_conv2d scale bias data type " + << "of type(uint8) but was " << scale_bias->dtype); + return false; + } // Collect the ifm, weight and ofm tensors for using in the inference function Array tensor_types = {types[0], types[1], types[4]}; diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 19f546a6f974b..58862c5f5faa3 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -382,14 +382,15 @@ def make_ethosu_conv2d( ifm_layout="NHWC", ofm_layout="NHWC", weight_dtype="int8", + scale_bias_dtype="uint8", ): # conv params weight_shape = (ofm_channels, kernel_shape[0], kernel_shape[1], ifm_channels) padding = get_pad_tuple(padding, kernel_shape) - scale_bias_data = generate_weights_data((weight_shape[0], 10), "uint8") - scale_bias = relay.const(scale_bias_data, dtype="uint8") - weight_data = generate_weights_data(weight_shape, "int8") + scale_bias_data = generate_weights_data((weight_shape[0], 10), scale_bias_dtype) + scale_bias = relay.const(scale_bias_data, dtype=scale_bias_dtype) + weight_data = generate_weights_data(weight_shape, weight_dtype) weight = relay.const(weight_data, dtype=weight_dtype) conv = ethosu_ops.ethosu_conv2d( ifm, @@ -427,13 +428,14 @@ def make_ethosu_depthwise_conv2d( ifm_layout="NHWC", ofm_layout="NHWC", weight_dtype="int8", + scale_bias_dtype="uint8", ): # params weight_shape = (channels, kernel_shape[0], kernel_shape[1], 1) padding = get_pad_tuple(padding, kernel_shape) - scale_bias_data = generate_weights_data((weight_shape[0], 10), "uint8") - scale_bias = relay.const(scale_bias_data, dtype="uint8") + scale_bias_data = generate_weights_data((weight_shape[0], 10), scale_bias_dtype) + scale_bias = relay.const(scale_bias_data, dtype=scale_bias_dtype) weight_data = generate_weights_data(weight_shape, weight_dtype) weight = relay.const(weight_data, dtype=weight_dtype) depthwise = ethosu_ops.ethosu_depthwise_conv2d( diff --git a/tests/python/contrib/test_ethosu/test_type_inference.py b/tests/python/contrib/test_ethosu/test_type_inference.py index 9b041392c732c..ecbe31b3cbd35 100644 --- a/tests/python/contrib/test_ethosu/test_type_inference.py +++ b/tests/python/contrib/test_ethosu/test_type_inference.py @@ -61,6 +61,34 @@ def test_ethosu_conv2d_type_inference( assert tuple(func.body.checked_type.shape) == ofm_shape +@pytest.mark.parametrize( + "ifm_dtype,weight_dtype,scale_bias_dtype", + [("float32", "int8", "uint8"), ("int8", "float32", "uint8"), ("int8", "int8", "float32")], +) +def test_ethosu_conv2d_invalid_dtypes(ifm_dtype, weight_dtype, scale_bias_dtype): + ifm_channels = 55 + ofm_channels = 122 + kernel_shape = (3, 2) + padding = (0, 1, 2, 3) + strides = (1, 2) + dilation = (2, 1) + ifm = relay.var("ifm", shape=(1, 56, 72, 55), dtype=ifm_dtype) + conv2d = make_ethosu_conv2d( + ifm, + ifm_channels, + ofm_channels, + kernel_shape, + padding, + strides, + dilation, + weight_dtype=weight_dtype, + scale_bias_dtype=scale_bias_dtype, + ) + func = relay.Function([ifm], conv2d) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) + + @pytest.mark.parametrize( "ifm_shape, ifm_layout", [((1, 46, 71, 55), "NHWC"), ((1, 46, 4, 71, 16), "NHCWB16")] ) @@ -94,6 +122,33 @@ def test_ethosu_depthwise_conv2d_type_inference( assert tuple(func.body.checked_type.shape) == ofm_shape +@pytest.mark.parametrize( + "ifm_dtype,weight_dtype,scale_bias_dtype", + [("float32", "int8", "uint8"), ("int8", "float32", "uint8"), ("int8", "int8", "float32")], +) +def test_ethosu_depthwise_conv2d_invalid_dtypes(ifm_dtype, weight_dtype, scale_bias_dtype): + channels = 55 + kernel_shape = (3, 2) + padding = (0, 1, 2, 3) + strides = (1, 2) + dilation = (2, 1) + dilation = (2, 1) + ifm = relay.var("ifm", shape=(1, 56, 72, 55), dtype=ifm_dtype) + depthwise_conv2d = make_ethosu_depthwise_conv2d( + ifm, + channels, + kernel_shape, + padding, + strides, + dilation, + weight_dtype=weight_dtype, + scale_bias_dtype=scale_bias_dtype, + ) + func = relay.Function([ifm], depthwise_conv2d) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) + + @pytest.mark.parametrize( "ifm_shape, ifm_layout", [((1, 56, 72, 55), "NHWC"), ((1, 56, 4, 72, 16), "NHCWB16")] )