Skip to content

Commit

Permalink
[Op] ISTFT op class, shape infer and reference (#28606)
Browse files Browse the repository at this point in the history
### Details:
 - ISTFT op class and 
- ISTFT shape inference (tested by type_prop, CPU registration and CPU
shape infer tba with CPU enablement PR)
 - ISTFT reference implementation

### Tickets:
 - 159379, 159380
 
 ### Related PRs:
 - #28807
 - #28743
  • Loading branch information
mitruska authored Feb 5, 2025
1 parent e068aef commit e4611ea
Show file tree
Hide file tree
Showing 15 changed files with 1,731 additions and 1 deletion.
62 changes: 62 additions & 0 deletions src/core/include/openvino/op/istft.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"

namespace ov::op::v16 {
/// \brief An operation ISTFT that computes the Inverse Short Time Fourier Transform.
/// \ingroup ov_ops_cpp_api
class OPENVINO_API ISTFT : public Op {
public:
OPENVINO_OP("ISTFT", "opset16");
ISTFT() = default;

/// \brief Constructs an ISTFT operation with signal length to be inferred
///
/// \param data Input data
/// \param window Window values applied in ISTFT
/// \param frame_size Scalar value representing the size of Fourier Transform
/// \param frame_step The distance (number of samples) between successive window frames
/// \param center Flag signaling if the signal input has been padded before STFT
/// \param normalized Flag signaling if the STFT result has been normalized
ISTFT(const Output<Node>& data,
const Output<Node>& window,
const Output<Node>& frame_size,
const Output<Node>& frame_step,
const bool center,
const bool normalized);

/// \brief Constructs an ISTFT operation with signal length provided
///
/// \param data Input data
/// \param window Window values applied in ISTFT
/// \param frame_size Scalar value representing the size of Fourier Transform
/// \param frame_step The distance (number of samples) between successive window frames
/// \param signal_length The signal length of the original signal
/// \param center Flag signaling if the signal input has been padded before STFT
/// \param normalized Flag signaling if the STFT result has been normalized
ISTFT(const Output<Node>& data,
const Output<Node>& window,
const Output<Node>& frame_size,
const Output<Node>& frame_step,
const Output<Node>& signal_length,
const bool center,
const bool normalized);

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

bool get_center() const;
void set_center(const bool center);

bool get_normalized() const;

private:
bool m_center = false;
bool m_normalized = false;
};
} // namespace ov::op::v16
1 change: 1 addition & 0 deletions src/core/include/openvino/op/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
#include "openvino/op/is_finite.hpp"
#include "openvino/op/is_inf.hpp"
#include "openvino/op/is_nan.hpp"
#include "openvino/op/istft.hpp"
#include "openvino/op/less.hpp"
#include "openvino/op/less_eq.hpp"
#include "openvino/op/log.hpp"
Expand Down
1 change: 1 addition & 0 deletions src/core/include/openvino/opsets/opset16_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ _OPENVINO_OP_REG(ShapeOf, ov::op::v3)

// New operations added in opset16
_OPENVINO_OP_REG(Identity, ov::op::v16)
_OPENVINO_OP_REG(ISTFT, ov::op::v16)
22 changes: 22 additions & 0 deletions src/core/reference/include/openvino/reference/istft.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/core/shape.hpp"

namespace ov {
namespace reference {
void istft(const float* in_data,
const float* window,
float* final_result,
const Shape& signal_shape,
const Shape& window_shape,
const int64_t frame_size,
const int64_t frame_step,
const int64_t length,
const bool center,
const bool normalized);
} // namespace reference
} // namespace ov
131 changes: 131 additions & 0 deletions src/core/reference/src/op/istft.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/reference/istft.hpp"

#include <algorithm>
#include <functional>
#include <vector>

#include "openvino/core/shape.hpp"
#include "openvino/reference/add.hpp"
#include "openvino/reference/irdft.hpp"
#include "openvino/reference/transpose.hpp"

namespace ov {
namespace reference {
void istft(const float* in_data,
const float* window,
float* final_result,
const Shape& data_shape,
const Shape& window_shape,
const int64_t frame_size,
const int64_t frame_step,
const int64_t length,
const bool center,
const bool normalized) {
const auto is_data_3D = data_shape.size() == 3;
const size_t frames_axis = 1 + (is_data_3D ? 0 : 1);
const size_t batch_size = is_data_3D ? 1 : data_shape[0];

const auto sqrt_frame_size = static_cast<float>(std::sqrt(frame_size));
const auto num_frames = data_shape[frames_axis];

const auto signal_length = (num_frames - 1) * frame_step + frame_size;
const int64_t final_signal_length = length > 0 ? length : (center ? (signal_length - frame_size) : signal_length);
std::fill(final_result, final_result + batch_size * final_signal_length, 0.f);

std::vector<float> mid_result(batch_size * signal_length, 0.f);
float* result = mid_result.data();

const auto frame_size_dim = static_cast<size_t>(frame_size);
const auto frame_size_dim_shape = Shape{frame_size_dim};
const auto frame_size_dim_shape_out = Shape{frame_size_dim, 2};
const auto fft_out_shape = Shape{static_cast<size_t>((frame_size_dim / 2) + 1), 2};

const auto window_length = window_shape[0] < frame_size_dim ? window_shape[0] : frame_size_dim;
std::vector<float> pad_window(frame_size, 0);
std::copy(window, window + window_shape[0], pad_window.begin() + (frame_size_dim - window_length) / 2);

std::vector<float> data_t(in_data, in_data + shape_size(data_shape));
const auto stft_transp_out_shape = Shape{batch_size, num_frames, fft_out_shape[0], fft_out_shape[1]};
transpose(reinterpret_cast<const char*>(in_data),
reinterpret_cast<char*>(data_t.data()),
Shape{batch_size, fft_out_shape[0], num_frames, fft_out_shape[1]},
sizeof(float),
{0, 2, 1, 3},
stft_transp_out_shape);

// Setting function for the result postprocessing
const auto norm_window_div = [sqrt_frame_size](float a, float b) {
if (b != 0.f)
return (a * sqrt_frame_size) / b;
else
return 0.f;
};
const auto window_div = [](float a, float b) {
if (b != 0.f)
return a / b;
else
return 0.f;
};
std::function<float(float, float)> postprocess_func;
if (normalized) {
postprocess_func = norm_window_div;
} else {
postprocess_func = window_div;
}

const auto fft_out_shape_size = shape_size(fft_out_shape);
const auto in_batch_single_step = num_frames * fft_out_shape_size;
const int64_t margin = center ? (frame_size / 2) : 0;
const int64_t data_end = signal_length - margin;
const int64_t copy_end = final_signal_length < data_end ? final_signal_length : data_end;

std::vector<float> window_sum(batch_size * signal_length);
std::vector<float> frame_signal(frame_size);

for (size_t batch = 0, batch_in_start = 0, batch_out_start = 0; batch < batch_size; ++batch) {
for (size_t frame_idx = 0; frame_idx < num_frames; ++frame_idx) {
const auto in_frame_start = batch_in_start + frame_idx * fft_out_shape_size;
const auto in_frame_end = in_frame_start + fft_out_shape_size;

const auto out_frame_start = batch_out_start + frame_idx * frame_step;
const auto out_frame_end = out_frame_start + frame_size;

std::vector<float> frame_data(data_t.data() + in_frame_start, data_t.data() + in_frame_end);
reference::irdft(frame_data,
fft_out_shape,
{0},
frame_signal.data(),
frame_size_dim_shape_out,
frame_size_dim_shape,
frame_size);

std::transform(frame_signal.begin(),
frame_signal.end(),
mid_result.begin() + out_frame_start,
mid_result.begin() + out_frame_start,
func::add<float>);

std::transform(window_sum.begin() + out_frame_start,
window_sum.begin() + out_frame_end,
pad_window.begin(),
window_sum.begin() + out_frame_start,
func::add<float>);
}

std::transform(result, result + signal_length, window_sum.begin(), result, postprocess_func);

const auto result_start = result + margin;
std::copy(result_start, result_start + copy_end, final_result);

batch_in_start += in_batch_single_step;
batch_out_start += signal_length;
result += batch_out_start;
final_result += final_signal_length;
}
}
} // namespace reference
} // namespace ov
119 changes: 119 additions & 0 deletions src/core/shape_inference/include/istft_shape_inference.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "dimension_util.hpp"
#include "openvino/op/istft.hpp"
#include "utils.hpp"

namespace ov::op::v16 {
template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> shape_infer(const ISTFT* op,
const std::vector<TShape>& input_shapes,
const ITensorAccessor& ta = make_tensor_accessor()) {
using TDim = typename TRShape::value_type;
using TDimVal = typename TDim::value_type;

const auto inputs_count = input_shapes.size();
const auto is_in_count_correct = inputs_count == 4 || inputs_count == 5;
NODE_VALIDATION_CHECK(op, is_in_count_correct);

const auto& data_shape = input_shapes[0];
const auto& window_shape = input_shapes[1];
const auto& frame_size_shape = input_shapes[2];
const auto& frame_step_shape = input_shapes[3];

const auto data_shape_rank = data_shape.rank();
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
data_shape_rank.compatible(3) || data_shape_rank.compatible(4),
"The shape of data must be 3D or 4D.");
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
window_shape.rank().compatible(1),
"The shape of window must be 1D [window_size].");
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
frame_size_shape.rank().compatible(0),
"The shape of frame_size must be a scalar.");
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
frame_step_shape.rank().compatible(0),
"The shape of frame_step must be a scalar.");

