Skip to content

Commit

Permalink
Merge pull request openvinotoolkit#55 from sikorsl1/lsikorsk/add_aten…
Browse files Browse the repository at this point in the history
…_norm

Add aten::norm operator and layer test
  • Loading branch information
slyalin authored Dec 6, 2022
2 parents fc0bc93 + d67e6e5 commit a5b077d
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 0 deletions.
52 changes: 52 additions & 0 deletions src/frontends/pytorch/src/op/norm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset8.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_norm(NodeContext& context) {
auto input_tensor = context.get_input(0);
auto p = context.const_input<float>(1);
auto dim = context.get_input(2);
auto keep_dim = context.const_input<bool>(3);

OutputVector res;

if (p == 1) {
auto reduce_l1 = context.mark_node(std::make_shared<opset8::ReduceL1>(input_tensor, dim, keep_dim));
res.push_back(reduce_l1);
} else if (p == 2) {
auto reduce_l2 = context.mark_node(std::make_shared<opset8::ReduceL2>(input_tensor, dim, keep_dim));
res.push_back(reduce_l2);
} else if (p == std::numeric_limits<float>::infinity()) {
auto abs = context.mark_node(std::make_shared<opset8::Abs>(input_tensor));
auto max = context.mark_node(std::make_shared<opset8::ReduceMax>(abs, dim, keep_dim));
res.push_back(max);
} else if (p == -std::numeric_limits<float>::infinity()) {
auto abs = context.mark_node(std::make_shared<opset8::Abs>(input_tensor));
auto min = context.mark_node(std::make_shared<opset8::ReduceMin>(abs, dim, keep_dim));
res.push_back(min);
} else {
auto const_p = context.mark_node(opset8::Constant::create(element::f64, Shape{1}, {p}));
auto const_p_inv = context.mark_node(opset8::Constant::create(element::f64, Shape{1}, {1.0 / p}));
auto abs = context.mark_node(std::make_shared<opset8::Abs>(input_tensor));
auto pow = context.mark_node(std::make_shared<opset8::Power>(abs, const_p));
auto sum = context.mark_node(std::make_shared<opset8::ReduceSum>(pow, dim, keep_dim));
auto pow_inv = context.mark_node(std::make_shared<opset8::Power>(sum, const_p_inv));
res.push_back(pow_inv);
}

return res;
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ OP_CONVERTER(translate_max_pool2d);
OP_CONVERTER(translate_masked_fill);
OP_CONVERTER(translate_mean);
OP_CONVERTER(translate_neg);
OP_CONVERTER(translate_norm);
OP_CONVERTER(translate_new_full);
OP_CONVERTER(translate_new_ones);
OP_CONVERTER(translate_new_zeros);
Expand Down Expand Up @@ -145,6 +146,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::mul_", op::inplace_op<op::translate_1to1_match_2_inputs<opset8::Multiply>>},
{"aten::ne", op::translate_1to1_match_2_inputs<opset8::NotEqual>},
{"aten::neg", op::translate_neg},
{"aten::norm", op::translate_norm},
{"aten::new_full", op::translate_new_full},
{"aten::new_ones", op::translate_new_ones},
{"aten::new_zeros", op::translate_new_zeros},
Expand Down
37 changes: 37 additions & 0 deletions tests/layer_tests/pytorch_tests/test_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
from pytorch_layer_test_class import PytorchLayerTest
import numpy as np
import torch


@pytest.mark.parametrize('p', [-2, -1, 0, 1, 2, 2.5, float('inf'), float('-inf')])
@pytest.mark.parametrize('dim', [[0], [0, 1], [0, 1, 2]])
@pytest.mark.parametrize('keepdim', [True, False])
class TestNorm(PytorchLayerTest):

def _prepare_input(self):
return (np.random.randn(2, 3, 4, 5), )

def create_model(self, p, dim, keepdim):
class aten_norm(torch.nn.Module):

def __init__(self, p, dim, keepdim) -> None:
super().__init__()
self.p = p
self.dim = dim
self.keepdim = keepdim

def forward(self, input_data):
return torch.norm(input_data, self.p, self.dim, self.keepdim)

ref_net = None

return aten_norm(p, dim, keepdim), ref_net, "aten::norm"

@pytest.mark.nightly
def test_norm(self, ie_device, precision, ir_version, p, dim, keepdim):
self._test(*self.create_model(p, dim, keepdim),
ie_device, precision, ir_version)

0 comments on commit a5b077d

Please sign in to comment.