Skip to content

Commit

Permalink
#12662: add keepdim fixes to reduce (#16163)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue
#12662
#14898
#13361
#12170

### Problem description
- padding caused issues for max
- keepdim=False errored out

### What's changed
- remove the erroring out of keepdim=False and adjust code to handle
keepdim=False properly
- adding padding within min/max to ensure that it's set up properly has
been pushed back to a future PR

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/12432801168
- [x] Blackhole Post commit (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12423085751
- [x] Model regression CI testing passes (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12423092106 same as
main
https://github.com/tenstorrent/tt-metal/actions/runs/12422179419/job/34683976776
- [x] Device performance regression CI testing passes (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12423088573
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
bbradelTT authored Dec 20, 2024
1 parent 5272cee commit ec1869e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 29 deletions.
26 changes: 26 additions & 0 deletions tests/ttnn/unit_tests/operations/test_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,29 @@ def test_max_global(device, batch_size, h, w):
output_tensor = output_tensor[0, 0, 0]

assert_with_pcc(torch_output_tensor, output_tensor)


@pytest.mark.parametrize(
"input_shape_and_dim",
[
((32, 32, 32, 64), -4),
((2, 32, 32, 64), -3),
((32, 32, 64), -3),
],
)
@pytest.mark.parametrize("keepdim", [True, False])
def test_max_dim(device, input_shape_and_dim, keepdim):
input_shape, max_dim = input_shape_and_dim

torch_input_tensor = torch_random(input_shape, -100, 100, dtype=torch.bfloat16)
torch_output_tensor, _ = torch.max(torch_input_tensor, dim=max_dim, keepdim=keepdim)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.max(input_tensor, dim=max_dim, keepdim=keepdim)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor)
52 changes: 23 additions & 29 deletions ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ static Tensor reduce_impl(
float scalar,
bool reshape) {
using ttnn::operations::experimental::auto_format::AutoFormat;
if (not keepdim) {
TT_THROW("keepdim=False is not supported");
}

auto input_shape = input_tensor_arg.get_shape();
auto rank = input_shape.size();
auto memory_config = memory_config_arg.value_or(input_tensor_arg.memory_config());
Expand Down Expand Up @@ -58,41 +54,39 @@ static Tensor reduce_impl(
rank);
}

if (dim.size() == 1 && rank == 4) {
if (dim[0] == rank - 3) {
auto out_shape = input_tensor_arg.get_legacy_shape();
out_shape[1] = 1;

Tensor output = ttnn::transpose(input_tensor_arg, 1, -2, memory_config);
output = reduce_impl<reduce_type>(output, 2, keepdim, memory_config, compute_kernel_config, scalar, false);
output = ttnn::transpose(output, 1, -2, memory_config);
return AutoFormat::format_output_tensor(output, out_shape, input_tensor_arg.device(), Layout::TILE);
} else if (dim[0] == 0) {
auto out_shape = input_tensor_arg.get_legacy_shape();
out_shape[0] = 1;

Tensor output = ttnn::transpose(input_tensor_arg, 0, -2, memory_config);
output = reduce_impl<reduce_type>(output, 2, keepdim, memory_config, compute_kernel_config, scalar, false);
output = ttnn::transpose(output, 0, -2, memory_config);
return AutoFormat::format_output_tensor(output, out_shape, input_tensor_arg.device(), Layout::TILE);
}
}

std::sort(dim.begin(), dim.end());

ttnn::SmallVector<uint32_t> output_shape;
ttnn::SmallVector<uint32_t> padded_output_shape;
for (int axis = 0; axis < input_shape.size(); axis++) {
if (std::find(dim.begin(), dim.end(), axis) != dim.end()) {
if (keepdim) {
output_shape.push_back(1);
padded_output_shape.push_back(axis >= rank - 2 ? ttnn::TILE_SIZE : 1);
}
} else {
// Get the shape for the output tensor
output_shape.push_back(input_shape[axis]);
// Get the padded shape for the output tensor
padded_output_shape.push_back(input_shape.value[axis]);
}
}

if (dim.size() == 1 && (rank == 3 || rank == 4)) {
if (dim[0] == 1 && rank == 4) {
Tensor output = ttnn::transpose(input_tensor_arg, 1, -2, memory_config);
output = reduce_impl<reduce_type>(
output, 2, /*keepdim=*/true, memory_config, compute_kernel_config, scalar, /*reshape=*/true);
output = ttnn::transpose(output, 1, -2, memory_config);
if (reshape) {
output = ttnn::reshape(output, ttnn::Shape{output_shape});
}
return output;
} else if (dim[0] == 0) {
Tensor output = ttnn::transpose(input_tensor_arg, 0, -2, memory_config);
output = reduce_impl<reduce_type>(
output, -2, /*keepdim=*/true, memory_config, compute_kernel_config, scalar, /*reshape=*/true);
output = ttnn::transpose(output, 0, -2, memory_config);
if (reshape) {
output = ttnn::reshape(output, ttnn::Shape{output_shape});
}
return output;
}
}

Expand Down Expand Up @@ -199,7 +193,7 @@ static Tensor reduce_impl(
}

if (reshape) {
output_tensor = ttnn::reshape(output_tensor, ttnn::Shape{output_shape, padded_output_shape});
output_tensor = ttnn::reshape(output_tensor, ttnn::Shape{output_shape});
}

return output_tensor;
Expand Down

0 comments on commit ec1869e

Please sign in to comment.