Skip to content

Commit

Permalink
Use dense packed CB indices for Matmul (#17081)
Browse files Browse the repository at this point in the history
### Ticket
[Link to Github
Issue](#16954)

### Problem description
Most CBs between 1 and 16 are unused. The causes the dispatcher to waste
timing initializing many unneeded CBs, so it would be better to pack
them starting at 0.


### Checklist
- [x] Post commit CI passes
- [x] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
yugaoTT authored Jan 24, 2025
1 parent 7e81658 commit c203bf6
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ void MAIN {

constexpr uint32_t in0_cb_id = tt::CBIndex::c_0;
constexpr uint32_t in1_cb_id = tt::CBIndex::c_1;
constexpr uint32_t out_cb_id = tt::CBIndex::c_16;
constexpr uint32_t mm_partials_cb_id = tt::CBIndex::c_24;
constexpr uint32_t out_cb_id = tt::CBIndex::c_4;
constexpr uint32_t mm_partials_cb_id = tt::CBIndex::c_5;

constexpr uint32_t untilize_mode_out_cb_id = untilize_out ? mm_partials_cb_id : out_cb_id;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void kernel_main() {
reinterpret_cast<volatile tt_l1_ptr uint32_t*>(in0_mcast_sender_semaphore_addr);

// L1 array
constexpr uint32_t cb_l1_array = tt::CBIndex::c_5;
constexpr uint32_t cb_l1_array = tt::CBIndex::c_6;
uint32_t in0_mcast_sender_semaphore_valid_addr = get_write_ptr(cb_l1_array);
volatile tt_l1_ptr uint32_t* in0_mcast_sender_semaphore_valid_addr_ptr =
reinterpret_cast<volatile tt_l1_ptr uint32_t*>(in0_mcast_sender_semaphore_valid_addr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ void kernel_main() {
constexpr uint32_t cb_id_in1 = 1;

// WRITER
constexpr uint32_t cb_id_out0 = 16;
constexpr uint32_t cb_id_out0 = tt::CBIndex::c_4;

// WRITER
// single-tile
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ void kernel_main() {
#endif

constexpr uint32_t cb_id_in1 = 1;
constexpr uint32_t cb_id_out = 16;
constexpr uint32_t cb_id_out_reshard = 17;
constexpr uint32_t cb_id_out = tt::CBIndex::c_4;
constexpr uint32_t cb_id_out_reshard = tt::CBIndex::c_6;
constexpr uint32_t in1_single_tile_size_bytes = get_tile_size(cb_id_in1);
constexpr uint32_t in1_block_size_bytes = in1_block_num_tiles * in1_single_tile_size_bytes;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void kernel_main() {
#endif

// WRITER
constexpr uint32_t cb_id_out0 = 16;
constexpr uint32_t cb_id_out0 = tt::CBIndex::c_4;
constexpr uint32_t output_single_tile_size_bytes = get_tile_size(cb_id_out0);
constexpr const uint32_t output_tile_hw = get_tile_hw(cb_id_out0);
constexpr DataFormat output_data_format = get_dataformat(cb_id_out0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void kernel_main() {
constexpr uint32_t cb_id_in1 = 1;

// WRITER
constexpr uint32_t cb_id_out0 = 16;
constexpr uint32_t cb_id_out0 = tt::CBIndex::c_4;

#ifdef IN1_SHARDED
const uint32_t in1_num_tiles = batch * num_blocks * in1_block_h * in1_block_w;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0(
.defines = mm_kernel_defines});

// Create circular buffers
uint32_t src0_cb_index = 0;
uint32_t src0_cb_index = tt::CBIndex::c_0;
tt_metal::CircularBufferConfig src0_cb_config =
tt_metal::CircularBufferConfig(in0_CB_size, {{src0_cb_index, in0_data_format}})
.set_page_size(src0_cb_index, in0_single_tile_size)
Expand All @@ -598,7 +598,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0(
in0_CB_size / in0_single_tile_size,
in0_CB_size);

uint32_t src1_cb_index = 1;
uint32_t src1_cb_index = tt::CBIndex::c_1;
tt_metal::CircularBufferConfig src1_cb_config =
tt_metal::CircularBufferConfig(in1_CB_size, {{src1_cb_index, in1_data_format}})
.set_page_size(src1_cb_index, in1_single_tile_size)
Expand All @@ -617,7 +617,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0(
in1_CB_size / in1_single_tile_size,
in1_CB_size);

uint32_t src2_cb_index = 2;
uint32_t src2_cb_index = tt::CBIndex::c_2;
CBHandle cb_src2 = 0;
if (in0_is_sharded) {
tt_metal::CircularBufferConfig src2_cb_config =
Expand All @@ -635,14 +635,14 @@ operation::ProgramWithCallbacks create_program_mcast_in0(
in2_CB_size);

// Local L1 to store temp vars
uint32_t l1_cb_index = 5;
uint32_t l1_cb_index = tt::CBIndex::c_6;
CircularBufferConfig cb_for_l1_array_config =
CircularBufferConfig(32 * 2, {{l1_cb_index, tt::DataFormat::Float16_b}}).set_page_size(l1_cb_index, 32 * 2);
tt_metal::CreateCircularBuffer(program, all_cores, cb_for_l1_array_config);
}

uint32_t output_cb_index = tt::CBIndex::c_16;
uint32_t interm0_cb_index = 24;
uint32_t output_cb_index = tt::CBIndex::c_4;
uint32_t interm0_cb_index = tt::CBIndex::c_5;
tt_metal::CircularBufferConfig interm0_cb_config =
tt_metal::CircularBufferConfig(0, {{interm0_cb_index, interm0_data_format}});
tt_metal::CircularBufferConfig output_cb_config =
Expand Down Expand Up @@ -698,7 +698,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0(

tt_metal::CBHandle cb_src3 = 0;
if (bias_buffer != nullptr) {
uint32_t src3_cb_index = 3;
uint32_t src3_cb_index = tt::CBIndex::c_3;
tt_metal::CircularBufferConfig cb_src3_config =
tt_metal::CircularBufferConfig(in3_CB_size, {{src3_cb_index, bias_data_format}})
.set_page_size(src3_cb_index, bias_single_tile_size)
Expand Down Expand Up @@ -1370,7 +1370,7 @@ operation::ProgramWithCallbacks create_program_mcast_in1(
.defines = mm_kernel_defines});

// Create circular buffers
uint32_t src0_cb_index = 0;
uint32_t src0_cb_index = tt::CBIndex::c_0;
tt_metal::CircularBufferConfig src0_cb_config =
tt_metal::CircularBufferConfig(in0_CB_size, {{src0_cb_index, in0_data_format}})
.set_page_size(src0_cb_index, in0_single_tile_size)
Expand All @@ -1387,7 +1387,7 @@ operation::ProgramWithCallbacks create_program_mcast_in1(
in0_CB_size / in0_single_tile_size,
in0_CB_size);

uint32_t src2_cb_index = 2;
uint32_t src2_cb_index = tt::CBIndex::c_2;
CBHandle cb_src2 = 0;
if (in0_is_sharded and extract_shard_sub_blocks) { // in0_is_sharded is technically redundant
tt_metal::CircularBufferConfig src2_cb_config =
Expand All @@ -1405,7 +1405,7 @@ operation::ProgramWithCallbacks create_program_mcast_in1(
in2_CB_size);
}

uint32_t src1_cb_index = 1;
uint32_t src1_cb_index = tt::CBIndex::c_1;
tt_metal::CircularBufferConfig src1_cb_config =
tt_metal::CircularBufferConfig(in1_CB_size, {{src1_cb_index, in1_data_format}})
.set_page_size(src1_cb_index, in1_single_tile_size)
Expand All @@ -1419,8 +1419,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1(
in1_CB_size / in1_single_tile_size,
in1_CB_size);

uint32_t output_cb_index = tt::CBIndex::c_16;
uint32_t interm0_cb_index = 24;
uint32_t output_cb_index = tt::CBIndex::c_4;
uint32_t interm0_cb_index = tt::CBIndex::c_5;
tt_metal::CircularBufferConfig interm0_cb_config =
tt_metal::CircularBufferConfig(0, {{interm0_cb_index, interm0_data_format}});
tt_metal::CircularBufferConfig output_cb_config =
Expand Down Expand Up @@ -1475,7 +1475,7 @@ operation::ProgramWithCallbacks create_program_mcast_in1(
out_CB_size);

if (bias_buffer != nullptr) {
uint32_t src3_cb_index = 3;
uint32_t src3_cb_index = tt::CBIndex::c_3;
tt_metal::CircularBufferConfig cb_src3_config =
tt_metal::CircularBufferConfig(in3_CB_size, {{src3_cb_index, bias_data_format}})
.set_page_size(src3_cb_index, bias_single_tile_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(
.defines = mm_kernel_defines});

// Create circular buffers
uint32_t src0_cb_index = 0;
uint32_t src0_cb_index = tt::CBIndex::c_0;
tt_metal::CircularBufferConfig src0_cb_config =
tt_metal::CircularBufferConfig(in0_CB_size, {{src0_cb_index, in0_data_format}})
.set_page_size(src0_cb_index, in0_single_tile_size)
Expand All @@ -727,7 +727,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(
in0_CB_size / in0_single_tile_size,
in0_CB_size);

uint32_t src1_cb_index = 1;
uint32_t src1_cb_index = tt::CBIndex::c_1;
tt_metal::CircularBufferConfig src1_cb_config =
tt_metal::CircularBufferConfig(in1_CB_size, {{src1_cb_index, in1_data_format}})
.set_page_size(src1_cb_index, in1_single_tile_size)
Expand All @@ -744,7 +744,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(
in1_CB_size / in1_single_tile_size,
in1_CB_size);

uint32_t src2_cb_index = 2;
uint32_t src2_cb_index = tt::CBIndex::c_2;
CBHandle cb_src2 = 0;
if (in0_block_sharded) {
tt_metal::CircularBufferConfig src2_cb_config =
Expand All @@ -762,14 +762,14 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(
in2_CB_size);

// Local L1 to store temp vars
uint32_t l1_cb_index = 5;
uint32_t l1_cb_index = tt::CBIndex::c_6;
CircularBufferConfig cb_for_l1_array_config =
CircularBufferConfig(32 * 2, {{l1_cb_index, tt::DataFormat::Float16_b}}).set_page_size(l1_cb_index, 32 * 2);
tt_metal::CreateCircularBuffer(program, all_cores, cb_for_l1_array_config);
}

uint32_t output_cb_index = tt::CBIndex::c_16;
uint32_t interm0_cb_index = 24;
uint32_t output_cb_index = tt::CBIndex::c_4;
uint32_t interm0_cb_index = tt::CBIndex::c_5;
tt_metal::CircularBufferConfig interm0_cb_config =
tt_metal::CircularBufferConfig(0, {{interm0_cb_index, interm0_data_format}});
tt_metal::CircularBufferConfig output_cb_config =
Expand Down Expand Up @@ -825,7 +825,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(

// CB for bias
if (bias_buffer != nullptr) {
uint32_t src3_cb_index = 3;
uint32_t src3_cb_index = tt::CBIndex::c_3;
tt_metal::CircularBufferConfig cb_src3_config =
tt_metal::CircularBufferConfig(in3_CB_size, {{src3_cb_index, bias_data_format}})
.set_page_size(src3_cb_index, bias_single_tile_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ operation::ProgramWithCallbacks create_program_dram_sharded(
log_debug(LogOp, "in1_single_tile_size: {}", in1_single_tile_size);

// Create circular buffers
uint32_t src0_cb_index = 0;
uint32_t src0_cb_index = tt::CBIndex::c_0;
tt_metal::CircularBufferConfig src0_cb_config =
tt_metal::CircularBufferConfig(in0_CB_size, {{src0_cb_index, in0_data_format}})
.set_page_size(src0_cb_index, in0_single_tile_size)
Expand All @@ -456,7 +456,7 @@ operation::ProgramWithCallbacks create_program_dram_sharded(
in0_CB_size / in0_single_tile_size,
in0_CB_size);

uint32_t src1_cb_index = 1;
uint32_t src1_cb_index = tt::CBIndex::c_1;
tt_metal::CircularBufferConfig src1_cb_config =
tt_metal::CircularBufferConfig(in1_CB_size, {{src1_cb_index, in1_data_format}})
.set_page_size(src1_cb_index, in1_single_tile_size)
Expand All @@ -470,7 +470,7 @@ operation::ProgramWithCallbacks create_program_dram_sharded(
in1_CB_size / in1_single_tile_size,
in1_CB_size);

uint32_t src2_cb_index = 2;
uint32_t src2_cb_index = tt::CBIndex::c_2;
tt_metal::CircularBufferConfig src2_cb_config =
tt_metal::CircularBufferConfig(in2_CB_size, {{src2_cb_index, in0_data_format}})
.set_page_size(src2_cb_index, in0_single_tile_size)
Expand All @@ -485,8 +485,8 @@ operation::ProgramWithCallbacks create_program_dram_sharded(
in2_CB_size / in0_single_tile_size,
in2_CB_size);

uint32_t output_cb_index = tt::CBIndex::c_16;
uint32_t interm0_cb_index = 24;
uint32_t output_cb_index = tt::CBIndex::c_4;
uint32_t interm0_cb_index = tt::CBIndex::c_5;
tt_metal::CircularBufferConfig interm0_cb_config =
tt_metal::CircularBufferConfig(0, {{interm0_cb_index, interm0_data_format}});
tt_metal::CircularBufferConfig output_cb_config =
Expand Down Expand Up @@ -537,7 +537,7 @@ operation::ProgramWithCallbacks create_program_dram_sharded(
out_CB_size);

// resharded output
uint32_t output_reshard_cb_index = 17;
uint32_t output_reshard_cb_index = tt::CBIndex::c_6;
std::map<uint8_t, tt::DataFormat> output_reshard_cb_data_format_spec{
{output_reshard_cb_index, output_data_format},
};
Expand All @@ -549,7 +549,7 @@ operation::ProgramWithCallbacks create_program_dram_sharded(
auto cb_output_reshard = tt_metal::CreateCircularBuffer(program, all_cores_in_rect_grid, output_reshard_cb_config);

if (bias_buffer != nullptr) {
uint32_t src3_cb_index = 3;
uint32_t src3_cb_index = tt::CBIndex::c_3;
tt_metal::CircularBufferConfig cb_src3_config =
tt_metal::CircularBufferConfig(in3_CB_size, {{src3_cb_index, bias_data_format}})
.set_page_size(src3_cb_index, bias_single_tile_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ operation::ProgramWithCallbacks create_program(
}

// Create circular buffers
uint32_t src0_cb_index = 0;
uint32_t src0_cb_index = tt::CBIndex::c_0;
tt_metal::CircularBufferConfig cb_src0_config =
tt_metal::CircularBufferConfig(in0_CB_size, {{src0_cb_index, in0_data_format}})
.set_page_size(src0_cb_index, in0_single_tile_size)
Expand All @@ -281,7 +281,7 @@ operation::ProgramWithCallbacks create_program(
}
auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config);

uint32_t src1_cb_index = 1;
uint32_t src1_cb_index = tt::CBIndex::c_1;
tt_metal::CircularBufferConfig cb_src1_config =
tt_metal::CircularBufferConfig(in1_CB_size, {{src1_cb_index, in1_data_format}})
.set_page_size(src1_cb_index, in1_single_tile_size)
Expand All @@ -291,8 +291,8 @@ operation::ProgramWithCallbacks create_program(
}
auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src1_config);

uint32_t output_cb_index = tt::CBIndex::c_16;
uint32_t interm0_cb_index = 24;
uint32_t output_cb_index = tt::CBIndex::c_4;
uint32_t interm0_cb_index = tt::CBIndex::c_5;
tt_metal::CircularBufferConfig interm0_cb_config =
tt_metal::CircularBufferConfig(0, {{interm0_cb_index, interm0_data_format}});
tt_metal::CircularBufferConfig output_cb_config =
Expand Down

0 comments on commit c203bf6

Please sign in to comment.