const auto frame_size = get_input_const_data_as<TRShape, int64_t>(op, 2, ta);
const auto frame_step = get_input_const_data_as<TRShape, int64_t>(op, 3, ta);

if (frame_size) {
const auto& frame_size_val = (*frame_size)[0];
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
0 < frame_size_val,
"Provided frame size must be greater than zero, but got: ",
frame_size_val);
const bool is_win_shape_correct =
window_shape.is_dynamic() || (TDimVal{0} < window_shape[0].get_length() &&
window_shape[0].get_length() <= static_cast<TDimVal>(frame_size_val));

NODE_SHAPE_INFER_CHECK(op,
input_shapes,
is_win_shape_correct,
"Window input dimension must be in range [1, ",
frame_size_val,
"].");
}

if (frame_step) {
const auto& frame_step_val = (*frame_step)[0];
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
0 < frame_step_val,
"Provided frame step must be greater than zero, but got: ",
frame_step_val);
}

// For the input with dynamic rank, output shape is also fully dynamic
if (data_shape_rank.is_dynamic()) {
return {data_shape};
}
const auto is_data_3D = data_shape.size() == 3;

// Init output shape with dynamic dimension and update if more info can be inferred
std::vector<TRShape> output_shapes{TRShape{TDim(ov::util::dim::inf_bound)}};
if (inputs_count == 5) {
const auto& length_shape = input_shapes[4];
const bool has_len_valid_shape =
length_shape.rank().is_dynamic() ||
(length_shape.size() == 0 || (length_shape.size() == 1 && length_shape[0].compatible(1)));
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
has_len_valid_shape,
"The shape of 'signal_length' input must be a scalar or single element 1D tensor.");

const auto sig_len_in = get_input_const_data_as_shape<TRShape>(op, 4, ta);
if (sig_len_in) { // Set desired length of the signal dimension, if provided
output_shapes[0][0] = TDim{(*sig_len_in)[0]};
}
} else if (frame_size && frame_step) { // Otherwise infer the length of the signal
const auto& frame_size_val = (*frame_size)[0];
const auto& frame_step_val = (*frame_step)[0];

const int64_t frames_axis = 1 + (is_data_3D ? 0 : 1);
const TDim& num_frames_dim = data_shape[frames_axis];
TDim signal_length = (num_frames_dim - 1) * frame_step_val;
if (!op->get_center()) {
signal_length += frame_size_val;
}
output_shapes[0][0] = std::move(signal_length);
}

if (!is_data_3D) { // Copy batch dimension
const auto& batch_dim = data_shape[0];
output_shapes[0].insert(output_shapes[0].begin(), batch_dim);
}

return output_shapes;
}
} // namespace ov::op::v16
Loading

0 comments on commit e4611ea

Please sign in to comment.