forked from tenstorrent/tt-metal
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for adaptive_avgpool2d operation
- Loading branch information
Showing
7 changed files
with
190 additions
and
0 deletions.
There are no files selected for viewing
25 changes: 25 additions & 0 deletions
25
tests/ttnn/unit_tests/operations/test_adaptive_avg_pool2d.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from loguru import logger | ||
|
||
import torch | ||
import pytest | ||
import math | ||
from tests.ttnn.utils_for_testing import assert_with_pcc | ||
import ttnn | ||
|
||
|
||
def test_run_adaptive_avg_pool2d(device): | ||
input_shape = [1, 4, 8, 16] | ||
output_size = (8, 8) | ||
dtype = ttnn.float32 | ||
|
||
torch.manual_seed(0) | ||
torch_input_tensor = torch.randn(input_shape) | ||
torch_output_tensor = torch.nn.functional.adaptive_avg_pool2d(torch_input_tensor, output_size) | ||
input_tensor = ttnn.from_torch(torch_input_tensor, dtype=dtype, device=device) | ||
output_tensor = ttnn.adaptive_avg_pool2d(input_tensor, ttnn.Shape(output_size)) | ||
|
||
output_tensor = ttnn.to_torch(output_tensor) | ||
assert_with_pcc(torch_output_tensor, output_tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
75 changes: 75 additions & 0 deletions
75
ttnn/cpp/ttnn/operations/pool/adaptive_avg_pool/adaptive_avg_pool.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "adaptive_avg_pool.hpp" | ||
#include "ttnn/operations/core/core.hpp" | ||
|
||
namespace ttnn::operations::pool { | ||
|
||
ttnn::Tensor AdaptiveAvgPool2DOperation::invoke( | ||
const ttnn::Tensor& input, const ttnn::Shape& output_size, const std::optional<MemoryConfig>& mem_config) { | ||
TT_FATAL(input.storage_type() == StorageType::DEVICE, "Input tensor must be on device"); | ||
|
||
auto input_mem_config = input.memory_config(); | ||
auto input_shape = input.get_logical_shape(); | ||
auto input_height = input_shape[-2]; | ||
auto input_width = input_shape[-1]; | ||
auto output_height = output_size[0]; | ||
auto output_width = output_size[1]; | ||
|
||
auto channels = input_shape.rank() == 3 ? input_shape[0] : input_shape[0] * input_shape[1]; | ||
auto output_shape = input_shape; | ||
output_shape[-2] = output_height; | ||
output_shape[-1] = output_width; | ||
|
||
auto output_mem_config = mem_config.value_or(input_mem_config); | ||
// create output tensor | ||
auto output_tensor = | ||
create_device_tensor(output_shape, input.get_dtype(), input.get_layout(), input.device(), output_mem_config); | ||
|
||
std::vector<float> input_buffer(input.volume()); | ||
tt::tt_metal::tensor_impl::read_data_from_device_buffer<float>( | ||
input.device()->command_queue(), input.device_buffer(), input_buffer.data(), true); | ||
|
||
std::vector<float> output_buffer(output_tensor.volume(), 0.0f); | ||
|
||
auto input_strides = input.strides(); | ||
auto output_strides = output_tensor.strides(); | ||
for (uint32_t c = 0; c < channels; ++c) { | ||
for (uint32_t oh = 0; oh < output_height; ++oh) { | ||
int64_t ih0 = start_index(oh, output_height, input_height); | ||
int64_t ih1 = end_index(oh, output_height, input_height); | ||
int64_t kh = ih1 - ih0; | ||
|
||
for (uint32_t ow = 0; ow < output_width; ++ow) { | ||
int64_t iw0 = start_index(ow, output_width, input_width); | ||
int64_t iw1 = end_index(ow, output_width, input_width); | ||
int64_t kw = iw1 - iw0; | ||
|
||
float sum = 0.0f; | ||
for (int64_t ih = ih0; ih < ih1; ++ih) { | ||
for (int64_t iw = iw0; iw < iw1; ++iw) { | ||
size_t input_idx = c * input_strides[1] + ih * input_width + iw; | ||
sum += input_buffer[input_idx]; | ||
} | ||
} | ||
size_t output_idx = c * output_strides[1] + oh * output_width + ow; | ||
output_buffer[output_idx] = sum / kh / kw; | ||
} | ||
} | ||
} | ||
|
||
auto output_tensor_buffer = tt::tt_metal::owned_buffer::create<float>(output_buffer.size()); | ||
std::copy(output_buffer.begin(), output_buffer.end(), output_tensor_buffer.begin()); | ||
|
||
auto output_tensor_with_data = | ||
Tensor(OwnedStorage{output_tensor_buffer}, output_shape, input.get_dtype(), input.get_layout()); | ||
if (input.device() != nullptr) { | ||
output_tensor_with_data = output_tensor_with_data.to(input.device(), output_mem_config); | ||
} | ||
|
||
return output_tensor_with_data; | ||
} | ||
|
||
} // namespace ttnn::operations::pool |
30 changes: 30 additions & 0 deletions
30
ttnn/cpp/ttnn/operations/pool/adaptive_avg_pool/adaptive_avg_pool.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "ttnn/decorators.hpp" | ||
|
||
namespace ttnn { | ||
namespace operations::pool { | ||
|
||
inline int64_t start_index(int64_t out_idx, int64_t out_size, int64_t in_size) { | ||
return (out_idx * in_size) / out_size; | ||
} | ||
|
||
inline int64_t end_index(int64_t out_idx, int64_t out_size, int64_t in_size) { | ||
return ((out_idx + 1) * in_size + out_size - 1) / out_size; | ||
} | ||
|
||
struct AdaptiveAvgPool2DOperation { | ||
static ttnn::Tensor invoke( | ||
const ttnn::Tensor& input, const ttnn::Shape& output_size, const std::optional<MemoryConfig>& mem_config); | ||
}; | ||
|
||
} // namespace operations::pool | ||
|
||
constexpr auto adaptive_avg_pool2d = | ||
ttnn::register_operation<"ttnn::adaptive_avg_pool2d", ttnn::operations::pool::AdaptiveAvgPool2DOperation>(); | ||
|
||
} // namespace ttnn |
46 changes: 46 additions & 0 deletions
46
ttnn/cpp/ttnn/operations/pool/adaptive_avg_pool/adaptive_avg_pool_pybind.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "adaptive_avg_pool_pybind.hpp" | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
#include "pybind11/decorators.hpp" | ||
#include "ttnn/operations/pool/adaptive_avg_pool/adaptive_avg_pool.hpp" | ||
#include "ttnn/types.hpp" | ||
|
||
namespace ttnn::operations::pool { | ||
namespace detail { | ||
template <typename pool_operation_t> | ||
void bind_adaptive_avg_pool(pybind11::module& module, const pool_operation_t& operation, const char* doc) { | ||
bind_registered_operation( | ||
module, | ||
operation, | ||
doc, | ||
ttnn::pybind_overload_t{ | ||
[](const decltype(operation)& self, | ||
const ttnn::Tensor& input, | ||
const ttnn::Shape& output_size, | ||
const std::optional<ttnn::MemoryConfig>& memory_config) { | ||
return self(input, output_size, memory_config); | ||
}, | ||
pybind11::arg("input").noconvert(), | ||
pybind11::arg("output_size"), | ||
pybind11::kw_only(), | ||
pybind11::arg("memory_config") = std::nullopt}); | ||
} | ||
} // namespace detail | ||
|
||
void py_bind_adaptive_avg_pool(pybind11::module& module) { | ||
detail::bind_adaptive_avg_pool( | ||
module, | ||
ttnn::adaptive_avg_pool2d, | ||
R"doc(adaptive_avg_pool2d(input: ttnn.Tensor, output_size: ttnn.Shape) -> ttnn.Tensor | ||
Applies a 2D adaptive average pooling operation on the input tensor. | ||
Args: | ||
* :attr:`input`: Input tensor. | ||
* :attr:`output_size`: Target output size. | ||
* :attr:`<optional> mem_config`. | ||
)doc"); | ||
} | ||
} // namespace ttnn::operations::pool |
10 changes: 10 additions & 0 deletions
10
ttnn/cpp/ttnn/operations/pool/adaptive_avg_pool/adaptive_avg_pool_pybind.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
#include "pybind11/pybind_fwd.hpp" | ||
|
||
namespace ttnn::operations::pool { | ||
void py_bind_adaptive_avg_pool(pybind11::module& module); | ||
} // namespace ttnn::operations::pool |