Skip to content

Commit

Permalink
[ETHOSN] Transpose fully connected weights
Browse files Browse the repository at this point in the history
The NPU driver stack expects weights in IO (HWIO) format, however, Relay
uses an OI representation. Although the shape of the weight tensor was
correctly changed during codegen, the values in the weights tensor were
not being transposed. This lead to an output mismatch when the output
"units" was > 1. The tests didn't catch this due to using a weights
tensor of all 1's.

Change-Id: I51b2bcd14b677280ef3b6a6845d56b7dfacc7d6a
  • Loading branch information
lhutton1 committed Oct 5, 2022
1 parent af01526 commit 1b3d315
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 55 deletions.
4 changes: 2 additions & 2 deletions src/relay/backend/contrib/ethosn/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ EthosnError ConstructNetworkVisitor::MakeFullyConnectedLayer(const Call& call,
return err;
}

auto weights = AddConstant(network_, params.weights_info, params.raw_weights).tensor;
auto bias = AddConstant(network_, params.bias_info, params.raw_bias).tensor;
auto weights = AddConstant(network_, params.weights_info, params.raw_weights->data).tensor;
auto bias = AddConstant(network_, params.bias_info, params.raw_bias->data).tensor;
try {
auto input =
AddReshape(network_, *operand_table_[call->args[0]][0], params.input_info.m_Dimensions)
Expand Down
51 changes: 26 additions & 25 deletions src/relay/backend/contrib/ethosn/ethosn_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,19 @@ EthosnError EthosnAPI::QnnConv2d(const Expr& expr, ConvolutionParams* params) {
return err;
}

Constant TransposeWeights(const Constant& data, const std::string& input_layout,
const std::string& target_layout) {
Array<Integer> transpose_matrix;
for (const char& c : target_layout) {
int pos = input_layout.find(c);
transpose_matrix.push_back(pos);
}
Expr transpose = MakeTranspose(data, transpose_matrix);
transpose = InferType(FoldConstantExpr(transpose));
Constant transposed_data = Downcast<Constant>(transpose);
return transposed_data;
}

EthosnError EthosnAPI::QnnFullyConnected(const Expr& expr, FullyConnectedParams* params) {
Call requantize = Downcast<Call>(expr);
Call bias_add = Downcast<Call>(requantize->args[0]);
Expand All @@ -197,7 +210,8 @@ EthosnError EthosnAPI::QnnFullyConnected(const Expr& expr, FullyConnectedParams*
sl::QuantizationInfo output_q_info;
err += Tvm2Npu(input_zero_point, input_scale, &data_q_info);
err += Tvm2Npu(kernel_zero_point, kernel_scale, &weights_q_info);
err += Tvm2Npu(0, data_q_info.GetScale() * weights_q_info.GetScale(), &bias_q_info);
std::valarray<float> bias = data_q_info.GetScale() * weights_q_info.GetScales();
err += Tvm2Npu(0, bias, 3, &bias_q_info);
err += Tvm2Npu(output_zero_point, output_scale, &output_q_info);

// Create fc info
Expand All @@ -213,27 +227,29 @@ EthosnError EthosnAPI::QnnFullyConnected(const Expr& expr, FullyConnectedParams*
data_data_type, sl::DataFormat::NHWC, data_q_info);

// Create weights info
const auto* weights_dtype = dense->args[1]->checked_type().as<TensorTypeNode>();
Constant weights_data = Downcast<Constant>(dense->args[1]);
weights_data = TransposeWeights(weights_data, "OI", "IO");
const auto* weights_ttype = weights_data->checked_type().as<TensorTypeNode>();
sl::TensorShape weights_tensor_shape;
sl::DataType weights_data_type;
sl::DataFormat weights_data_format;
// Ignore the error here because weights don't have a batch axis
Tvm2Npu(weights_dtype->shape, &weights_tensor_shape);
err += Tvm2Npu(weights_dtype->dtype, &weights_data_type);
Tvm2Npu(weights_ttype->shape, &weights_tensor_shape);
err += Tvm2Npu(weights_ttype->dtype, &weights_data_type);
err += Tvm2Npu("HWIO", &weights_data_format);
params->weights_info = sl::TensorInfo({1, 1, weights_tensor_shape[1], weights_tensor_shape[0]},
params->weights_info = sl::TensorInfo({1, 1, weights_tensor_shape[0], weights_tensor_shape[1]},
weights_data_type, weights_data_format, weights_q_info);
params->raw_weights = dense->args[1].as<ConstantNode>()->data->data;
params->raw_weights = weights_data->data;

// Create bias info
params->bias_info =
sl::TensorInfo({1, 1, 1, weights_tensor_shape[0]}, sl::DataType::INT32_QUANTIZED,
sl::TensorInfo({1, 1, 1, weights_tensor_shape[1]}, sl::DataType::INT32_QUANTIZED,
sl::DataFormat::NHWC, bias_q_info);
params->raw_bias = bias_add->args[1].as<ConstantNode>()->data->data;
params->raw_bias = bias_add->args[1].as<ConstantNode>()->data;

sl::TensorInfo output_tensor_info;
err += Tvm2Npu(requantize->checked_type(), &output_tensor_info);
output_tensor_info.m_Dimensions = {data_tensor_shape[0], 1, 1, weights_tensor_shape[0]};
output_tensor_info.m_Dimensions = {data_tensor_shape[0], 1, 1, weights_tensor_shape[1]};
output_tensor_info.m_QuantizationInfo = output_q_info;
params->output_info = output_tensor_info;

Expand Down Expand Up @@ -449,21 +465,6 @@ EthosnError EthosnAPI::Mean(const Expr& expr, MeanParams* params) {
return err;
}

Constant TransposeWeights(const Constant& data, const std::string& input_layout) {
int pos_h = input_layout.find("H");
int pos_w = input_layout.find("W");
int pos_i = input_layout.find("I");
int pos_o = input_layout.find("O");

// Currently the expected target layout is HWIO only.
Array<Integer> target_shape = {pos_h, pos_w, pos_i, pos_o};

Expr transpose = MakeTranspose(data, target_shape);
transpose = InferType(FoldConstantExpr(transpose));
Constant transposed_data = Downcast<Constant>(transpose);
return transposed_data;
}

EthosnError EthosnAPI::QnnConv2dTranspose(const Expr& expr, QnnConv2dTransposeParams* params) {
Call requantize = Downcast<Call>(expr);
Call bias;
Expand Down Expand Up @@ -530,7 +531,7 @@ EthosnError EthosnAPI::QnnConv2dTranspose(const Expr& expr, QnnConv2dTransposePa
// Create weights info
Constant weights_data = Downcast<Constant>(conv2d_transpose->args[1]);
if (conv_attr->kernel_layout != "HWIO") {
weights_data = TransposeWeights(weights_data, conv_attr->kernel_layout);
weights_data = TransposeWeights(weights_data, conv_attr->kernel_layout, "HWIO");
}
const auto* weights_ttype = weights_data->checked_type().as<TensorTypeNode>();
sl::TensorShape weights_tensor_shape;
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/contrib/ethosn/ethosn_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ struct FullyConnectedParams {
sl::TensorInfo weights_info;
sl::TensorInfo bias_info;
sl::TensorInfo output_info;
void* raw_weights = nullptr;
void* raw_bias = nullptr;
runtime::NDArray raw_weights;
runtime::NDArray raw_bias;
};

struct MaxPool2DParams {
Expand Down
56 changes: 30 additions & 26 deletions tests/python/contrib/test_ethosn/test_fullyconnected.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

import numpy as np
import pytest

import tvm
from tvm import relay
from tvm.testing import requires_ethosn

from . import infrastructure as tei


Expand All @@ -30,7 +32,11 @@ def _get_model(
):
"""Return a model an any parameters it may have"""
a = relay.var("a", shape=shape, dtype=dtype)
weights_array = tvm.nd.array(np.ones(weight_shape, dtype))
weights_array = tvm.nd.array(
np.random.randint(
np.iinfo(dtype).min, high=np.iinfo(dtype).max, size=weight_shape, dtype=dtype
)
)
weights = relay.const(weights_array, dtype)
dense = relay.qnn.op.dense(
a,
Expand Down Expand Up @@ -66,26 +72,24 @@ def _get_model(
((1, 1280), 1000),
],
)
@pytest.mark.parametrize(
"dtype,input_zp,input_sc,kernel_zp,kernel_sc",
[
("uint8", 71, 0.580, 176, 1.498),
("uint8", 166, 1.724, 138, 0.180),
("int8", 71, 0.580, 0, 1.498),
("int8", 120, 1.724, 0, 0.180),
],
)
def test_fullyconnected(shape, out_channels, dtype, input_zp, input_sc, kernel_zp, kernel_sc):
@pytest.mark.parametrize("dtype", ["uint8", "int8"])
def test_fullyconnected(shape, out_channels, dtype):
"""Compare Fully Connected output with TVM."""

np.random.seed(0)
iinfo = np.iinfo(dtype)
data_min = iinfo.min
data_max = iinfo.max

inputs = {
"a": tvm.nd.array(
np.random.randint(np.iinfo(dtype).min, np.iinfo(dtype).max + 1, size=shape, dtype=dtype)
),
"a": tvm.nd.array(np.random.randint(data_min, data_max + 1, size=shape, dtype=dtype)),
}

outputs = []

input_zp = np.random.randint(data_min, data_max)
input_sc = np.random.random() * 2
kernel_zp = np.random.randint(data_min, data_max)
kernel_sc = np.random.random() * 2
output_zp, output_sc = tei.get_conv2d_qnn_params(
dtype,
input_zp,
Expand All @@ -96,18 +100,18 @@ def test_fullyconnected(shape, out_channels, dtype, input_zp, input_sc, kernel_z
shape[1],
1,
)
model, params = _get_model(
shape,
(out_channels, shape[1]),
input_zp,
input_sc,
kernel_zp,
kernel_sc,
output_zp,
output_sc,
dtype,
)
for npu in [False, True]:
model, params = _get_model(
shape,
(out_channels, shape[1]),
input_zp,
input_sc,
kernel_zp,
kernel_sc,
output_zp,
output_sc,
dtype,
)
mod = tei.make_module(model, params)
outputs.append(tei.build_and_run(mod, inputs, 1, params, npu=npu))
tei.verify(outputs, dtype, 1)
Expand Down

0 comments on commit 1b3d315

Please sign in to comment.