Skip to content

Commit

Permalink
[microNPU] Replace ICHECK with diagnostic context in type inference (a…
Browse files Browse the repository at this point in the history
…pache#9470)

[microNPU] Replace ICHECK with diagnostic context in type inference

Convolution and depthwise convolution use the ICHECK format of
error checking during type inference. This PR updates these checks to
use the diagnostic context.
  • Loading branch information
lhutton1 authored and mehrdadh committed Dec 1, 2021
1 parent 4140d07 commit 614446d
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 20 deletions.
28 changes: 22 additions & 6 deletions src/relay/op/contrib/ethosu/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,28 @@ bool EthosuConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attr
if (ifm == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<EthosuConv2DAttrs>();
CHECK(param != nullptr) << "EthosuConv2DAttrs cannot be nullptr.";
CHECK(ifm->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8))
<< "Expected ethosu_conv2d type(uint8) or type(int8) for ifm but was " << ifm->dtype;
CHECK(weight->dtype == DataType::UInt(8) || weight->dtype == DataType::Int(8))
<< "Expected ethosu_conv2d type(uint8) or type(int8) for weight but was " << weight->dtype;
CHECK(scale_bias->dtype == DataType::UInt(8))
<< "Expected ethosu_conv2d type(uint8) for scale_bias but was " << scale_bias->dtype;

if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_conv2d input data type "
<< "of type(uint8) or type(int8) but was " << ifm->dtype);
return false;
}

if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_conv2d weight data type "
<< "of type(uint8) or type(int8) but was " << weight->dtype);
return false;
}

if (scale_bias->dtype != DataType::UInt(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_conv2d scale bias data type "
<< "of type(uint8) but was " << scale_bias->dtype);
return false;
}

// The scale_bias should be provided as a tensor of size {ofm_channels, 10}
reporter->Assign(types[2], TensorType({weight->shape[0], 10}, DataType::UInt(8)));
Expand Down
33 changes: 24 additions & 9 deletions src/relay/op/contrib/ethosu/depthwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,30 @@ bool EthosuDepthwiseConv2DRel(const Array<Type>& types, int num_inputs, const At

const auto* param = attrs.as<EthosuDepthwiseConv2DAttrs>();
ICHECK(param != nullptr) << "EthosuDepthwiseConv2DAttrs cannot be nullptr.";
ICHECK(ifm->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8))
<< "Expected ethosu_depthwise_conv2d type(uint8) or type(int8) for ifm but was "
<< ifm->dtype;
ICHECK(weight->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8))
<< "Expected ethosu_depthwise_conv2d type(uint8) or type(int8) for weight but was "
<< weight->dtype;
ICHECK(scale_bias->dtype == DataType::UInt(8))
<< "Expected ethosu_depthwise_conv2d type(uint8) for scale_bias but was "
<< scale_bias->dtype;

if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d input data type "
<< "of type(uint8) or type(int8) but was " << ifm->dtype);
return false;
}

if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d weight data type "
<< "of type(uint8) or type(int8) but was " << weight->dtype);
return false;
}

if (scale_bias->dtype != DataType::UInt(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d scale bias data type "
<< "of type(uint8) but was " << scale_bias->dtype);
return false;
}

// Collect the ifm, weight and ofm tensors for using in the inference function
Array<Type> tensor_types = {types[0], types[1], types[4]};
Expand Down
12 changes: 7 additions & 5 deletions tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,14 +382,15 @@ def make_ethosu_conv2d(
ifm_layout="NHWC",
ofm_layout="NHWC",
weight_dtype="int8",
scale_bias_dtype="uint8",
):
# conv params
weight_shape = (ofm_channels, kernel_shape[0], kernel_shape[1], ifm_channels)
padding = get_pad_tuple(padding, kernel_shape)

