diff --git a/python/tvm/relay/qnn/op/_qnn.py b/python/tvm/relay/qnn/op/_qnn.py index c9c4c86e8b47..e2157a051abb 100644 --- a/python/tvm/relay/qnn/op/_qnn.py +++ b/python/tvm/relay/qnn/op/_qnn.py @@ -93,7 +93,16 @@ def alter_op_layout_qnn_conv2d(attrs, inputs, tinfos, out_type): # qnn.dense register_strategy("qnn.dense", strategy.qnn_dense_strategy) -register_pattern("qnn.dense", OpPattern.OUT_ELEMWISE_FUSABLE) + + +@register_alter_op_layout("qnn.dense") +def alter_op_layout_qnn_dense(attrs, inputs, tinfos, out_type): + """Alternate the layout of qnn.dense""" + return topi.nn.qnn_dense_alter_layout(attrs, inputs, tinfos, out_type) + + +# qnn.contrib_dense_pack +register_strategy("qnn.contrib_dense_pack", strategy.qnn_dense_pack_strategy) # qnn.batch_matmul register_strategy("qnn.batch_matmul", strategy.qnn_batch_matmul_strategy) diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index ef368a016e0c..53cb41c2fb2f 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -340,6 +340,62 @@ def helper_change_dtypes_to_int8(attrs, inputs, types, relay_op): ) +def helper_change_dtypes_to_uint8(attrs, inputs, types, relay_op): + """Helper function to change dtypes to uint8 x uint8. + Legalizes QNN dense op for Hexagon DSP. It supports fast u8 x u8 vrmpy instruction. + + Converting from int8 to uint8 can be done in following manner: + + Original equation + scale * (QA - zp_a) + scale * (QA + 128 - 128 - zp_a) + scale * ( (QA + 128) - (zp_a + 128)) + + Replacing QA + 128 with QA' and (zp_a + 128) with zp_a' + We get our new quantized uint8 tensor - scale * (QA' - zp_a') + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + # Collect the dtypes. + data_dtype = types[0].dtype + kernel_dtype = types[1].dtype + + # Do nothing since it is already uint8. + if data_dtype == "uint8" and kernel_dtype == "uint8": + return None + + # Collect the input exprs. + data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale = inputs + + # Shift input if necessary. + if data_dtype == "int8": + # Compute (QA + 128) and (zp_a + 128) + data, input_zero_point = _shift(data, input_zero_point, "uint8") + + # Shift kernel if necessary. + if kernel_dtype == "int8": + # Compute (QA + 128) and (zp_a + 128) + kernel, kernel_zero_point = _shift(kernel, kernel_zero_point, "uint8") + + # Call qnn.conv2d/qnn.dense with modified inputs and zero points. + new_attrs = dict(attrs) + return relay_op( + data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale, **new_attrs + ) + + # Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting. def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op): """Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However, @@ -520,7 +576,7 @@ def _qnn_conv2d_legalize_hexagon(attrs, inputs, types): out_channel = kernel_tensor.shape[0].value ic_modified = False oc_modified = False - data, kernel, input_zp, output_zp, input_scale, output_scale = inputs + data, kernel, data_zp, kernel_zp, data_scale, kernel_scale = inputs if in_channel % IN_CHANNEL_VECTOR_LENGTH != 0: new_in_channel = helper_align_up(in_channel, IN_CHANNEL_VECTOR_LENGTH) @@ -537,21 +593,93 @@ def _qnn_conv2d_legalize_hexagon(attrs, inputs, types): kernel = relay.nn.pad(kernel, pad_width=((0, diff), (0, 0), (0, 0), (0, 0))) oc_modified = True + # Pad kernel zero point by 'diff' elements of 0 if it is not scalar + kernel_zp_tensor = types[3] + if len(kernel_zp_tensor.shape) != 0: + assert isinstance(kernel_zp, relay.Constant) + padded_kernel_zp_np = np.append(kernel_zp.data.numpy(), [0] * diff) + kernel_zp = relay.const(padded_kernel_zp_np) + + # Pad kernel scale by 'diff' elements of 1.0 if it is not scalar + kernel_scale_tensor = types[5] + if len(kernel_scale_tensor.shape) != 0: + assert isinstance(kernel_scale, relay.Constant) + padded_kernel_scale_np = np.append(kernel_scale.data.numpy(), [1.0] * diff) + kernel_scale = relay.const(padded_kernel_scale_np) + if ic_modified is True or oc_modified is True: new_attrs = dict(attrs) if oc_modified: new_attrs["channels"] = new_out_channel out = relay.qnn.op.conv2d( - data, kernel, input_zp, output_zp, input_scale, output_scale, **new_attrs + data, kernel, data_zp, kernel_zp, data_scale, kernel_scale, **new_attrs ) output_tensor = types[6] original_out_shape = list(output_tensor.shape) out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape) else: out = relay.qnn.op.conv2d( - data, kernel, input_zp, output_zp, input_scale, output_scale, **new_attrs + data, kernel, data_zp, kernel_zp, data_scale, kernel_scale, **new_attrs ) return out return None + + +@qnn_dense_legalize.register("hexagon") +def _qnn_dense_legalize_hexagon(attrs, inputs, types): + """Legalize qnn.dense op for vrmpy tensorization. + + N dimension of weights should be aligned on vector length. If not, then N dimension is padded to + be a multiple of 32. + """ + assert len(types) == 7 + assert len(inputs) == 6 + + data_tensor, kernel_tensor = types[0], types[1] + if "int8" not in data_tensor.dtype or "int8" not in kernel_tensor.dtype: + return None + + N, _ = kernel_tensor.shape + + if N % OUT_CHANNEL_VECTOR_LENGTH != 0: + N_padded = helper_align_up(N, OUT_CHANNEL_VECTOR_LENGTH) + diff = N_padded - N + + data, kernel, data_zp, kernel_zp, data_scale, kernel_scale = inputs + + # Pad weights by 'diff' + padded_kernel = relay.nn.pad(kernel, pad_width=((0, diff), (0, 0))) + + kernel_zp_tensor, kernel_scale_tensor = types[3], types[5] + + # Pad kernel zero point by 'diff' elements of 0 if it is not scalar + if len(kernel_zp_tensor.shape) != 0: + assert isinstance(kernel_zp, relay.Constant) + assert isinstance(diff, tvm.tir.IntImm) + padded_kernel_zp_np = np.append(kernel_zp.data.numpy(), [0] * diff.value) + kernel_zp = relay.const(padded_kernel_zp_np) + + # Pad kernel scale by 'diff' elements of 1.0 if it is not scalar + if len(kernel_scale_tensor.shape) != 0: + assert isinstance(kernel_scale, relay.Constant) + assert isinstance(diff, tvm.tir.IntImm) + padded_kernel_scale_np = np.append(kernel_scale.data.numpy(), [1.0] * diff.value) + kernel_scale = relay.const(padded_kernel_scale_np) + + # If units is explicitly specified, it is used to compute the output shape. + # We need to update units after padding to prevent a type error. + new_attrs = dict(attrs) + if attrs["units"] is not None: + new_attrs["units"] = N + diff + + new_inputs = (data, padded_kernel, data_zp, kernel_zp, data_scale, kernel_scale) + + out = relay.qnn.op.dense(*new_inputs, **new_attrs) + + output_tensor = types[6] + out = relay.strided_slice(out, begin=[0, 0], end=list(output_tensor.shape)) + return out + + return None diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 78d6669413ca..6c0248d40d92 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -718,6 +718,70 @@ def dense( ) +def contrib_dense_pack( + data, + weight, + input_zero_point, + kernel_zero_point, + input_scale, + kernel_scale, + kernel_layout="NC", + units=None, + out_dtype="int32", +): + """Qnn contrib_dense_pack operator. + Applies a quantized linear transformation + + .. math:: + + `Y = X * W` + + If doing Per-channel quantization, qnn expects the kernel_zero_scale + and optionally the kernel_zero_point will be 1-D vectors instead of scalars. + + Parameters + ---------- + data : tvm.relay.Expr + The quantized input data to the operator. + weight : tvm.relay.Expr + The quantized weight expressions. + input_zero_point: tvm.relay.Expr + The input zero point. + kernel_zero_point: tvm.relay.Expr + The kernel zero point. + input_scale: tvm.relay.Expr + The scale for the input tensor. + kernel_scale: tvm.relay.Expr + The scale for the weight tensor. The scale for the weight tensor is + stored for access to this during relay. This information is not + needed in the pass pipeline after qnn.conv2d is lowered to the + sequence of steps as in nn.conv2d. See also input_scale in Requantize. + kernel_layout: str + The layout of weight, such as "NC" or "NC32n4c". + units : int, optional + Number of hidden units of the dense transformation. + out_dtype : str, optional + Specifies the output data type for mixed precision dense can be int32 or int16. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + + return _make.contrib_dense_pack( + data, + weight, + input_zero_point, + kernel_zero_point, + input_scale, + kernel_scale, + kernel_layout, + units, + out_dtype, + ) + + def mul( lhs, rhs, diff --git a/python/tvm/relay/qnn/strategy/generic.py b/python/tvm/relay/qnn/strategy/generic.py index 8275cf7f755e..3ebf8edd3665 100644 --- a/python/tvm/relay/qnn/strategy/generic.py +++ b/python/tvm/relay/qnn/strategy/generic.py @@ -267,6 +267,12 @@ def qnn_dense_strategy(attrs, inputs, out_type, target): ) +@override_native_generic_func("qnn_dense_pack_strategy") +def qnn_dense_pack_strategy(attrs, inputs, out_type, target): + """qnn.contrib_dense_pack generic strategy""" + raise RuntimeError("qnn.contrib_dense_pack is currently only supported with Hexagon. ") + + @override_native_generic_func("qnn_batch_matmul_strategy") def qnn_batch_matmul_strategy(attrs, inputs, out_type, target): """qnn.batch_matmul generic strategy""" diff --git a/python/tvm/relay/qnn/strategy/hexagon.py b/python/tvm/relay/qnn/strategy/hexagon.py index c25c96f8edb4..d17b0da6cf0a 100644 --- a/python/tvm/relay/qnn/strategy/hexagon.py +++ b/python/tvm/relay/qnn/strategy/hexagon.py @@ -173,6 +173,24 @@ def qnn_dense_strategy_hexagon(attrs, inputs, out_type, target): return strategy +@qnn_dense_pack_strategy.register("hexagon") +def qnn_dense_pack_strategy_hexagon(attrs, inputs, out_type, target): + """qnn.contrib_dense_pack strategy for Hexagon""" + strategy = _op.OpStrategy() + if ( + "uint8" in inputs[0].dtype + and "int8" in inputs[1].dtype + and attrs["weight_layout"] == "NC32n4c" + ): + # uint8 + uint8|int8 case + strategy.add_implementation( + wrap_topi_qnn_dense(topi.hexagon.qnn_dense_pack_vrmpy), + wrap_topi_schedule(topi.hexagon.schedule_qnn_dense_pack_vrmpy), + name="qnn_dense_pack_vrmpy.hexagon", + ) + return strategy + + @qnn_batch_matmul_strategy.register("hexagon") def qnn_batch_matmul_strategy_hexagon(attrs, inputs, out_type, target): """qnn.batch_matmul strategy for Hexagon""" diff --git a/python/tvm/topi/hexagon/qnn/__init__.py b/python/tvm/topi/hexagon/qnn/__init__.py index 022a552c9d54..ba7d64b6b56d 100644 --- a/python/tvm/topi/hexagon/qnn/__init__.py +++ b/python/tvm/topi/hexagon/qnn/__init__.py @@ -20,6 +20,7 @@ from .adaptive_avg_pool1d import * from .avg_pool2d import qnn_avg_pool2d_compute, qnn_avg_pool2d_schedule from .conv2d_alter_op import * +from .dense_alter_op import * from .dequantize import dequantize_compute, dequantize_schedule from .global_avg_pool2d import * from .nn import * diff --git a/python/tvm/topi/hexagon/qnn/dense_alter_op.py b/python/tvm/topi/hexagon/qnn/dense_alter_op.py new file mode 100644 index 000000000000..1935bbda036e --- /dev/null +++ b/python/tvm/topi/hexagon/qnn/dense_alter_op.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""QNN Dense alter op functions for Hexagon""" + +from tvm import relay +from ..dense_alter_op import check_vrmpy_applicable +from ...nn import qnn_dense_alter_layout + + +@qnn_dense_alter_layout.register("hexagon") +def _alter_qnn_dense_layout(_attrs, inputs, tinfos, out_type): + data_tensor = tinfos[0] + weight_tensor = tinfos[1] + + if check_vrmpy_applicable(data_tensor, weight_tensor): + weight_layout = "NC32n4c" + return relay.qnn.op.contrib_dense_pack(*inputs, weight_layout, None, out_type.dtype) + else: + return None diff --git a/python/tvm/topi/hexagon/qnn/nn.py b/python/tvm/topi/hexagon/qnn/nn.py index aabdf2a63b8b..5702be2e1a33 100644 --- a/python/tvm/topi/hexagon/qnn/nn.py +++ b/python/tvm/topi/hexagon/qnn/nn.py @@ -17,12 +17,13 @@ """Hexagon QNN operators""" # pylint: disable=invalid-name +from typing import Union import numpy as np import tvm from tvm import te, topi from ..utils import saturate, get_fixed_point_value -from ...utils import get_const_tuple +from ...utils import get_const_tuple, get_const_int, get_const_float from ...nn.utils import get_pad_tuple from ...nn.pad import pad from ... import tag, nn @@ -37,11 +38,25 @@ def clip_cast(val, dtype): return te.max(tvm.te.min(val, const_max), const_min).astype(dtype) -# Return True if given Tensor is scalar constant value. -def is_constant(tensor: te.Tensor): - return tensor.ndim == 0 and ( - isinstance(tensor.op.body[0], (tvm.tir.expr.FloatImm, tvm.tir.expr.IntImm)) - ) +# Return True if given expression is scalar constant value. +def is_scalar(expr): + if isinstance(expr, te.Tensor): + return expr.ndim == 0 and (isinstance(expr.op.body[0], (tvm.tir.FloatImm, tvm.tir.IntImm))) + return isinstance(expr, (tvm.tir.FloatImm, tvm.tir.IntImm)) + + +def get_const_int_value(expr): + if isinstance(expr, te.Tensor): + assert isinstance(expr.op.body[0], tvm.tir.IntImm) + return expr.op.body[0].value + return get_const_int(expr) + + +def get_const_float_value(expr): + if isinstance(expr, te.Tensor): + assert isinstance(expr.op.body[0], tvm.tir.FloatImm) + return expr.op.body[0].value + return get_const_float(expr) def get_qnn_param(param, indices, axis): @@ -53,6 +68,28 @@ def get_qnn_param(param, indices, axis): return param[param_idx] +def subtract_zero_point( + tensor: te.Tensor, + zero_point: Union[te.Tensor, tvm.tir.IntImm], + name: str, +): + """ + Subtract zero point from given tensor. If zero point is scalar constant and is equal to 0, then + it can be optimized and return tensor as it is. + This new block is marked with 'meta_schedule.inline_rule = disable' attribute to disable inline. + Otherwise, inline prevents from tensorization and leveraging vrmpy intrinsic + """ + if is_scalar(zero_point) and get_const_int_value(zero_point) == 0: + return tensor + else: + return te.compute( + tensor.shape, + lambda *i: te.subtract(tensor(*i), zero_point).astype(tensor.dtype), + name=name, + attrs={"meta_schedule.inline_rule": "disable"}, + ) + + def default_schedule(outs): """Simple default schedule for QNN ops. @@ -172,9 +209,9 @@ def qnn_requantize( TODO: support 'rounding' and 'compute_dtype' arguments. """ - if is_constant(input_scale) and is_constant(output_scale): - iscale = input_scale.op.body[0].value - oscale = output_scale.op.body[0].value + if is_scalar(input_scale) and is_scalar(output_scale): + iscale = get_const_float_value(input_scale) + oscale = get_const_float_value(output_scale) scale = iscale / oscale scale_fixed_point, rsh = get_fixed_point_value(scale, "int16") @@ -187,7 +224,7 @@ def _compute(*indices): # Add output zero point + clip + cast: return saturate(te.add(mul, output_zp), out_dtype).astype(out_dtype) - return te.compute(data.shape, _compute) + return te.compute(data.shape, _compute, name="requantize") else: @@ -196,12 +233,14 @@ def _compute(*indices): iscale = get_qnn_param(input_scale, indices, axis) oscale = get_qnn_param(output_scale, indices, axis) + # Subtract input zero point: sub = te.subtract(value, input_zp) mul = te.div(iscale, oscale) val = te.add(te.round(te.multiply(mul, sub)), output_zp) + # clip + cast: return saturate(val, out_dtype).astype(out_dtype) - return te.compute(data.shape, _compute) + return te.compute(data.shape, _compute, name="requantize") def schedule_qnn_requantize(outs): @@ -245,7 +284,7 @@ def _compute_const(x: te.Tensor, iscale, input_zp): ) def _compute_tensor(x: te.Tensor, input_scale, input_zp): - if is_constant(input_scale) and is_constant(output_scale): + if is_scalar(input_scale) and is_scalar(output_scale): iscale = input_scale.op.body[0].value oscale = output_scale.op.body[0].value scale = iscale / oscale @@ -263,12 +302,12 @@ def _compute_tensor(x: te.Tensor, input_scale, input_zp): ).astype("int32"), ) - if is_constant(lhs): + if is_scalar(lhs): lhs_tensor = _compute_const(lhs, lhs_scale, lhs_zp) else: lhs_tensor = _compute_tensor(lhs, lhs_scale, lhs_zp) - if is_constant(rhs): + if is_scalar(rhs): rhs_tensor = _compute_const(rhs, rhs_scale, rhs_zp) else: rhs_tensor = _compute_tensor(rhs, rhs_scale, rhs_zp) @@ -334,7 +373,16 @@ def schedule_qnn_subtract(outs): return default_schedule(outs) -def qnn_mul(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output_zp): +def qnn_mul( + lhs: te.Tensor, + rhs: te.Tensor, + lhs_scale: te.Tensor, + lhs_zp: te.Tensor, + rhs_scale: te.Tensor, + rhs_zp: te.Tensor, + output_scale: te.Tensor, + output_zp: te.Tensor, +): """Compute for qnn.mul mul = (lhs_input - lhs_zp) * (rhs_input - rhs_zp) @@ -343,20 +391,25 @@ def qnn_mul(lhs, rhs, lhs_scale, lhs_zp, rhs_scale, rhs_zp, output_scale, output assert lhs.dtype == rhs.dtype odtype = lhs.dtype - if is_constant(lhs): - lhs_tensor = lhs - lhs_zp - else: - lhs_tensor = te.compute(lhs.shape, lambda *i: te.subtract(lhs(*i), lhs_zp)) + def _compute_tensor(tensor, zero_point): + if is_scalar(tensor): + return tensor - zero_point + else: + return te.compute(tensor.shape, lambda *i: te.subtract(tensor(*i), zero_point)) - if is_constant(rhs): - rhs_tensor = rhs - rhs_zp - else: - rhs_tensor = te.compute(rhs.shape, lambda *i: te.subtract(rhs(*i), rhs_zp)) + lhs_tensor = _compute_tensor(lhs, lhs_zp) + rhs_tensor = _compute_tensor(rhs, rhs_zp) # Multiply with broadcasting. mul = topi.multiply(lhs_tensor, rhs_tensor) - iscale = lhs_scale * rhs_scale + if is_scalar(lhs_scale) and is_scalar(rhs_scale): + assert isinstance(lhs_scale, te.Tensor) + assert isinstance(rhs_scale, te.Tensor) + iscale = lhs_scale.op.body[0] * rhs_scale.op.body[0] + else: + iscale = lhs_scale * rhs_scale + return qnn_requantize(mul, iscale, tvm.tir.const(0), output_scale, output_zp, out_dtype=odtype) @@ -616,23 +669,10 @@ def qnn_conv2d_NCHWc_int8( # Conv2d inputs odtype, ): """Compute for qnn.conv2d with NCHWc layout.""" - # Subtract zero point from weights. Need to disable inline of this block - # (meta_schedule.inline_rule = disable). Otherwise, inline prevents from tensorization. - weight = te.compute( - weight.shape, - lambda *i: te.subtract(weight(*i), kernel_zero_point).astype(weight.dtype), - name="weight_zp", - attrs={"meta_schedule.inline_rule": "disable"}, - ) - # Subtract zero point from input. Again need to disable inline of this block - # (meta_schedule.inline_rule = disable). Otherwise, inline prevents from tensorization. - data = te.compute( - data.shape, - lambda *i: te.subtract(data(*i), input_zero_point).astype(data.dtype), - name="data_zp", - attrs={"meta_schedule.inline_rule": "disable"}, - ) + # Subtract zero point from input and weights. + weight = subtract_zero_point(weight, kernel_zero_point, "weight_zp") + data = subtract_zero_point(data, input_zero_point, "data_zp") strides = get_const_tuple(strides) padding = get_const_tuple(padding) @@ -644,7 +684,9 @@ def qnn_conv2d_NCHWc_int8( # Conv2d inputs assert len(out.shape) == len(bias.shape) assert bias.shape[2] == 1 and bias.shape[3] == 1 out = te.compute( - out.shape, lambda n, c, h, w, ci: out[n, c, h, w, ci] + bias[n, c, 0, 0, ci] + out.shape, + lambda n, c, h, w, ci: out[n, c, h, w, ci] + bias[n, c, 0, 0, ci], + name="bias_add", ) # Requantize output of convolution @@ -874,6 +916,98 @@ def schedule_qnn_dense(outs): return default_schedule(outs) +def qnn_dense_pack_vrmpy( + data: te.Tensor, + weight: te.Tensor, + # Dense quantization params: + input_zero_point: te.Tensor, + kernel_zero_point: te.Tensor, + _input_scale: te.Tensor, + _kernel_scale: te.Tensor, + # bias + bias: te.Tensor, + # Requantization params: + rq_input_scale: te.Tensor, + rq_input_zero_point: te.Tensor, + rq_output_scale: te.Tensor, + rq_output_zero_point: te.Tensor, + out_dtype: str, +): + """Compute for qnn.contrib_dense_pack + + Output data type should be specified through the 'odtype' parameter. qnn.dense leverages int32 + type to store intermediate results. If 'odtype' differs from int32, you need to specify + requantization parameters. + """ + # Subtract zero point from input and weights. + weight = subtract_zero_point(weight, kernel_zero_point, "weight_zp") + data = subtract_zero_point(data, input_zero_point, "data_zp") + + # Required for vrmpy intrinsic + assert "int8" in weight.dtype and "int8" in data.dtype + + M, K = get_const_tuple(data.shape) + N_O, _, N_I, _ = get_const_tuple(weight.shape) + k = te.reduce_axis((0, K), "k") + out = te.compute( + (M, N_O * N_I), + lambda m, n: te.sum( + data[m, k].astype("int32") + * weight[ + tvm.tir.indexdiv(n, 32), + tvm.tir.indexdiv(k, 4), + tvm.tir.indexmod(n, 32), + tvm.tir.indexmod(k, 4), + ].astype("int32"), + axis=k, + ), + name="qnn_dense_pack", + ) + + # Add bias + if bias is not None: + assert bias.ndim == 2 + out = te.compute(out.shape, lambda n, c: out[n, c] + bias[0, c]) + + # Requantize output of qnn.contrib_dense_pack + if rq_input_scale is not None and rq_output_scale is not None: + # Now supported only scalar and 1D quantization parameters + assert rq_input_scale.ndim == 0 or rq_input_scale.ndim == 1 + assert rq_output_scale.ndim == 0 or rq_output_scale.ndim == 1 + axis = -1 + if rq_input_scale.ndim == 1 or rq_output_scale.ndim == 1: + axis = 1 # Axis param should correspond to 'C' dimension. + + return qnn_requantize( + out, + rq_input_scale, + rq_input_zero_point, + rq_output_scale, + rq_output_zero_point, + axis, + out_dtype, + ) + + return out + + +def schedule_qnn_dense_pack_vrmpy(outs): + """Schedule for qnn.contrib_dense_pack + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of qnn.dense + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return default_schedule(outs) + + def qnn_batch_matmul( tensor_a, tensor_b, diff --git a/python/tvm/topi/nn/qnn.py b/python/tvm/topi/nn/qnn.py index 7a29266b087c..9aaa452a7392 100644 --- a/python/tvm/topi/nn/qnn.py +++ b/python/tvm/topi/nn/qnn.py @@ -255,3 +255,22 @@ def qnn_conv2d_alter_layout(_attrs, _inputs, _tinfos, _out_type): The output type """ return None + + +@tvm.target.generic_func +def qnn_dense_alter_layout(_attrs, _inputs, _tinfos, _out_type): + """Change qnn.dense layout. + Not to change by default + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current dense op + inputs : tvm.relay.Expr + Grouped input symbols + tinfos : list + Input shape and dtype + out_type: type + The output type + """ + return None diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index c680c5a77e04..94710b28a4fe 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -134,6 +134,7 @@ class QnnPatternMatcher { QnnPatternMatcher() : qnn_conv2d_op_(Op::Get("qnn.conv2d")), qnn_dense_op_(Op::Get("qnn.dense")), + qnn_dense_pack_op_(Op::Get("qnn.contrib_dense_pack")), qnn_requantize_op_(Op::Get("qnn.requantize")), bias_add_op_(Op::Get("add")) {} @@ -153,6 +154,10 @@ class QnnPatternMatcher { registered_ops_.push_front(P_QDense); ICHECK(anchor_op_ == nullptr); anchor_op_ = call_node; + } else if (op == qnn_dense_pack_op_) { + registered_ops_.push_front(P_QDensePack); + ICHECK(anchor_op_ == nullptr); + anchor_op_ = call_node; } else { registered_ops_.push_front(P_Opaque); } @@ -163,7 +168,7 @@ class QnnPatternMatcher { if (registered_ops_.empty()) return false; if (op == qnn_conv2d_op_ || op == qnn_requantize_op_ || op == bias_add_op_ || - op == qnn_dense_op_) { + op == qnn_dense_op_ || op == qnn_dense_pack_op_) { for (const auto& pat : supported_patterns_) { auto it = std::search(registered_ops_.begin(), registered_ops_.end(), pat.begin(), pat.end()); @@ -183,21 +188,24 @@ class QnnPatternMatcher { private: const Op& qnn_conv2d_op_; const Op& qnn_dense_op_; + const Op& qnn_dense_pack_op_; const Op& qnn_requantize_op_; const Op& bias_add_op_; // Main (complicated) operation in the primitive (for example qnn.conv2d, qnn.dense etc.). const CallNode* anchor_op_ = nullptr; - enum POper { P_QConv2d, P_QDense, P_BiasAdd, P_QRequantize, P_Opaque }; + enum POper { P_QConv2d, P_QDense, P_QDensePack, P_BiasAdd, P_QRequantize, P_Opaque }; std::deque registered_ops_; const std::vector> supported_patterns_ = { - {P_QDense, P_BiasAdd, P_QRequantize}, // Pattern qnn.dense -> bias_add -> qnn.requantize - {P_QDense, P_QRequantize}, // Patter qnn.dense -> qnn.requantize - {P_QConv2d, P_BiasAdd, P_QRequantize}, // Pattern qnn.conv2d -> bias_add -> qnn.requantize - {P_QConv2d, P_QRequantize} // Patter qnn.conv2d -> qnn.requantize + {P_QDense, P_BiasAdd, P_QRequantize}, // qnn.dense -> bias_add -> qnn.requantize + {P_QDense, P_QRequantize}, // qnn.dense -> qnn.requantize + {P_QDensePack, P_BiasAdd, P_QRequantize}, // qnn.contrib_dense_pack -> bias -> qnn.requantize + {P_QDensePack, P_QRequantize}, // qnn.contrib_dense_pack -> qnn.requantize + {P_QConv2d, P_BiasAdd, P_QRequantize}, // qnn.conv2d -> bias_add -> qnn.requantize + {P_QConv2d, P_QRequantize} // qnn.conv2d -> qnn.requantize }; }; diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index cf601ff5f11b..3ebef29776cd 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -220,6 +220,11 @@ bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs return true; } +InferCorrectLayoutOutput DenseInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_NN_NN_H_ diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 09d51e3c9ce7..48f2a813d0e7 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -22,10 +22,8 @@ * \brief Property def of qnn dense operator. */ -#include #include #include -#include #include "../../op/nn/nn.h" #include "../../transforms/pattern_utils.h" @@ -75,6 +73,27 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, return MatmulRel(tensor_types, 3, attrs, reporter); } +InferCorrectLayoutOutput QnnDenseInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + // Use Relay Dense Infer correct layout. + auto dense_new_layouts = + DenseInferCorrectLayout(attrs, new_in_layouts, old_in_layouts, old_in_types); + + // Fill the layouts of remaining input tensors - scales and zero points. The layouts of these + // tensors can be treated as channel layout. + Layout channel_layout = Layout("N"); + Array input_layouts = {dense_new_layouts->input_layouts[0], + dense_new_layouts->input_layouts[1], + channel_layout, + channel_layout, + channel_layout, + channel_layout}; + Array output_layouts = dense_new_layouts->output_layouts; + return InferCorrectLayoutOutput(input_layouts, output_layouts, attrs); +} + // Positional relay function to create quantized dense operator used by frontend FFI. Expr MakeQuantizedDense(Expr data, Expr weight, Expr input_zero_point, Expr kernel_zero_point, Expr input_scale, Expr kernel_scale, IndexExpr units, DataType out_dtype) { @@ -223,11 +242,91 @@ RELAY_REGISTER_OP("qnn.dense") "The quantization zero_point of the weight tensor.") .set_support_level(11) .add_type_rel("QDense", QnnDenseRel) + .set_attr("FInferCorrectLayout", QnnDenseInferCorrectLayout) .set_attr("TNonComputational", true) - .set_attr("FTVMQnnCanonicalize", QnnDenseCanonicalize); + .set_attr("FTVMQnnCanonicalize", QnnDenseCanonicalize) + .set_attr("TOpPattern", kOutEWiseFusable); TVM_REGISTER_GLOBAL("relay.qnn.op._make.dense").set_body_typed(MakeQuantizedDense); +// ------------------- relay.qnn.op.contrib_dense_pack + +bool QnnDensePackRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // Expected types: data, weight, input_zero_point, weight_zero_point, input_scale, weight_scale, + // out_type + ICHECK_EQ(types.size(), 7); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr || weight == nullptr) return false; + + const DensePackAttrs* param = attrs.as(); + ICHECK(param != nullptr); + + ICHECK_EQ(data->shape.size(), 2) << "Only 2D data is supported"; + ICHECK(weight->shape.size() == 4) << "Expect weight to be 4D tensor"; + + Array oshape = data->shape; + oshape.Set(1, weight->shape[0] * weight->shape[2]); + + ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; + // assign output type + reporter->Assign(types[6], TensorType(oshape, param->out_dtype)); + return true; +} + +InferCorrectLayoutOutput QnnDensePackInferCorrectLayout( + const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, + const Array& old_in_types) { + auto params = attrs.as(); + ICHECK(params); + return InferCorrectLayoutOutput({"NC", params->weight_layout, "N", "N", "N", "N"}, {"NC"}, attrs); +} + +Expr QnnDensePackCanonicalize(const Attrs& attrs, const Array& new_args, + const Array& arg_types) { + LOG(FATAL) << "Canonicalization function for qnn.contrib_dense_pack is not implemented"; + return Expr(); +} + +Expr MakeQuantizedDensePack(Expr data, Expr weight, Expr input_zero_point, Expr kernel_zero_point, + Expr input_scale, Expr kernel_scale, tvm::String weight_layout, + IndexExpr units, DataType out_dtype) { + auto attrs = make_object(); + attrs->units = std::move(units); + attrs->out_dtype = out_dtype; + attrs->weight_layout = weight_layout; + static const Op& op = Op::Get("qnn.contrib_dense_pack"); + return Call(op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, + Attrs(attrs), {}); +} + +RELAY_REGISTER_OP("qnn.contrib_dense_pack") + .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. +- **data**: quantized(int8, uint8) `(x1, x2, ..., xn, input_dim)` +- **weight**: quantized(int8, uint8) `(units, input_dim)` +- **out**: quantized(int32) `(x1, x2, ..., xn, units)`. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(6) + .add_argument("data", "quantized nD Tensor", "Input data.") + .add_argument("weight", "quantized 2D Tensor", "Weight matrix.") + .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.") + .add_argument("weight_zero_point", "Tensor", + "The quantization zero_point of the weight tensor.") + .set_support_level(11) + .add_type_rel("QnnDensePack", QnnDensePackRel) + .set_attr("FInferCorrectLayout", QnnDensePackInferCorrectLayout) + .set_attr("TNonComputational", true) + .set_attr("FTVMQnnCanonicalize", QnnDensePackCanonicalize) + .set_attr("TOpPattern", kOutEWiseFusable); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.contrib_dense_pack").set_body_typed(MakeQuantizedDensePack); + +// ------------------- relay.qnn.op.contrib_dense_pack + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/tests/python/contrib/test_arm_compute_lib/test_dense.py b/tests/python/contrib/test_arm_compute_lib/test_dense.py index fa6057dd9a63..bbcfc4abe6a9 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_dense.py +++ b/tests/python/contrib/test_arm_compute_lib/test_dense.py @@ -151,11 +151,7 @@ def _get_expected_codegen(shape, weight_shape, units, dtype, has_bias=False): if has_bias: bias_dtype = "int32" if dtype in qnn_dtypes else "float32" - bias_shape = ( - [1, weight_shape[0]] - if dtype == "float32" and weight_shape[0] != 1 - else [weight_shape[0]] - ) + bias_shape = [1, weight_shape[0]] if weight_shape[0] != 1 else [weight_shape[0]] inputs.append( { "op": "const", diff --git a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py index e583b1b5eac8..f4342f5814df 100644 --- a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py +++ b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py @@ -23,6 +23,7 @@ from tvm.contrib.hexagon.session import Session from tvm.contrib.hexagon.pytest_plugin import HEXAGON_AOT_LLVM_TARGET from tvm.relay.backend import Executor +from tvm.relay.testing import run_opt_pass, run_infer_type @tvm.testing.requires_hexagon @@ -51,6 +52,28 @@ def test_no_qnn_pass(): assert "qnn.dequantize" in opt_mod_2.astext(show_meta_data=False) +def test_alter_layout_qnn_dense(): + """Test weights layout transformation of qnn.dense with int8 weights""" + data = relay.var("data", shape=(128, 16), dtype="uint8") + weight = relay.var("weight", shape=(64, 16), dtype="int8") + zero = relay.const(0) + iscale = relay.const(0.15) + wscale = relay.const(0.37) + + def before(): + return relay.qnn.op.dense(data, weight, zero, zero, iscale, wscale, units=None) + + def expected(): + op0 = relay.layout_transform(weight, src_layout="NC", dst_layout="NC32n4c") + return relay.qnn.op.contrib_dense_pack(data, op0, zero, zero, iscale, wscale, "NC32n4c") + + target = tvm.target.hexagon("v68") + with tvm.target.Target(target): + a = run_opt_pass(before(), tvm.relay.transform.AlterOpLayout()) + b = run_infer_type(expected()) + tvm.ir.assert_structural_equal(a, b) + + def execute(mod_executor, inputs: dict): for input_name, input_data in inputs.items(): mod_executor.set_input(input_name, input_data) @@ -130,59 +153,116 @@ def test_qnn_conv2d_rq(hexagon_session: Session): np.testing.assert_equal(hexagon_output, llvm_out) -@tvm.testing.requires_hexagon -def test_qnn_dense_bias_rq(hexagon_session: Session): - """QNN dense with bias test.""" - data_shape = [8, 8] - weight_shape = [16, 8] - bias_shape = [16] - data = relay.var("data", shape=data_shape, dtype="float32") - weight = relay.var("weight", shape=weight_shape, dtype="float32") - bias = relay.var("bias", shape=bias_shape, dtype="float32") +class TestQnnDense: + """QNN dense op test class.""" - op0 = relay.qnn.op.quantize(data, relay.const(0.08), relay.const(0), out_dtype="int8") - op1 = relay.qnn.op.quantize(weight, relay.const(0.07), relay.const(0), out_dtype="int8") - op2 = relay.qnn.op.dense( - op0, - op1, - input_zero_point=relay.const(0), - kernel_zero_point=relay.const(0), - input_scale=relay.const(0.08), - kernel_scale=relay.const(0.07), - units=None, - ) - op3 = relay.qnn.op.quantize(bias, relay.const(0.5), relay.const(0), out_dtype="int32") - op4 = relay.nn.bias_add(op2, op3) - op5 = relay.qnn.op.requantize( - op4, - input_scale=relay.const(0.05), - input_zero_point=relay.const(0), - output_scale=relay.const(0.212), - output_zero_point=relay.const(10), - out_dtype="int8", - ) - relay_mod = tvm.IRModule.from_expr(op5) + dtype = tvm.testing.parameter("uint8", "int8") + n_dim = tvm.testing.parameter(64, 60) - # Compile for Hexagon - hexagon_lowered = build_hexagon_module(relay_mod) + @tvm.testing.requires_hexagon + def test_qnn_dense_add_requantize(self, hexagon_session: Session, dtype, n_dim): + """Check lowering of qnn.dense + bias_add + qnn.requantize + dtype: type of weights + n_dim: N dimension of weights, need to check cases when it is multiple of 32 and not. + """ + data_shape = [128, 32] + weight_shape = [n_dim, 32] + bias_shape = [n_dim] + data = relay.var("data", shape=data_shape, dtype="uint8") + weight = relay.var("weight", shape=weight_shape, dtype=dtype) + bias = relay.var("bias", shape=bias_shape, dtype="int32") - # Reference compilation - llvm_lowered = build_ref_module(relay_mod) + op0 = relay.qnn.op.dense( + data, + weight, + input_zero_point=relay.const(2), + kernel_zero_point=relay.const(0), + input_scale=relay.const(0.08), + kernel_scale=relay.const(0.07), + units=None, + ) + op1 = relay.nn.bias_add(op0, bias) + op2 = relay.qnn.op.requantize( + op1, + input_scale=relay.const(1.3), + input_zero_point=relay.const(4), + output_scale=relay.const(3.7), + output_zero_point=relay.const(1), + out_dtype="uint8", + ) + relay_mod = tvm.IRModule.from_expr(op2) - data_np = np.random.rand(*data_shape) - 0.5 - weight_np = np.random.rand(*weight_shape) - 0.5 - bias_np = np.random.rand(*bias_shape) - inputs = {"data": data_np, "weight": weight_np, "bias": bias_np} + # Compile for Hexagon + hexagon_lowered = build_hexagon_module(relay_mod) - hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered) - hexagon_output = execute(hx_m, inputs) + # Reference compilation + llvm_lowered = build_ref_module(relay_mod) - dev = tvm.cpu(0) - llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev)) - llvm_out = execute(llvm_m, inputs) + np.random.seed(0) + + data_np = np.random.randint(2, 8, size=data_shape, dtype="uint8") + weight_np = np.random.randint(0, 8, size=weight_shape, dtype=dtype) + bias_np = np.random.randint(-10, 10, size=bias_shape, dtype="int32") + inputs = {"data": data_np, "weight": weight_np, "bias": bias_np} - # Diff by 1 is Ok. - tvm.testing.assert_allclose(hexagon_output, llvm_out, atol=1) + hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered) + hexagon_output = execute(hx_m, inputs) + + llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](tvm.cpu(0))) + llvm_out = execute(llvm_m, inputs) + + # Diff by 1 is Ok. + tvm.testing.assert_allclose(hexagon_output, llvm_out, atol=1) + + @tvm.testing.requires_hexagon + def test_qnn_dense_requantize(self, hexagon_session: Session): + """Check lowering of qnn.dense + qnn.requantize + Checkint the case: data type = "uint8", weight type = "int8", input zp = 0 and kernel zp = 0 + """ + data_shape = [128, 32] + weight_shape = [64, 32] + data = relay.var("data", shape=data_shape, dtype="uint8") + weight = relay.var("weight", shape=weight_shape, dtype="int8") + + op0 = relay.qnn.op.dense( + data, + weight, + input_zero_point=relay.const(0), + kernel_zero_point=relay.const(0), + input_scale=relay.const(0.06), + kernel_scale=relay.const(0.19), + units=64, + ) + op1 = relay.qnn.op.requantize( + op0, + input_scale=relay.const(0.1), + input_zero_point=relay.const(0), + output_scale=relay.const(0.24), + output_zero_point=relay.const(64), + out_dtype="uint8", + ) + relay_mod = tvm.IRModule.from_expr(op1) + + # Compile for Hexagon + hexagon_lowered = build_hexagon_module(relay_mod) + + # Reference compilation + llvm_lowered = build_ref_module(relay_mod) + + np.random.seed(0) + + data_np = np.random.randint(0, 8, size=data_shape, dtype="uint8") + weight_np = np.random.randint(-4, 4, size=weight_shape, dtype="int8") + inputs = {"data": data_np, "weight": weight_np} + + hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered) + hexagon_output = execute(hx_m, inputs) + + llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](tvm.cpu(0))) + llvm_out = execute(llvm_m, inputs) + + # Diff by 1 is Ok. + tvm.testing.assert_allclose(hexagon_output, llvm_out, atol=1) class TestQnnBinaryOp: diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index c64b30a2128b..73ba9c22082f 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -23,6 +23,7 @@ from tvm.contrib import graph_executor from tvm.relay import transform, analysis from tvm.relay.testing.temp_op_attr import TempOpAttr +from tvm.relay.testing import run_infer_type def alpha_equal(x, y): @@ -296,7 +297,98 @@ def _get_mod(data_dtype, kernel_dtype): assert "cast" in legalized_mod.astext() and "qnn" in legalized_mod.astext() +def test_qnn_legalize_qnn_conv2d_non_scalar_qnn_params(): + """ + Test QNN legalization for qnn.conv2d op for Hexagon target when kernel zero point and kernel + scale are vectors of scalars. + """ + data_shape = (1, 29, 16, 16) + weights_shape = (60, 29, 3, 3) + O, I = weights_shape[0], weights_shape[1] + data = relay.var("data", shape=data_shape, dtype="uint8") + weights = relay.var("weight", shape=weights_shape, dtype="int8") + data_zp = relay.const(2) + data_scale = relay.const(0.15) + + def before(): + op = relay.qnn.op.conv2d( + data, + weights, + input_zero_point=data_zp, + kernel_zero_point=relay.const([1] * O), + input_scale=data_scale, + kernel_scale=relay.const([0.17] * O), + padding=[0, 0, 0, 0], + channels=O, + kernel_size=[3, 3], + ) + return op + + def expected(): + in_diff = 3 + out_diff = 4 + op0 = relay.nn.pad(weights, pad_width=[[0, 0], [0, in_diff], [0, 0], [0, 0]]) + op1 = relay.nn.pad(data, pad_width=[[0, 0], [0, in_diff], [0, 0], [0, 0]]) + op2 = relay.nn.pad(op0, pad_width=[[0, out_diff], [0, 0], [0, 0], [0, 0]]) + op3 = relay.qnn.op.conv2d( + op1, + op2, + input_zero_point=data_zp, + kernel_zero_point=relay.const([1] * O + [0] * out_diff), + input_scale=data_scale, + kernel_scale=relay.const([0.17] * O + [1.0] * out_diff), + padding=[0, 0, 0, 0], + channels=(O + out_diff), + kernel_size=[3, 3], + ) + op4 = relay.strided_slice(op3, begin=[0, 0, 0, 0], end=[1, 60, 14, 14], strides=[1]) + return op4 + + target = tvm.target.hexagon("v68") + with tvm.target.Target(target): + a = run_opt_pass(before(), relay.qnn.transform.Legalize()) + b = run_infer_type(expected()) + tvm.ir.assert_structural_equal(a, b) + + +def test_qnn_legalize_qnn_dense_non_scalar_qnn_params(): + """ + Test QNN legalization for qnn.dense op for Hexagon target when kernel zero point and kernel + scale are vectors of scalars. + """ + data_shape = (4, 16) + weights_shape = (58, 16) + N = weights_shape[0] + data = relay.var("data", shape=data_shape, dtype="uint8") + weights = relay.var("weight", shape=weights_shape, dtype="int8") + data_zp = relay.const(2) + data_scale = relay.const(0.15) + + def before(): + wzp = relay.const([1] * N) + wscale = relay.const([0.17] * N) + op = relay.qnn.op.dense(data, weights, data_zp, wzp, data_scale, wscale, units=N) + return op + + def expected(): + diff = 6 + wzp = relay.const([1] * N + [0] * diff) + wscale = relay.const([0.17] * N + [1.0] * diff) + op0 = relay.nn.pad(weights, pad_width=[[0, diff], [0, 0]]) + op1 = relay.qnn.op.dense(data, op0, data_zp, wzp, data_scale, wscale, units=(N + diff)) + op2 = relay.strided_slice(op1, begin=[0, 0], end=[data_shape[0], N], strides=[1], axes=None) + return op2 + + target = tvm.target.hexagon("v68") + with tvm.target.Target(target): + a = run_opt_pass(before(), relay.qnn.transform.Legalize()) + b = run_infer_type(expected()) + tvm.ir.assert_structural_equal(a, b) + + if __name__ == "__main__": test_qnn_legalize() test_qnn_legalize_qnn_conv2d() test_qnn_legalize_qnn_dense() + test_qnn_legalize_qnn_conv2d_non_scalar_qnn_params() + test_qnn_legalize_qnn_dense_non_scalar_qnn_params()