Skip to content

Commit

Permalink
[QNN] Multiplication operator
Browse files Browse the repository at this point in the history
  • Loading branch information
tristan-arm committed Sep 10, 2019
1 parent 42195a4 commit 7f9dbb0
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 0 deletions.
41 changes: 41 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,44 @@ def add(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_s
lhs_scale, lhs_zero_point,
rhs_scale, rhs_zero_point,
output_scale, output_zero_point)

def mul(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale,
output_zero_point):
"""Quantized multiplication with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side quantized input data.
rhs : relay.Expr
The right hand side quantized input data.
lhs_scale: float
The scale of the lhs quantized expr.
lhs_zero_point: int
The zero point of lhs quantized expr.
rhs_scale: float
The scale of the rhs quantized expr.
rhs_zero_point: int
The zero point of rhs quantized expr.
output_scale: float
The scale of the output quantized expr.
output_zero_point: int
The zero point of output quantized expr.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.mul(lhs, rhs,
lhs_scale, lhs_zero_point,
rhs_scale, rhs_zero_point,
output_scale, output_zero_point)
105 changes: 105 additions & 0 deletions src/relay/qnn/op/mul.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/qnn/op/mul.cc
* \brief QNN mul operator.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../pass/pattern_util.h"
#include "../util.h"
#include "op_common.h"

namespace tvm {
namespace relay {
namespace qnn {

/*
* \brief Canonicalizes the QNN mul op.
* \param attrs The QNN concatenate attrs.
* \param new_args The new mutated args to the call node.
* \param arg_types The types of input and output.
* \return The sequence of Relay ops for mul op.
*/
Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
// Get the attrs.
CHECK_EQ(new_args.size(), 2);
auto& lhs = new_args[0];
auto& rhs = new_args[1];
const auto* binary_op_attrs = attrs.as<QnnBinaryOpAttrs>();
CHECK(binary_op_attrs != nullptr);
auto lhs_scale = binary_op_attrs->lhs_scale;
auto lhs_zero_point = binary_op_attrs->lhs_zero_point;
auto rhs_scale = binary_op_attrs->rhs_scale;
auto rhs_zero_point = binary_op_attrs->rhs_zero_point;
auto output_scale = binary_op_attrs->output_scale;
auto output_zero_point = binary_op_attrs->output_zero_point;

// Get the input dtype and shape.
CHECK_EQ(arg_types.size(), 3);
auto tensor_type = arg_types[0].as<TensorTypeNode>();
auto input_dtype = tensor_type->dtype;
auto input_shape = tensor_type->shape;

// Requantize LHS if necessary.
auto requantized_lhs = lhs;
if (lhs_scale != output_scale || lhs_zero_point != output_zero_point) {
requantized_lhs = Requantize(lhs, input_shape, lhs_scale, lhs_zero_point, output_scale,
output_zero_point, Int(32));
} else {
requantized_lhs = Cast(requantized_lhs, Int(32));
}

// Requantize RHS if necessary.
auto requantized_rhs = rhs;
if (rhs_scale != output_scale || rhs_zero_point != output_zero_point) {
requantized_rhs = Requantize(rhs, input_shape, rhs_scale, rhs_zero_point, output_scale,
output_zero_point, Int(32));
} else {
requantized_rhs = Cast(requantized_rhs, Int(32));
}

auto output = Multiply(requantized_lhs, requantized_rhs);

// Subtract zero point.
if (output_zero_point != 0) {
auto output_zp = MakeConstantScalar(Int(32), output_zero_point);
output = Subtract(output, output_zp);
}

// Go back to lower precision.
auto q_min = GetQmin(input_dtype);
auto q_max = GetQmax(input_dtype);
output = Clip(output, q_min, q_max);
return Cast(output, input_dtype);
}

// QNN Multiplication operator.
QNN_REGISTER_BINARY_OP("mul")
.describe("Elementwise mul with with broadcasting for quantized tensors.")
.set_support_level(11)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnMulCanonicalize);

} // namespace qnn
} // namespace relay
} // namespace tvm
196 changes: 196 additions & 0 deletions tests/python/relay/test_qnn_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
import numpy as np
from tvm import relay
from tvm.contrib import graph_runtime
import topi.testing

def test_tflite_same_io_qnn_params():
data_dtype = 'uint8'

x = relay.var("x", shape=(1, 4), dtype=data_dtype)
y = relay.var("y", shape=(1, 4), dtype=data_dtype)
z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=0.00784314,
lhs_zero_point=127,
rhs_scale=0.00784314,
rhs_zero_point=127,
output_scale=0.00784314,
output_zero_point=127)

func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]

x_datas = [np.array((1, 153, 165, 178)).reshape((1,4)),
np.array((25, 1, 178, 216)).reshape((1,4)),
np.array((25, 153, 1, 165)).reshape((1,4))]
y_datas = [np.array((204, 178, 1, 140)).reshape((1,4)),
np.array((204, 178, 191, 1)).reshape((1,4)),
np.array((204, 178, 1, 191)).reshape((1,4))]
golden_outputs = [np.array((77, 255,38, 255)).reshape((1, 4)),
np.array((255, 51, 255, 89)).reshape((1,4)),
np.array((255, 255, 0, 255)).reshape((1,4))]

for i in range(0, 3):
x_data = x_datas[i]
y_data = y_datas[i]
golden_output = golden_outputs[i]

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)


def test_tflite_different_io_qnn_params():
data_dtype = 'uint8'

x = relay.var("x", shape=(1, 4), dtype=data_dtype)
y = relay.var("y", shape=(1, 4), dtype=data_dtype)
z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=0.0156863,
lhs_zero_point=127,
rhs_scale=0.0117647,
rhs_zero_point=85,
output_scale=0.0235294,
output_zero_point=128)

func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]

x_datas = [np.array((76, 140, 153, 172)).reshape((1,4)),
np.array((133, 140, 146, 153)).reshape((1,4)),
np.array((76, 140, 172, 146)).reshape((1,4))]
y_datas = [np.array((136, 119, 128, 17)).reshape((1,4)),
np.array((136, 119, 111, 94)).reshape((1,4)),
np.array((136, 119, 17, 128)).reshape((1,4))]
golden_outputs = [np.array((255, 255, 255, 255)).reshape((1, 4)),
np.array((255, 255, 255, 255)).reshape((1,4)),
np.array((255, 255, 255, 255)).reshape((1,4))]

for i in range(0, 3):
x_data = x_datas[i]
y_data = y_datas[i]
golden_output = golden_outputs[i]

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)


def test_saturation():
# Same params
data_dtype = 'uint8'
x = relay.var("x", shape=(1, 4), dtype=data_dtype)
y = relay.var("y", shape=(1, 4), dtype=data_dtype)
z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=0.125,
lhs_zero_point=0,
rhs_scale=0.125,
rhs_zero_point=0,
output_scale=0.125,
output_zero_point=0)

func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]

x_data = np.array((255, 1, 1, 0)).reshape((1,4))
y_data = np.array((255, 255, 128, 0)).reshape((1,4))
golden_output = np.array((255, 255, 128, 0)).reshape((1, 4))

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)

# Same params, different scale
z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=0.125,
lhs_zero_point=0,
rhs_scale=0.125,
rhs_zero_point=0,
output_scale=0.25,
output_zero_point=0)

func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]

x_data = np.array((255, 1, 1, 0)).reshape((1,4))
y_data = np.array((255, 255, 127, 0)).reshape((1,4))
golden_output = np.array((255, 128, 64, 0)).reshape((1, 4))

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)

# Same io params, different output scale
z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=0.125,
lhs_zero_point=0,
rhs_scale=0.125,
rhs_zero_point=0,
output_scale=0.25,
output_zero_point=0)

func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]

x_data = np.array((255, 1, 1, 0)).reshape((1,4))
y_data = np.array((255, 255, 127, 0)).reshape((1,4))
golden_output = np.array((255, 128, 64, 0)).reshape((1, 4))

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)

# All params different
z = relay.qnn.op.mul(lhs=x, rhs=y,
lhs_scale=0.5,
lhs_zero_point=0,
rhs_scale=0.25,
rhs_zero_point=0,
output_scale=0.125,
output_zero_point=0)

func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
func = mod["main"]

x_data = np.array((255, 0, 1, 0)).reshape((1,4))
y_data = np.array((0, 128, 64, 0)).reshape((1,4))
golden_output = np.array((0, 0, 255, 0)).reshape((1, 4))

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)


if __name__ == '__main__':
test_tflite_same_io_qnn_params()
test_tflite_different_io_qnn_params()
test_saturation()

0 comments on commit 7f9dbb0

Please sign in to comment.