scale_bias_data = generate_weights_data((weight_shape[0], 10), "uint8")
scale_bias = relay.const(scale_bias_data, dtype="uint8")
weight_data = generate_weights_data(weight_shape, "int8")
scale_bias_data = generate_weights_data((weight_shape[0], 10), scale_bias_dtype)
scale_bias = relay.const(scale_bias_data, dtype=scale_bias_dtype)
weight_data = generate_weights_data(weight_shape, weight_dtype)
weight = relay.const(weight_data, dtype=weight_dtype)
conv = ethosu_ops.ethosu_conv2d(
ifm,
Expand Down Expand Up @@ -427,13 +428,14 @@ def make_ethosu_depthwise_conv2d(
ifm_layout="NHWC",
ofm_layout="NHWC",
weight_dtype="int8",
scale_bias_dtype="uint8",
):
# params
weight_shape = (channels, kernel_shape[0], kernel_shape[1], 1)
padding = get_pad_tuple(padding, kernel_shape)

scale_bias_data = generate_weights_data((weight_shape[0], 10), "uint8")
scale_bias = relay.const(scale_bias_data, dtype="uint8")
scale_bias_data = generate_weights_data((weight_shape[0], 10), scale_bias_dtype)
scale_bias = relay.const(scale_bias_data, dtype=scale_bias_dtype)
weight_data = generate_weights_data(weight_shape, weight_dtype)
weight = relay.const(weight_data, dtype=weight_dtype)
depthwise = ethosu_ops.ethosu_depthwise_conv2d(
Expand Down
55 changes: 55 additions & 0 deletions tests/python/contrib/test_ethosu/test_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,34 @@ def test_ethosu_conv2d_type_inference(
assert tuple(func.body.checked_type.shape) == ofm_shape


@pytest.mark.parametrize(
"ifm_dtype,weight_dtype,scale_bias_dtype",
[("float32", "int8", "uint8"), ("int8", "float32", "uint8"), ("int8", "int8", "float32")],
)
def test_ethosu_conv2d_invalid_dtypes(ifm_dtype, weight_dtype, scale_bias_dtype):
ifm_channels = 55
ofm_channels = 122
kernel_shape = (3, 2)
padding = (0, 1, 2, 3)
strides = (1, 2)
dilation = (2, 1)
ifm = relay.var("ifm", shape=(1, 56, 72, 55), dtype=ifm_dtype)
conv2d = make_ethosu_conv2d(
ifm,
ifm_channels,
ofm_channels,
kernel_shape,
padding,
strides,
dilation,
weight_dtype=weight_dtype,
scale_bias_dtype=scale_bias_dtype,
)
func = relay.Function([ifm], conv2d)
with pytest.raises(TVMError):
run_opt_pass(func, relay.transform.InferType())


@pytest.mark.parametrize(
"ifm_shape, ifm_layout", [((1, 46, 71, 55), "NHWC"), ((1, 46, 4, 71, 16), "NHCWB16")]
)
Expand Down Expand Up @@ -94,6 +122,33 @@ def test_ethosu_depthwise_conv2d_type_inference(
assert tuple(func.body.checked_type.shape) == ofm_shape


@pytest.mark.parametrize(
"ifm_dtype,weight_dtype,scale_bias_dtype",
[("float32", "int8", "uint8"), ("int8", "float32", "uint8"), ("int8", "int8", "float32")],
)
def test_ethosu_depthwise_conv2d_invalid_dtypes(ifm_dtype, weight_dtype, scale_bias_dtype):
channels = 55
kernel_shape = (3, 2)
padding = (0, 1, 2, 3)
strides = (1, 2)
dilation = (2, 1)
dilation = (2, 1)
ifm = relay.var("ifm", shape=(1, 56, 72, 55), dtype=ifm_dtype)
depthwise_conv2d = make_ethosu_depthwise_conv2d(
ifm,
channels,
kernel_shape,
padding,
strides,
dilation,
weight_dtype=weight_dtype,
scale_bias_dtype=scale_bias_dtype,
)
func = relay.Function([ifm], depthwise_conv2d)
with pytest.raises(TVMError):
run_opt_pass(func, relay.transform.InferType())


@pytest.mark.parametrize(
"ifm_shape, ifm_layout", [((1, 56, 72, 55), "NHWC"), ((1, 56, 4, 72, 16), "NHCWB16")]
)
Expand Down

0 comments on commit 614446d

Please sign in to comment.