Skip to content

Commit

Permalink
Add support for adaptive_avgpool2d operation
Browse files Browse the repository at this point in the history
  • Loading branch information
nikileshx committed Feb 3, 2025
1 parent 4034423 commit 54468f5
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/ttnn/unit_tests/operations/test_adaptive_avg_pool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0

from loguru import logger

import torch
import pytest
import math
from tests.ttnn.utils_for_testing import assert_with_pcc
import ttnn


def test_run_adaptive_avg_pool2d(device):
input_shape = [1, 4, 8, 16]
output_size = (8, 8)
dtype = ttnn.float32

torch.manual_seed(0)
torch_input_tensor = torch.randn(input_shape)
torch_output_tensor = torch.nn.functional.adaptive_avg_pool2d(torch_input_tensor, output_size)
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=dtype, device=device)
output_tensor = ttnn.adaptive_avg_pool2d(input_tensor, ttnn.Shape(output_size))

output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(torch_output_tensor, output_tensor)
2 changes: 2 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ set(TTNN_OP_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/generic/generic_pools.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/generic/generic_pools_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/global_avg_pool/global_avg_pool.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/adaptive_avg_pool/adaptive_avg_pool.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/adaptive_avg_pool/adaptive_avg_pool_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/device//upsample_bilinear_program_factory_multicore.cpp
Expand Down
2 changes: 2 additions & 0 deletions ttnn/cpp/pybind11/operations/__init__.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "ttnn/operations/pool/downsample/downsample_pybind.hpp"
#include "ttnn/operations/pool/generic/generic_pools_pybind.hpp"
#include "ttnn/operations/pool/global_avg_pool/global_avg_pool_pybind.hpp"
#include "ttnn/operations/pool/adaptive_avg_pool/adaptive_avg_pool_pybind.hpp"
#include "ttnn/operations/pool/upsample/upsample_pybind.hpp"
#include "ttnn/operations/reduction/reduction_pybind.hpp"
#include "ttnn/operations/sliding_window/sliding_window_pybind.hpp"
Expand Down Expand Up @@ -126,6 +127,7 @@ void py_module(py::module& module) {
avgpool::py_module(m_pool);
upsample::py_module(m_pool);
downsample::py_bind_downsample(m_pool);
pool::py_bind_adaptive_avg_pool(m_pool);

auto m_normalization = module.def_submodule("normalization", "normalization operations");
normalization::py_module(m_normalization);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "adaptive_avg_pool.hpp"
#include "ttnn/operations/core/core.hpp"

namespace ttnn::operations::pool {

ttnn::Tensor AdaptiveAvgPool2DOperation::invoke(
const ttnn::Tensor& input, const ttnn::Shape& output_size, const std::optional<MemoryConfig>& mem_config) {
TT_FATAL(input.storage_type() == StorageType::DEVICE, "Input tensor must be on device");

auto input_mem_config = input.memory_config();
auto input_shape = input.get_logical_shape();
auto input_height = input_shape[-2];
auto input_width = input_shape[-1];
auto output_height = output_size[0];
auto output_width = output_size[1];

auto channels = input_shape.rank() == 3 ? input_shape[0] : input_shape[0] * input_shape[1];
auto output_shape = input_shape;
output_shape[-2] = output_height;
output_shape[-1] = output_width;

auto output_mem_config = mem_config.value_or(input_mem_config);
// create output tensor
auto output_tensor =
create_device_tensor(output_shape, input.get_dtype(), input.get_layout(), input.device(), output_mem_config);

std::vector<float> input_buffer(input.volume());
tt::tt_metal::tensor_impl::read_data_from_device_buffer<float>(
input.device()->command_queue(), input.device_buffer(), input_buffer.data(), true);

std::vector<float> output_buffer(output_tensor.volume(), 0.0f);

auto input_strides = input.strides();
auto output_strides = output_tensor.strides();
for (uint32_t c = 0; c < channels; ++c) {
for (uint32_t oh = 0; oh < output_height; ++oh) {
int64_t ih0 = start_index(oh, output_height, input_height);
int64_t ih1 = end_index(oh, output_height, input_height);
int64_t kh = ih1 - ih0;

for (uint32_t ow = 0; ow < output_width; ++ow) {
int64_t iw0 = start_index(ow, output_width, input_width);
int64_t iw1 = end_index(ow, output_width, input_width);
int64_t kw = iw1 - iw0;

float sum = 0.0f;
for (int64_t ih = ih0; ih < ih1; ++ih) {
for (int64_t iw = iw0; iw < iw1; ++iw) {
size_t input_idx = c * input_strides[1] + ih * input_width + iw;
sum += input_buffer[input_idx];
}
}
size_t output_idx = c * output_strides[1] + oh * output_width + ow;
output_buffer[output_idx] = sum / kh / kw;
}
}
}

auto output_tensor_buffer = tt::tt_metal::owned_buffer::create<float>(output_buffer.size());
std::copy(output_buffer.begin(), output_buffer.end(), output_tensor_buffer.begin());

auto output_tensor_with_data =
Tensor(OwnedStorage{output_tensor_buffer}, output_shape, input.get_dtype(), input.get_layout());
if (input.device() != nullptr) {
output_tensor_with_data = output_tensor_with_data.to(input.device(), output_mem_config);
}

return output_tensor_with_data;
}

} // namespace ttnn::operations::pool
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ttnn/decorators.hpp"

namespace ttnn {
namespace operations::pool {

inline int64_t start_index(int64_t out_idx, int64_t out_size, int64_t in_size) {
return (out_idx * in_size) / out_size;
}

inline int64_t end_index(int64_t out_idx, int64_t out_size, int64_t in_size) {
return ((out_idx + 1) * in_size + out_size - 1) / out_size;
}

struct AdaptiveAvgPool2DOperation {
static ttnn::Tensor invoke(
const ttnn::Tensor& input, const ttnn::Shape& output_size, const std::optional<MemoryConfig>& mem_config);
};

} // namespace operations::pool

constexpr auto adaptive_avg_pool2d =
ttnn::register_operation<"ttnn::adaptive_avg_pool2d", ttnn::operations::pool::AdaptiveAvgPool2DOperation>();

} // namespace ttnn
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "adaptive_avg_pool_pybind.hpp"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "pybind11/decorators.hpp"
#include "ttnn/operations/pool/adaptive_avg_pool/adaptive_avg_pool.hpp"
#include "ttnn/types.hpp"

namespace ttnn::operations::pool {
namespace detail {
template <typename pool_operation_t>
void bind_adaptive_avg_pool(pybind11::module& module, const pool_operation_t& operation, const char* doc) {
bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const decltype(operation)& self,
const ttnn::Tensor& input,
const ttnn::Shape& output_size,
const std::optional<ttnn::MemoryConfig>& memory_config) {
return self(input, output_size, memory_config);
},
pybind11::arg("input").noconvert(),
pybind11::arg("output_size"),
pybind11::kw_only(),
pybind11::arg("memory_config") = std::nullopt});
}
} // namespace detail

void py_bind_adaptive_avg_pool(pybind11::module& module) {
detail::bind_adaptive_avg_pool(
module,
ttnn::adaptive_avg_pool2d,
R"doc(adaptive_avg_pool2d(input: ttnn.Tensor, output_size: ttnn.Shape) -> ttnn.Tensor
Applies a 2D adaptive average pooling operation on the input tensor.
Args:
* :attr:`input`: Input tensor.
* :attr:`output_size`: Target output size.
* :attr:`<optional> mem_config`.
)doc");
}
} // namespace ttnn::operations::pool
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once
#include "pybind11/pybind_fwd.hpp"

namespace ttnn::operations::pool {
void py_bind_adaptive_avg_pool(pybind11::module& module);
} // namespace ttnn::operations::pool

0 comments on commit 54468f5

Please sign in to comment.