Skip to content

Commit

Permalink
#14308: add native h-c padding handling to transpose_hc kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
sjameelTT committed Nov 19, 2024
1 parent 76f8ae2 commit d4ee500
Show file tree
Hide file tree
Showing 9 changed files with 378 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def test_rotary_embedding_llama_with_program_cache(

num_ops = 2 # 2 * rope
if mode == "decode":
num_ops += 4 # embedding + transpose + pad + interleaved_to_sharded
num_ops += 3 # embedding + transpose + interleaved_to_sharded

# When batch size is 1, transpose is a no-op
if batch == 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from models.utility_functions import skip_for_grayskull, skip_for_blackhole
from tests.ttnn.utils_for_testing import assert_with_pcc

torch.manual_seed(2005)


def transpose(
input_shape,
Expand Down Expand Up @@ -124,9 +126,7 @@ def test_transpose_hc_program_cache(dtype, device, use_program_cache):
H = 32
W = 32
input_shape = (N, C, H, W)
# CACHE MISS since its single core
# Cache size 2 more because of pad op in single core impl + transpose
transpose(input_shape, device, dim0=1, dim1=-2, expected_program_cache_size=4, input_dtype=dtype)
transpose(input_shape, device, dim0=1, dim1=-2, expected_program_cache_size=3, input_dtype=dtype)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -839,7 +839,17 @@ def test_transpose_5d(shape, dims, layout, device):

@pytest.mark.parametrize(
"shape",
[[1, 5, 10, 15], [1, 1, 1, 2]],
[
[1, 5, 10, 15],
[1, 1, 1, 2],
[1, 3, 2, 1],
[1, 17, 1, 1],
[1, 1, 16, 1],
[1, 1, 17, 1],
[1, 1, 1, 17],
[2, 1, 1, 1],
[2, 33, 33, 33],
],
)
@pytest.mark.parametrize(
"dims",
Expand All @@ -852,13 +862,15 @@ def test_transpose_5d(shape, dims, layout, device):
"layout",
[ttnn.TILE_LAYOUT],
)
def test_transpose_issue_11650_10350(shape, dims, layout, device):
@pytest.mark.parametrize(
"dtype",
[ttnn.float32, ttnn.bfloat16],
)
def test_transpose_issue_11650_10350(shape, dims, layout, dtype, device):
torch_input = torch.randn(shape, dtype=torch.bfloat16)
torch_output = torch_input.transpose(dims[0], dims[1])

tt_input = ttnn.from_torch(torch_input, dtype=ttnn.DataType.BFLOAT16, layout=layout, device=device)
print(tt_input)
tt_input = ttnn.from_torch(torch_input, dtype=dtype, layout=layout, device=device)
tt_output = ttnn.transpose(tt_input, dims[0], dims[1])
print(tt_output)
tt_output = ttnn.to_torch(tt_output)
assert_with_pcc(torch_output, tt_output, 0.9999)
12 changes: 0 additions & 12 deletions ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,6 @@ ttnn::Tensor permute_impl(const ttnn::Tensor &a, const SmallVector<uint32_t>& di
// Convert tensor back to original
auto input_shape = a.get_logical_shape();

// create_output_tensor shape is useless when we potentially have new padding to deal with
SmallVector<uint32_t> output_shape = {input_shape[N], input_shape[C], input_shape[H], input_shape[W]};
SmallVector<uint32_t> padded_output_shape = output_shape;

uint32_t input_rank = a.get_logical_shape().rank();
if (a.layout() == Layout::TILE) {
padded_output_shape[input_rank - 1] = tt::round_up(padded_output_shape[input_rank - 1], tt::constants::TILE_WIDTH);
padded_output_shape[input_rank - 2] = tt::round_up(padded_output_shape[input_rank - 2], tt::constants::TILE_HEIGHT);
}

ttnn::Shape final_shape = ttnn::Shape(output_shape, padded_output_shape);
auto formatted_input_tensor = a;
bool typecast = formatted_input_tensor.get_dtype() == DataType::BFLOAT8_B and formatted_input_tensor.get_layout() == Layout::TILE and (pad_n or pad_c) and !a.is_sharded();
formatted_input_tensor = typecast ? ttnn::typecast(formatted_input_tensor, DataType::BFLOAT16) : formatted_input_tensor;
Expand Down Expand Up @@ -130,7 +119,6 @@ ttnn::Tensor permute_impl(const ttnn::Tensor &a, const SmallVector<uint32_t>& di
} else {
TT_ASSERT(false, "Illegal permute args");
}
output = ttnn::reshape(output, final_shape);
output = typecast ? ttnn::typecast(output, DataType::BFLOAT8_B) : output;
return output;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "dataflow_api.h"

// Utility functions
inline constexpr uint32_t div_up(uint32_t a, uint32_t b) {
return static_cast<uint32_t>((a + b - 1) / b);
}

inline constexpr uint32_t round_up(uint32_t a, uint32_t b) {
return b * div_up(a, b);
}

void kernel_main() {

// Retrieve arguments
uint32_t dst_addr = get_arg_val<uint32_t>(0);
uint32_t start_tile_idx = get_arg_val<uint32_t>(1);
uint32_t end_tile_idx = get_arg_val<uint32_t>(2);

// Compile-time constants
constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1;
constexpr uint32_t element_size = get_compile_time_arg_val(1);
constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(2);
constexpr uint32_t C = get_compile_time_arg_val(3);
constexpr uint32_t H = get_compile_time_arg_val(4);
constexpr uint32_t W = get_compile_time_arg_val(5);
constexpr uint32_t TILE_HEIGHT = get_compile_time_arg_val(6);
constexpr uint32_t TILE_WIDTH = get_compile_time_arg_val(7);
constexpr uint32_t FACE_HEIGHT = get_compile_time_arg_val(8);
constexpr uint32_t FACE_WIDTH = get_compile_time_arg_val(9);

// Derived compile-time constants
constexpr uint32_t TILE_HW = TILE_HEIGHT * TILE_WIDTH;
constexpr uint8_t NUM_FACES_H = TILE_HEIGHT / FACE_HEIGHT;
constexpr uint8_t NUM_FACES_W = TILE_WIDTH / FACE_WIDTH;

constexpr uint32_t C_p = round_up(C, TILE_HEIGHT);
constexpr uint32_t H_p = round_up(H, TILE_HEIGHT);
constexpr uint32_t W_p = round_up(W, TILE_WIDTH);

constexpr uint32_t W_t = W_p / TILE_WIDTH;
constexpr uint32_t H_t = H_p / TILE_HEIGHT;
constexpr uint32_t C_t = C_p / TILE_HEIGHT;

constexpr uint32_t SUBTILE_LINE_BYTES = FACE_WIDTH * element_size;

// Initialize address generator
const uint32_t tile_bytes = get_tile_size(cb_id_out0);
const auto input_data_format = get_dataformat(cb_id_out0);

const InterleavedAddrGenFast<dst_is_dram, TILE_HW> s = {
.bank_base_address = dst_addr,
.page_size = tile_bytes,
.data_format = input_data_format
};

// Calculate actual data height in the last tile
constexpr uint32_t H_last_tile = H - (H_t - 1) * TILE_HEIGHT;

// Calculate real_faces_h
uint8_t remainder_faces_h = (H_last_tile + FACE_HEIGHT - 1) / FACE_HEIGHT;
if (remainder_faces_h > NUM_FACES_H) {
// Ensure it does not exceed maximum number of faces per tile
remainder_faces_h = NUM_FACES_H;
}

uint32_t remainder = H_last_tile % FACE_HEIGHT;
uint8_t sub_tile_lines_real = (remainder == 0) ? FACE_HEIGHT : static_cast<uint8_t>(remainder);

// Precompute constants used in inner loops
const uint32_t face_height_width = FACE_HEIGHT * FACE_WIDTH;
const uint32_t num_faces_wh = NUM_FACES_W * FACE_WIDTH;

// Main single loop over all tiles
for (uint32_t tile_idx = start_tile_idx; tile_idx < end_tile_idx; ++tile_idx) {
// Compute n, c, h, w from tile_idx
uint32_t w = tile_idx % W_t;
uint32_t temp = tile_idx / W_t;

uint32_t h = temp % H_t;
temp /= H_t;

uint32_t c = temp % C;
uint32_t n = temp / C;

// Recalculate variables from the original loops
uint32_t output_ct_index = c / TILE_HEIGHT;
uint32_t rem = c % TILE_HEIGHT;

// Calculate the index inside the face_matrix
uint32_t output_face_h = rem / FACE_HEIGHT;
uint32_t output_sub_tile_line = rem % FACE_HEIGHT;

// Calculate the index along the channel dimension for the output tensor
uint32_t output_h = h * TILE_HEIGHT;

// Synchronization and read address retrieval
cb_wait_front(cb_id_out0, 1);
uint32_t l1_read_addr = get_read_ptr(cb_id_out0);

// Determine the number of faces in the height dimension
uint8_t num_faces_h = (h == H_t - 1) ? remainder_faces_h : NUM_FACES_H;

// Precompute parts of linear_idx that remain constant within the inner loops
// linear_idx = n * H * C_t * W_t + output_h_face_line * C_t * W_t + output_ct_index * W_t + w
// We can precompute n * H * C_t * W_t + output_ct_index * W_t + w
uint32_t base_linear_idx = n * H * C_t * W_t + output_ct_index * W_t + w;

// Iterate over faces in the height dimension
for (uint8_t face_h = 0; face_h < num_faces_h; ++face_h) {
// Compute output_h_face once per face_h
uint32_t output_h_face = output_h + face_h * FACE_HEIGHT;

// Precompute the additive factor for output_h_face_line
uint32_t base_output_h_face_line = output_h_face;

// Iterate over faces in the width dimension
for (uint8_t face_w = 0; face_w < NUM_FACES_W; ++face_w) {
// Compute output_w_face once per face_w
uint32_t output_w_face = w + face_w * FACE_WIDTH;

// Precompute the offset multiplier for the current face_w
uint32_t face_w_offset = face_w * face_height_width;

// Determine the number of sub-tile lines to process
bool is_last_sub_tile_line = (h == H_t - 1) && (face_h == num_faces_h - 1);
uint8_t sub_tile_lines = is_last_sub_tile_line ? sub_tile_lines_real : FACE_HEIGHT;

// Precompute offset for the current face_h
uint32_t face_h_offset = output_face_h * NUM_FACES_W * face_height_width;

// Iterate over sub-tile lines
for (uint8_t sub_tile_line = 0; sub_tile_line < sub_tile_lines; ++sub_tile_line) {
// Compute the complete output_h_face_line
uint32_t output_h_face_line = base_output_h_face_line + sub_tile_line;

// Compute the linear index
uint32_t linear_idx = base_linear_idx + output_h_face_line * C_t * W_t;

// Compute the offset
uint32_t offset = (face_h_offset + face_w_offset + output_sub_tile_line * FACE_WIDTH) * element_size;

// Compute the write address
uint64_t write_noc_base_addr = get_noc_addr(linear_idx, s, offset);

// Perform asynchronous write
noc_async_write(l1_read_addr, write_noc_base_addr, SUBTILE_LINE_BYTES);

// Increment the read address
l1_read_addr += SUBTILE_LINE_BYTES;
}

// Skip padding if not all lines are real
if (is_last_sub_tile_line) {
l1_read_addr += (FACE_HEIGHT - sub_tile_lines) * SUBTILE_LINE_BYTES;
}
}
}

// Ensure all asynchronous writes are completed before proceeding
noc_async_write_barrier();

// Remove the processed tile from the front of the buffer
cb_pop_front(cb_id_out0, 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ void Transpose::validate(const std::vector<Tensor> &input_tensors) const {
if (row_major) {
auto BUFFER_ALIGNMENT = input_tensor.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM ? DRAM_ALIGNMENT : L1_ALIGNMENT;
TT_FATAL((W * input_tensor.element_size()) % BUFFER_ALIGNMENT == 0, "Buffer is not aligned for this implementation row_size_bytes {} buffer_alignment {}", W * input_tensor.element_size(), BUFFER_ALIGNMENT);
} else {
TT_FATAL(C % TILE_HEIGHT == 0, "Error");
}
TT_FATAL(
input_tensor.get_dtype() == DataType::BFLOAT16 || input_tensor.get_dtype() == DataType::FLOAT32, "Error");
Expand Down Expand Up @@ -92,9 +90,21 @@ std::vector<ttnn::TensorSpec> Transpose::compute_output_specs(const std::vector<
std::swap(output_padded_shape[0], output_padded_shape[1]);
break;
case TransposeOpDim::HC:
std::swap(output_shape[1], output_shape[2]);
std::swap(output_padded_shape[1], output_padded_shape[2]);
break;
if (input_tensor.is_sharded() || input_tensor.get_layout() != Layout::TILE) {
std::swap(output_shape[1], output_shape[2]);
std::swap(output_padded_shape[1], output_padded_shape[2]);
break;
} else {
uint32_t C = output_shape[1];
uint32_t C_p = tt::round_up(C, input_tensor.get_tile().get_height());
uint32_t H = output_shape[2];
output_shape[1] = H;
output_shape[2] = C;
output_padded_shape[1] = H;
output_padded_shape[2] = C_p;
break;
}

case TransposeOpDim::WH:
std::swap(output_shape[2], output_shape[3]);
std::swap(output_padded_shape[2], output_padded_shape[3]);
Expand Down
Loading

0 comments on commit d4ee500

Please sign in to comment.