diff --git a/src/frontends/pytorch/src/op/hann_widow.cpp b/src/frontends/pytorch/src/op/hann_widow.cpp new file mode 100644 index 00000000000000..56a223211d7516 --- /dev/null +++ b/src/frontends/pytorch/src/op/hann_widow.cpp @@ -0,0 +1,69 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/sin.hpp" +#include "openvino/op/subtract.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_hann_window(const NodeContext& context) { + // aten::hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? + // pin_memory=None) -> Tensor + // aten::hann_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, + // Device? device=None, bool? pin_memory=None) -> Tensor + // aten::hann_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + // aten::hann_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + num_inputs_check(context, 1, 6); + auto window_size = context.get_input(0); + bool periodic = true; + auto num_inputs = context.get_input_size(); + if ((num_inputs == 3 || num_inputs == 6) && !context.input_is_none(1)) { + periodic = context.const_input(1); + } + auto zero_f = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0})); + auto one_f = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1})); + auto window_size_f = context.mark_node(std::make_shared(window_size, element::f32)); + auto range = context.mark_node(std::make_shared(zero_f, window_size_f, one_f, ov::element::f32)); + auto pi = context.mark_node(v0::Constant::create(ov::element::f32, Shape{}, {static_cast(M_PI)})); + auto output = context.mark_node(std::make_shared(range, pi)); + auto factor = window_size_f; + if (!periodic) { + factor = context.mark_node(std::make_shared(window_size_f, one_f)); + } + output = context.mark_node(std::make_shared(output, factor)); + auto sin = context.mark_node(std::make_shared(output)); + Output squared_sin = context.mark_node(std::make_shared(sin, sin)); + if (num_inputs > 3) { + size_t dtype_id = num_inputs == 5 ? 1 : 2; + if (!context.input_is_none(dtype_id)) { + squared_sin = apply_dtype(context, dtype_id, squared_sin); + } + } + if (num_inputs <= 3) { + size_t out_id = num_inputs == 3 ? 2 : 1; + if (!context.input_is_none(out_id)) { + squared_sin = context.mark_node(std::make_shared(squared_sin, context.get_input(out_id))); + context.mutate_input(out_id, squared_sin); + } + } + return {squared_sin}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index ef53c75d0fe369..9cc73f854bbf6a 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -100,6 +100,7 @@ OP_CONVERTER(translate_glu); OP_CONVERTER(translate_grid_sampler); OP_CONVERTER(translate_group_norm); OP_CONVERTER(translate_gru); +OP_CONVERTER(translate_hann_window); OP_CONVERTER(translate_hardtanh); OP_CONVERTER(translate_if); OP_CONVERTER(translate_im2col); @@ -479,6 +480,7 @@ const std::map get_supported_ops_ts() { {"aten::group_norm", op::translate_group_norm}, {"aten::gru", op::translate_gru}, {"aten::gt", op::translate_1to1_match_2_inputs_align_types}, + {"aten::hann_window", op::translate_hann_window}, {"aten::hardsigmoid", op::quantizable_op>}, {"aten::hardsigmoid_", op::quantizable_op>>}, diff --git a/tests/layer_tests/pytorch_tests/test_hann_window.py b/tests/layer_tests/pytorch_tests/test_hann_window.py new file mode 100644 index 00000000000000..2af8f6dfd71bfe --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_hann_window.py @@ -0,0 +1,85 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest, skip_if_export + + +class TestHannWindow(PytorchLayerTest): + def _prepare_input(self, window_size, out=False, out_dtype="float32"): + import numpy as np + + if not out: + return (np.array(window_size),) + return (np.array(window_size), np.zeros((window_size,), dtype=out_dtype)) + + def create_model(self, periodic, dtype, out): + import torch + + dtype_mapping = { + "float32": torch.float32, + "float64": torch.float64, + "float16": torch.float16 + } + + torch_dtype = dtype_mapping.get(dtype) + + class aten_hann_window(torch.nn.Module): + def __init__(self, periodic, dtype, out): + super(aten_hann_window, self).__init__() + self.periodic = periodic + self.dtype = dtype + + if out: + self.forward = self.forward_out if periodic is None else self.forward_periodic_out + elif dtype: + self.forward = self.forward_dtype if periodic is None else self.forward_dtype_periodic + elif periodic is not None: + self.forward = self.forward_periodic + + def forward(self, x): + return torch.hann_window(x) + + def forward_out(self, x, out): + return torch.hann_window(x, out=out) + + def forward_periodic_out(self, x, out): + return torch.hann_window(x, periodic=self.periodic, out=out) + + def forward_dtype(self, x): + return torch.hann_window(x, dtype=self.dtype) + + def forward_dtype_periodic(self, x): + return torch.hann_window(x, periodic=self.periodic, dtype=self.dtype) + + def forward_periodic(self, x): + return torch.hann_window(x, periodic=self.periodic) + + ref_net = None + + return aten_hann_window(periodic, torch_dtype, out), ref_net, "aten::hann_window" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("window_size", [2, 10, 32]) + @pytest.mark.parametrize(("dtype", "out", "out_dtype", "periodic"), [ + [None, False, None, None], + [None, True, "float32", None], + [None, True, "float64", None], + [None, True, "float32", False], + [None, True, "float64", False], + [None, True, "float32", True], + [None, True, "float64", True], + [None, False, "", False], + [None, False, "", True], + ["float32", False, "", None], + ["float64", False, "", None], + ["float32", False, "", False], + ["float64", False, "", False], + ["float32", False, "", True], + ["float64", False, "", True], + ]) + def test_hann_window(self, window_size, dtype, out, out_dtype, periodic, ie_device, precision, ir_version): + self._test(*self.create_model(periodic, dtype, out), ie_device, precision, + ir_version, kwargs_to_prepare_input={"window_size": window_size, "out": out, "out_dtype": out_dtype}) \ No newline at end of file