Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LPT] Legacy transformations improvements to BERT support on GPU #8

Open
wants to merge 4 commits into
base: es/lpt/lpt_to_ngraph_integration
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,14 @@ ngraph::graph_rewrite_callback get_callback() {

auto res = check_constant(const_node, data_node.get_partial_shape());

if (res == CONVERSION_RESULT::NONE || (res == CONVERSION_RESULT::SCALE_SHIFT && output_shape_rank < 4)) {
bool is_dequantization = lin_op->get_rt_info().count("DEQUANTIZATION") != 0;

if (!is_dequantization && (res == CONVERSION_RESULT::NONE || (res == CONVERSION_RESULT::SCALE_SHIFT && output_shape_rank < 4))) {
return convert_to_eltwise<T>(lin_op,
lin_op->input(0).get_source_output(),
lin_op->input(1).get_source_output());
}

bool is_dequantization = lin_op->get_rt_info().count("DEQUANTIZATION") != 0;

// TODO: if all values in Constant are equal the best way is to convert this Eltwise to Power
if (res == CONVERSION_RESULT::SCALE_SHIFT || is_dequantization) {
auto weights_et = const_node->get_element_type();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,14 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi
const auto output_shape = add_node->get_output_partial_shape(0);
const auto output_shape_rank = output_shape.rank().get_length();

bool is_dequantization =
(add_node->get_rt_info().count("DEQUANTIZATION") != 0 || mul_node->get_rt_info().count("DEQUANTIZATION") != 0);

if (res1 == CONVERSION_RESULT::NONE || res2 == CONVERSION_RESULT::NONE ||
((res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT) &&
(output_shape_rank == 1 || output_shape_rank > 4))) {
((res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT) && !is_dequantization && output_shape_rank < 4)) {
return false;
}

bool is_dequantization =
(add_node->get_rt_info().count("DEQUANTIZATION") != 0 || mul_node->get_rt_info().count("DEQUANTIZATION") != 0);

// TODO: in case if scale and shift constants has equal values the best way is to convert them to Power
if (res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT || is_dequantization) {
NodeVector new_ops;
Expand Down Expand Up @@ -176,7 +175,7 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi
new_ops.push_back(biases_in);
}

auto scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, weights_in, biases_in);
auto scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, weights_in, biases_in, weights_in->get_element_type());
new_ops.push_back(scaleshift);

scaleshift->set_friendly_name(add_node->get_friendly_name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ ngraph::pass::ReshapeFullyConnected::ReshapeFullyConnected() {
auto fc_new = std::make_shared<op::FullyConnected>(reshape,
fc->input_value(1),
fc->input_value(2),
output_shape_new);
output_shape_new,
fc->get_output_type());
new_ops.push_back(fc_new);

if (output_shape != output_shape_new) {
Expand All @@ -73,4 +74,4 @@ ngraph::pass::ReshapeFullyConnected::ReshapeFullyConnected() {

auto m = std::make_shared<ngraph::pattern::Matcher>(fc, "ReshapeFullyConnected");
this->register_matcher(m, callback);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "layer_transformation.hpp"

#include <string>
#include <sstream>
#include <memory>

#include <gtest/gtest.h>

#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/convert_opset1_to_legacy/convert_mul_or_add_finally.hpp>
#include <ngraph/pass/constant_folding.hpp>

#include "common_test_utils/ngraph_test_utils.hpp"

#include "ngraph_functions/low_precision_transformations/convert_mul_or_add_finally_with_dequantization_function.hpp"

using namespace testing;
using namespace ngraph::pass;

namespace {

inline std::ostream& operator<<(std::ostream& os, const std::vector<float>& values) {
os << "{ ";
for (size_t i = 0; i < values.size(); ++i) {
os << values[i];
if (i != (values.size() - 1ul)) {
os << ", ";
}
}
os << " }";
return os;
}

class ConvertMulOrAddFinallyTransformationWithDequantizationTestValues {
public:
std::vector<float> multiplyConstValues;
ngraph::Shape inputShape;
ngraph::element::Type inputPrecision;
ngraph::pass::low_precision::LayerTransformation::Params params;
};

using TestValuesType = ConvertMulOrAddFinallyTransformationWithDequantizationTestValues;

class ConvertMulOrAddFinallyTransformationWithDequantization : public LayerTransformation, public testing::WithParamInterface<TestValuesType> {
public:
void SetUp() override {
using namespace ngraph::builder::subgraph;
const ConvertMulOrAddFinallyTransformationWithDequantizationTestValues testValues = GetParam();

actualFunction = ConvertMulOrAddWithDequantizationFunction::getOriginal(testValues.inputShape,
testValues.inputPrecision,
testValues.multiplyConstValues);

ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::ConvertMulOrAddFinally>();
manager.register_pass<ngraph::pass::ConstantFolding>();

manager.run_passes(actualFunction);

referenceFunction = ConvertMulOrAddWithDequantizationFunction::getReference(testValues.inputShape,
testValues.inputPrecision,
testValues.multiplyConstValues);
}

static std::string getTestCaseName(testing::TestParamInfo<ConvertMulOrAddFinallyTransformationWithDequantizationTestValues> obj) {
const ConvertMulOrAddFinallyTransformationWithDequantizationTestValues testValues = obj.param;
std::ostringstream result;
result << LayerTransformation::getTestCaseNameByParams(testValues.inputPrecision, testValues.inputShape, testValues.params) << "_" <<
testValues.multiplyConstValues;
return result.str();
}
};

TEST_P(ConvertMulOrAddFinallyTransformationWithDequantization, CompareFunctions) {
actualFunction->validate_nodes_and_infer_types();
auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
ASSERT_TRUE(res.first) << res.second;
}

std::vector<ConvertMulOrAddFinallyTransformationWithDequantizationTestValues> testValues = {
{
{ -1.0 },
{ 1, 1000 },
ngraph::element::f32,
LayerTransformation::createParamsU8I8()
},
{
{ 128.0 },
{ 1, 10 },
ngraph::element::f32,
LayerTransformation::createParamsU8I8()
},
{
{ -64.5 },
{ 1, 10 },
ngraph::element::i8,
LayerTransformation::createParamsU8I8()
},
{
{ 1.2 },
{ 1, 100 },
ngraph::element::u8,
LayerTransformation::createParamsI8I8()
}
};

INSTANTIATE_TEST_CASE_P(
LPT,
ConvertMulOrAddFinallyTransformationWithDequantization,
::testing::ValuesIn(testValues),
ConvertMulOrAddFinallyTransformationWithDequantization::getTestCaseName);
} // namespace
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <sstream>
#include <vector>
#include <ngraph/ngraph.hpp>
#include "ngraph_functions/low_precision_transformations/common/fake_quantize_on_data.hpp"
#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"

namespace ngraph {
namespace builder {
namespace subgraph {

class ConvertMulOrAddWithDequantizationFunction {
public:
static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::Shape& inputShape,
const ngraph::element::Type inputPrecision,
const std::vector<float>& multiplyConst);

static std::shared_ptr<ngraph::Function> getReference(
const ngraph::Shape& inputShape,
const ngraph::element::Type inputPrecision,
const std::vector<float>& multiplyConst);
};
} // namespace subgraph
} // namespace builder
} // namespace ngraph
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "ngraph_functions/low_precision_transformations/convert_mul_or_add_finally_with_dequantization_function.hpp"

#include <memory>
#include <vector>
#include <ngraph/ngraph.hpp>


#include <ngraph/opsets/opset1.hpp>
#include <ngraph_ops/fully_connected.hpp>
#include "ngraph_functions/subgraph_builders.hpp"
#include "ngraph_functions/low_precision_transformations/common/builders.hpp"
#include "transformations/low_precision/network_helper.hpp"
#include "ngraph_ops/scaleshift.hpp"
#include "transformations/low_precision/common/dequantization_op.hpp"

namespace ngraph {
namespace builder {
namespace subgraph {

using namespace ngraph::pass;

std::shared_ptr<ngraph::Function> ConvertMulOrAddWithDequantizationFunction::getOriginal(
const ngraph::Shape& inputShape,
const ngraph::element::Type inputPrecision,
const std::vector<float>& multiplyConst) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape);
const auto reluOriginal = ngraph::opset1::Relu(
ngraph::op::TemporaryReplaceOutputType(input, element::f32).get());

std::shared_ptr<ngraph::opset1::Relu> relu = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Relu>>(
reluOriginal,
std::vector<element::Type>{ element::f32, element::f32 },
std::vector<element::Type>{});


const auto multiply = std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(relu,
std::make_shared<opset1::Constant>(element::f32, inputShape, multiplyConst));

multiply->set_friendly_name("output");

ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(multiply) };
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input },
"ConvertMulOrAddTransformationWithDequantization");
}

std::shared_ptr<ngraph::Function> ConvertMulOrAddWithDequantizationFunction::getReference(
const ngraph::Shape& inputShape,
const ngraph::element::Type inputPrecision,
const std::vector<float>& multiplyConst) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape);
const auto reluOriginal = ngraph::opset1::Relu(
ngraph::op::TemporaryReplaceOutputType(input, element::f32).get());

std::shared_ptr<ngraph::opset1::Relu> relu = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Relu>>(
reluOriginal,
std::vector<element::Type>{ element::f32, element::f32 },
std::vector<element::Type>{});

const auto weights = std::make_shared<opset1::Constant>(element::f32, inputShape, multiplyConst);
const auto bias = std::make_shared<opset1::Constant>(element::f32, inputShape, 0.0);
const auto scaleShift = std::make_shared<ngraph::op::ScaleShiftIE>(relu, weights, bias);

scaleShift->set_friendly_name("output");

ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(scaleShift) };
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "ConvertMulOrAddTransformationWithDequantization");
}

} // namespace subgraph
} // namespace builder
} // namespace ngraph