From 408d4a6eb540b630ff171b604c079689da15ed8f Mon Sep 17 00:00:00 2001 From: Borys Bradel <164946524+bbradelTT@users.noreply.github.com> Date: Mon, 27 Jan 2025 13:42:33 -0500 Subject: [PATCH] #14898: pass in pad value to transpose in reduce (#17142) ### Ticket Link to Github Issue #14898 Subset of previous PR (https://github.com/tenstorrent/tt-metal/pull/16989) that caused a hang in (Single-card) Demo tests and got reverted. Verified that this pipeline passes for this subset of changes: https://github.com/tenstorrent/tt-metal/actions/runs/12992459972 ### Problem description - transpose was filling in non-logical areas with default pad value when called from reduce ### What's changed - pass in an appropriate pad value for transpose to use - also mark a method that should only be used by pool to be deprecated to deter other uses ### Checklist - [x] Post commit CI passes https://github.com/tenstorrent/tt-metal/actions/runs/12992465641 - [ ] Blackhole Post commit (if applicable) - [ ] Model regression CI testing passes (if applicable) - [ ] Device performance regression CI testing passes (if applicable) - [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) tests passes - [x] New/Existing tests provide coverage for changes --- .../unit_testing/misc/test_min_max.py | 2 -- tests/ttnn/unit_tests/operations/test_max.py | 15 ++++++++++++++- .../unit_tests/operations/test_reduction_mean.py | 1 + .../reduction/generic/generic_reductions.cpp | 15 ++++++++++----- .../reduction/generic/generic_reductions.hpp | 1 + 5 files changed, 26 insertions(+), 8 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_min_max.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_min_max.py index 94cce42f17a7..acae71248476 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_min_max.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_min_max.py @@ -53,8 +53,6 @@ def test_min_max_for_dim_hw(device, use_program_cache, shape_dim, kind, layout): if kind == "max": value = x.max() elif kind == "min": - if N * C % 32 != 0: - pytest.skip("global min with Tensor dimension N*C not multiple of 32 is not supported at this time.") value = x.min() elif kind == "mean": value = x.mean() diff --git a/tests/ttnn/unit_tests/operations/test_max.py b/tests/ttnn/unit_tests/operations/test_max.py index 411fbd0ab44d..f6536f16f4e5 100644 --- a/tests/ttnn/unit_tests/operations/test_max.py +++ b/tests/ttnn/unit_tests/operations/test_max.py @@ -8,7 +8,7 @@ import ttnn from tests.ttnn.utils_for_testing import assert_with_pcc -from models.utility_functions import torch_random +from models.utility_functions import torch_random, is_grayskull @pytest.mark.parametrize("batch_size", [1, 16, 1, 16]) @@ -99,11 +99,24 @@ def test_max_global(device, batch_size, h, w): ((2, 32, 32, 64), -3), ((32, 32, 64), -3), ((1, 2, 3, 4), -1), + ((2, 22, 37, 41), -4), + ((2, 32, 64, 64), -3), + ((2, 22, 37, 41), -3), + ((2, 32, 64, 64), -2), + ((2, 22, 37, 41), -1), + ((2, 32, 64, 64), -1), + ((2, 22, 37), -3), + ((2, 22, 37), -2), + ((2, 22, 37), -1), + ((1, 6, 7), -3), + ((32, 6, 7), -3), ], ) @pytest.mark.parametrize("keepdim", [True, False]) def test_max_dim(device, input_shape_and_dim, keepdim): input_shape, max_dim = input_shape_and_dim + if is_grayskull() and (input_shape[-1] % 32 != 0 or input_shape[-2] % 32 != 0 or input_shape[max_dim] % 32 != 0): + pytest.skip("If not a tile size multiple, may fail on GS if run all the tests in this file. #17084") torch_input_tensor = torch_random(input_shape, -100, 100, dtype=torch.bfloat16) torch_output_tensor, _ = torch.max(torch_input_tensor, dim=max_dim, keepdim=keepdim) diff --git a/tests/ttnn/unit_tests/operations/test_reduction_mean.py b/tests/ttnn/unit_tests/operations/test_reduction_mean.py index b9e8786ca387..e9146dc8e61f 100644 --- a/tests/ttnn/unit_tests/operations/test_reduction_mean.py +++ b/tests/ttnn/unit_tests/operations/test_reduction_mean.py @@ -22,6 +22,7 @@ def test_mean(device, batch_size, h, w, dim): torch_output_tensor = torch.mean(torch_input_tensor, dim=dim, keepdim=True, dtype=torch.bfloat16) input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + ttnn.fill_implicit_tile_padding(input_tensor, 42) # garbage padding to test that mean removes it output_tensor = ttnn.mean(input_tensor, dim=dim) output_tensor = ttnn.to_torch(output_tensor) diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp index 6daaa3308292..aa83584adfc2 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp @@ -50,6 +50,12 @@ ttnn::SmallVector generate_reduce_dim( return dim; } +float get_pad_value(ReduceType reduce_type) { + return reduce_type == ReduceType::Max + ? -std::numeric_limits::infinity() + : (reduce_type == ReduceType::Min ? std::numeric_limits::infinity() : 0); +} + template static Tensor reduce_impl( const Tensor& input_tensor_arg, @@ -79,6 +85,7 @@ static Tensor reduce_impl( auto input_tensor = ttnn::unsqueeze_to_4D(input_tensor_arg); Tensor output_tensor; + float pad_value = get_pad_value(reduce_type); bool single_reduce_op = (dim.size() == 1 && (dim[0] == rank - 1 || dim[0] == rank - 2)) || (dim.size() == 2 && dim[1] == rank - 1 && dim[0] == rank - 2); if (!single_reduce_op) { @@ -92,7 +99,7 @@ static Tensor reduce_impl( int adjusted_dim = offset + i_dim; int reduce_dim = adjusted_dim; if (transpose) { - output_tensor = ttnn::transpose(output_tensor, adjusted_dim, 2, memory_config); + output_tensor = ttnn::transpose(output_tensor, adjusted_dim, -2, memory_config, pad_value); reduce_dim = 2; } if (use_reduce_type) { @@ -115,7 +122,7 @@ static Tensor reduce_impl( /*reshape=*/false); } if (transpose) { - output_tensor = ttnn::transpose(output_tensor, adjusted_dim, -2, memory_config); + output_tensor = ttnn::transpose(output_tensor, adjusted_dim, -2, memory_config, pad_value); } } } @@ -241,9 +248,7 @@ Tensor Reduce::invoke( const std::optional& compute_kernel_config, float scalar) { ttnn::SmallVector dim = generate_reduce_dim(input_tensor_arg, dim_arg); - float pad_value = reduce_type == ReduceType::Max - ? -std::numeric_limits::infinity() - : (reduce_type == ReduceType::Min ? std::numeric_limits::infinity() : 0); + float pad_value = get_pad_value(reduce_type); bool is_tiled = input_tensor_arg.get_layout() == TILE_LAYOUT; auto input_tensor = is_tiled ? ttnn::fill_implicit_tile_padding(input_tensor_arg, pad_value) : input_tensor_arg; if constexpr (reduce_type == ReduceType::Std || reduce_type == ReduceType::Var) { diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.hpp b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.hpp index 137dba6f7ce4..f592a5575089 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.hpp @@ -34,6 +34,7 @@ struct Reduce { }; // Entry point for pool op, which uses non-standard tensors that cannot be padded. +[[deprecated]] Tensor pool_sum( const Tensor& input_tensor_arg, int dim_arg,