Skip to content

Commit

Permalink
[PT FE]: support aten::hann_window
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jun 3, 2024
1 parent ed09df8 commit 7d4bc9a
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 0 deletions.
69 changes: 69 additions & 0 deletions src/frontends/pytorch/src/op/hann_widow.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/sin.hpp"
#include "openvino/op/subtract.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_hann_window(const NodeContext& context) {
// aten::hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool?
// pin_memory=None) -> Tensor
// aten::hann_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None,
// Device? device=None, bool? pin_memory=None) -> Tensor
// aten::hann_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)
// aten::hann_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)
num_inputs_check(context, 1, 6);
auto window_size = context.get_input(0);
bool periodic = true;
auto num_inputs = context.get_input_size();
if ((num_inputs == 3 || num_inputs == 6) && !context.input_is_none(1)) {
periodic = context.const_input<bool>(1);
}
auto zero_f = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
auto one_f = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
auto window_size_f = context.mark_node(std::make_shared<v0::Convert>(window_size, element::f32));
auto range = context.mark_node(std::make_shared<v4::Range>(zero_f, window_size_f, one_f, ov::element::f32));
auto pi = context.mark_node(v0::Constant::create(ov::element::f32, Shape{}, {static_cast<float>(M_PI)}));
auto output = context.mark_node(std::make_shared<v1::Multiply>(range, pi));
auto factor = window_size_f;
if (!periodic) {
factor = context.mark_node(std::make_shared<v1::Subtract>(window_size_f, one_f));
}
output = context.mark_node(std::make_shared<v1::Divide>(output, factor));
auto sin = context.mark_node(std::make_shared<v0::Sin>(output));
Output<Node> squared_sin = context.mark_node(std::make_shared<v1::Multiply>(sin, sin));
if (num_inputs > 3) {
size_t dtype_id = num_inputs == 5 ? 1 : 2;
if (!context.input_is_none(dtype_id)) {
squared_sin = apply_dtype(context, dtype_id, squared_sin);
}
}
if (num_inputs <= 3) {
size_t out_id = num_inputs == 3 ? 2 : 1;
if (!context.input_is_none(out_id)) {
squared_sin = context.mark_node(std::make_shared<v1::ConvertLike>(squared_sin, context.get_input(out_id)));
context.mutate_input(out_id, squared_sin);
}
}
return {squared_sin};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ OP_CONVERTER(translate_glu);
OP_CONVERTER(translate_grid_sampler);
OP_CONVERTER(translate_group_norm);
OP_CONVERTER(translate_gru);
OP_CONVERTER(translate_hann_window);
OP_CONVERTER(translate_hardtanh);
OP_CONVERTER(translate_if);
OP_CONVERTER(translate_im2col);
Expand Down Expand Up @@ -479,6 +480,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::group_norm", op::translate_group_norm},
{"aten::gru", op::translate_gru},
{"aten::gt", op::translate_1to1_match_2_inputs_align_types<opset10::Greater>},
{"aten::hann_window", op::translate_hann_window},
{"aten::hardsigmoid", op::quantizable_op<op::translate_1to1_match_1_inputs<opset10::HSigmoid>>},
{"aten::hardsigmoid_",
op::quantizable_op<op::inplace_op<op::translate_1to1_match_1_inputs<opset10::HSigmoid>>>},
Expand Down
85 changes: 85 additions & 0 deletions tests/layer_tests/pytorch_tests/test_hann_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest

from pytorch_layer_test_class import PytorchLayerTest, skip_if_export


class TestHannWindow(PytorchLayerTest):
def _prepare_input(self, window_size, out=False, out_dtype="float32"):
import numpy as np

if not out:
return (np.array(window_size),)
return (np.array(window_size), np.zeros((window_size,), dtype=out_dtype))

def create_model(self, periodic, dtype, out):
import torch

dtype_mapping = {
"float32": torch.float32,
"float64": torch.float64,
"float16": torch.float16
}

torch_dtype = dtype_mapping.get(dtype)

class aten_hann_window(torch.nn.Module):
def __init__(self, periodic, dtype, out):
super(aten_hann_window, self).__init__()
self.periodic = periodic
self.dtype = dtype

if out:
self.forward = self.forward_out if periodic is None else self.forward_periodic_out
elif dtype:
self.forward = self.forward_dtype if periodic is None else self.forward_dtype_periodic
elif periodic is not None:
self.forward = self.forward_periodic

def forward(self, x):
return torch.hann_window(x)

def forward_out(self, x, out):
return torch.hann_window(x, out=out)

def forward_periodic_out(self, x, out):
return torch.hann_window(x, periodic=self.periodic, out=out)

def forward_dtype(self, x):
return torch.hann_window(x, dtype=self.dtype)

def forward_dtype_periodic(self, x):
return torch.hann_window(x, periodic=self.periodic, dtype=self.dtype)

def forward_periodic(self, x):
return torch.hann_window(x, periodic=self.periodic)

ref_net = None

return aten_hann_window(periodic, torch_dtype, out), ref_net, "aten::hann_window"

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("window_size", [2, 10, 32])
@pytest.mark.parametrize(("dtype", "out", "out_dtype", "periodic"), [
[None, False, None, None],
[None, True, "float32", None],
[None, True, "float64", None],
[None, True, "float32", False],
[None, True, "float64", False],
[None, True, "float32", True],
[None, True, "float64", True],
[None, False, "", False],
[None, False, "", True],
["float32", False, "", None],
["float64", False, "", None],
["float32", False, "", False],
["float64", False, "", False],
["float32", False, "", True],
["float64", False, "", True],
])
def test_hann_window(self, window_size, dtype, out, out_dtype, periodic, ie_device, precision, ir_version):
self._test(*self.create_model(periodic, dtype, out), ie_device, precision,
ir_version, kwargs_to_prepare_input={"window_size": window_size, "out": out, "out_dtype": out_dtype})

0 comments on commit 7d4bc9a

Please sign in to comment.