Skip to content

Commit

Permalink
#14790: adding padding awareness
Browse files Browse the repository at this point in the history
  • Loading branch information
sjameelTT committed Nov 18, 2024
1 parent e5ee6a7 commit f32286d
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -872,5 +872,6 @@ def test_transpose_issue_11650_10350(shape, dims, layout, dtype, device):

tt_input = ttnn.from_torch(torch_input, dtype=dtype, layout=layout, device=device)
tt_output = ttnn.transpose(tt_input, dims[0], dims[1])
print(tt_output)
tt_output = ttnn.to_torch(tt_output)
assert_with_pcc(torch_output, tt_output, 0.9999)
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>
#include "dataflow_api.h"

#include "debug/dprint.h"

template <typename T>
FORCE_INLINE void fill_with_val(uint32_t begin_addr, uint32_t n, T val) {
auto* ptr = reinterpret_cast<volatile tt_l1_ptr T*>(begin_addr);
for (uint32_t i = 0; i < n; ++i) {
DPRINT << "fill_with_val: " << i << " " << val << ENDL();
ptr[i] = val;
}
}

void kernel_main() {
uint32_t src_addr = get_arg_val<uint32_t>(0);
uint32_t num_tiles = get_arg_val<uint32_t>(1);
uint32_t start_id = get_arg_val<uint32_t>(2);

constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;

constexpr uint32_t cb_id_in0 = 0;

// ublocks size defined in tiles
constexpr uint32_t onetile = 1;
const uint32_t tile_bytes = get_tile_size(cb_id_in0);
const DataFormat data_format = get_dataformat(cb_id_in0);

const InterleavedAddrGenFast<src_is_dram> s = {
.bank_base_address = src_addr,
.page_size = tile_bytes,
.data_format = data_format
};

// read a ublock of tiles from src to CB, and then push the ublock to unpacker
#ifdef BACKWARDS
uint32_t end_id = start_id - num_tiles;
for (uint32_t i = start_id; i != end_id; -- i) {
#else
uint32_t end_id = start_id + num_tiles;
for (uint32_t i = start_id; i < end_id; ++ i) {
#endif
cb_reserve_back(cb_id_in0, onetile);
uint32_t l1_write_addr = get_write_ptr(cb_id_in0);
noc_async_read_tile(i, s, l1_write_addr);
noc_async_read_barrier();
cb_push_back(cb_id_in0, onetile);
}

cb_reserve_back(tt::CB::c_in1, 1);
uint32_t l1_write_addr = get_write_ptr(tt::CB::c_in1);
fill_with_val<uint32_t>(l1_write_addr, 8, 123123);
cb_push_back(tt::CB::c_in1, 1);

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
// SPDX-License-Identifier: Apache-2.0

#include "dataflow_api.h"
#include "debug/dprint.h"

// Utility functions
inline constexpr uint32_t div_up(uint32_t a, uint32_t b) {
FORCE_INLINE constexpr uint32_t div_up(uint32_t a, uint32_t b) {
return static_cast<uint32_t>((a + b - 1) / b);
}

inline constexpr uint32_t round_up(uint32_t a, uint32_t b) {
FORCE_INLINE constexpr uint32_t round_up(uint32_t a, uint32_t b) {
return b * div_up(a, b);
}

Expand Down Expand Up @@ -166,4 +167,34 @@ void kernel_main() {
// Remove the processed tile from the front of the buffer
cb_pop_front(cb_id_out0, 1);
}

// add padding
if constexpr (C_p > C) {
cb_wait_front(tt::CB::c_in1, 1);
uint32_t l1_read_ptr = get_read_ptr(tt::CB::c_in1);
constexpr uint32_t N = 1;
uint32_t c_t = C_t - 1;
constexpr uint32_t num_padded_tiles = 1*H*W_t;
for (uint32_t tile_idx = 0; tile_idx < num_padded_tiles; ++tile_idx) {
// Map tile_idx to (n, h, w_t)
uint32_t n = tile_idx / (H * W_t);
uint32_t remainder1 = tile_idx % (H * W_t);
uint32_t h = remainder1 / W_t;
uint32_t w_t = remainder1 % W_t;
uint8_t C_in_tile = C % TILE_HEIGHT;
uint8_t face_c_start = C_in_tile/ FACE_HEIGHT;
for (uint8_t face_c = face_c_start; face_c < NUM_FACES_H; ++face_c) {
uint8_t sub_tile_line_start = face_c == face_c_start ? C_in_tile % FACE_HEIGHT : 0;
for (uint8_t face_w = 0; face_w < NUM_FACES_W; ++face_w) {
for (uint8_t sub_tile_line = sub_tile_line_start; sub_tile_line < FACE_HEIGHT; ++sub_tile_line) {
uint32_t linear_idx = n * H * C_t * W_t + h * C_t * W_t + c_t * W_t + w_t;
uint32_t offset = (face_c * NUM_FACES_W * FACE_HEIGHT * FACE_WIDTH + face_w * FACE_HEIGHT * FACE_WIDTH + sub_tile_line * FACE_WIDTH) * element_size;
uint64_t write_noc_base_addr = get_noc_addr(linear_idx, s, offset);
noc_async_write(l1_read_ptr, write_noc_base_addr, SUBTILE_LINE_BYTES);
}
}
}
}
cb_pop_front(tt::CB::c_in1, 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -540,11 +540,18 @@ operation::ProgramWithCallbacks transpose_hc_multi_core_tiled_interleaved(const
auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tensor_tiles);

uint32_t src0_cb_index = tt::CB::c_in0;
uint32_t padding_cb_index = tt::CB::c_in1;

tt::tt_metal::CircularBufferConfig cb_src0_config =
tt::tt_metal::CircularBufferConfig(2 * single_tile_size, {{src0_cb_index, cb_data_format}})
.set_page_size(src0_cb_index, single_tile_size);
auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config);

tt::tt_metal::CircularBufferConfig cb_src1_config =
tt::tt_metal::CircularBufferConfig(face_shape[1] * a.element_size(), {{padding_cb_index, cb_data_format}})
.set_page_size(padding_cb_index, face_shape[1] * a.element_size());
auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src1_config);


// create reader kernel with compile time and runtime args
tt::tt_metal::Buffer *src_buffer = a.buffer();
Expand All @@ -553,7 +560,7 @@ operation::ProgramWithCallbacks transpose_hc_multi_core_tiled_interleaved(const

tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_start_id.cpp",
"ttnn/cpp/ttnn/operations/data_movement/transpose/device/kernels/dataflow/reader_unary_interleaved_start_id_with_padding.cpp",
total_cores,
tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args));

Expand Down

0 comments on commit f32286d

Please sign in to comment.