forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request openvinotoolkit#55 from sikorsl1/lsikorsk/add_aten…
…_norm Add aten::norm operator and layer test
- Loading branch information
Showing
3 changed files
with
91 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// Copyright (C) 2018-2022 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/frontend/pytorch/node_context.hpp" | ||
#include "openvino/opsets/opset8.hpp" | ||
#include "utils.hpp" | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace pytorch { | ||
namespace op { | ||
|
||
OutputVector translate_norm(NodeContext& context) { | ||
auto input_tensor = context.get_input(0); | ||
auto p = context.const_input<float>(1); | ||
auto dim = context.get_input(2); | ||
auto keep_dim = context.const_input<bool>(3); | ||
|
||
OutputVector res; | ||
|
||
if (p == 1) { | ||
auto reduce_l1 = context.mark_node(std::make_shared<opset8::ReduceL1>(input_tensor, dim, keep_dim)); | ||
res.push_back(reduce_l1); | ||
} else if (p == 2) { | ||
auto reduce_l2 = context.mark_node(std::make_shared<opset8::ReduceL2>(input_tensor, dim, keep_dim)); | ||
res.push_back(reduce_l2); | ||
} else if (p == std::numeric_limits<float>::infinity()) { | ||
auto abs = context.mark_node(std::make_shared<opset8::Abs>(input_tensor)); | ||
auto max = context.mark_node(std::make_shared<opset8::ReduceMax>(abs, dim, keep_dim)); | ||
res.push_back(max); | ||
} else if (p == -std::numeric_limits<float>::infinity()) { | ||
auto abs = context.mark_node(std::make_shared<opset8::Abs>(input_tensor)); | ||
auto min = context.mark_node(std::make_shared<opset8::ReduceMin>(abs, dim, keep_dim)); | ||
res.push_back(min); | ||
} else { | ||
auto const_p = context.mark_node(opset8::Constant::create(element::f64, Shape{1}, {p})); | ||
auto const_p_inv = context.mark_node(opset8::Constant::create(element::f64, Shape{1}, {1.0 / p})); | ||
auto abs = context.mark_node(std::make_shared<opset8::Abs>(input_tensor)); | ||
auto pow = context.mark_node(std::make_shared<opset8::Power>(abs, const_p)); | ||
auto sum = context.mark_node(std::make_shared<opset8::ReduceSum>(pow, dim, keep_dim)); | ||
auto pow_inv = context.mark_node(std::make_shared<opset8::Power>(sum, const_p_inv)); | ||
res.push_back(pow_inv); | ||
} | ||
|
||
return res; | ||
}; | ||
|
||
} // namespace op | ||
} // namespace pytorch | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (C) 2018-2022 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
from pytorch_layer_test_class import PytorchLayerTest | ||
import numpy as np | ||
import torch | ||
|
||
|
||
@pytest.mark.parametrize('p', [-2, -1, 0, 1, 2, 2.5, float('inf'), float('-inf')]) | ||
@pytest.mark.parametrize('dim', [[0], [0, 1], [0, 1, 2]]) | ||
@pytest.mark.parametrize('keepdim', [True, False]) | ||
class TestNorm(PytorchLayerTest): | ||
|
||
def _prepare_input(self): | ||
return (np.random.randn(2, 3, 4, 5), ) | ||
|
||
def create_model(self, p, dim, keepdim): | ||
class aten_norm(torch.nn.Module): | ||
|
||
def __init__(self, p, dim, keepdim) -> None: | ||
super().__init__() | ||
self.p = p | ||
self.dim = dim | ||
self.keepdim = keepdim | ||
|
||
def forward(self, input_data): | ||
return torch.norm(input_data, self.p, self.dim, self.keepdim) | ||
|
||
ref_net = None | ||
|
||
return aten_norm(p, dim, keepdim), ref_net, "aten::norm" | ||
|
||
@pytest.mark.nightly | ||
def test_norm(self, ie_device, precision, ir_version, p, dim, keepdim): | ||
self._test(*self.create_model(p, dim, keepdim), | ||
ie_device, precision, ir_version) |