diff --git a/tt_metal/api/tt-metalium/circular_buffer_types.hpp b/tt_metal/api/tt-metalium/circular_buffer_types.hpp index 73da09c2603..b7afc3c034a 100644 --- a/tt_metal/api/tt-metalium/circular_buffer_types.hpp +++ b/tt_metal/api/tt-metalium/circular_buffer_types.hpp @@ -35,6 +35,20 @@ class CircularBufferConfig { CircularBufferConfig( uint32_t total_size, const std::map& data_format_spec, const Buffer& buffer); + // For flatbuffer deserialization, set all private members. + CircularBufferConfig( + uint32_t total_size, + std::optional globally_allocated_address, + const std::array, NUM_CIRCULAR_BUFFERS>& data_formats, + const std::array, NUM_CIRCULAR_BUFFERS>& page_sizes, + const std::array, NUM_CIRCULAR_BUFFERS>& tiles, + const std::unordered_set& buffer_indices, + const std::unordered_set& local_buffer_indices, + const std::unordered_set& remote_buffer_indices, + bool dynamic_cb, + uint32_t max_size, + uint32_t buffer_size); + CircularBufferConfig& set_page_size(uint8_t buffer_index, uint32_t page_size); CircularBufferConfig& set_total_size(uint32_t total_size); @@ -59,6 +73,11 @@ class CircularBufferConfig { const std::array, NUM_CIRCULAR_BUFFERS>& page_sizes() const; + // These 3 getters are not typically used, but needed for flatbuffer serialization + bool dynamic_cb() const; + uint32_t max_size() const; + uint32_t buffer_size() const; + const Buffer* shadow_global_buffer{nullptr}; class Builder { diff --git a/tt_metal/impl/CMakeLists.txt b/tt_metal/impl/CMakeLists.txt index e3ef8b4d276..be43586763b 100644 --- a/tt_metal/impl/CMakeLists.txt +++ b/tt_metal/impl/CMakeLists.txt @@ -45,12 +45,24 @@ set(IMPL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/trace/trace.cpp ${CMAKE_CURRENT_SOURCE_DIR}/trace/trace_buffer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event/event.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/base_types_from_flatbuffer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/base_types_to_flatbuffer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/buffer_types_from_flatbuffer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/buffer_types_to_flatbuffer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/program_types_from_flatbuffer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/program_types_to_flatbuffer.cpp ) # Include helper functions and generate headers from flatbuffer schemas include(flatbuffers) -set(FLATBUFFER_SCHEMAS) # Empty to start, coming soon. +set(FLATBUFFER_SCHEMAS + ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/light_metal_binary.fbs + ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/command.fbs + ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/base_types.fbs + ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/buffer_types.fbs + ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/program_types.fbs +) foreach(FBS_FILE ${FLATBUFFER_SCHEMAS}) GENERATE_FBS_HEADER(${FBS_FILE}) diff --git a/tt_metal/impl/buffers/circular_buffer_types.cpp b/tt_metal/impl/buffers/circular_buffer_types.cpp index 07be1fc60c4..259ce6989df 100644 --- a/tt_metal/impl/buffers/circular_buffer_types.cpp +++ b/tt_metal/impl/buffers/circular_buffer_types.cpp @@ -27,6 +27,31 @@ CircularBufferConfig::CircularBufferConfig( this->set_config(data_format_spec); } +// For flatbuffer deserialization, set all private members. +CircularBufferConfig::CircularBufferConfig( + uint32_t total_size, + std::optional globally_allocated_address, + const std::array, NUM_CIRCULAR_BUFFERS>& data_formats, + const std::array, NUM_CIRCULAR_BUFFERS>& page_sizes, + const std::array, NUM_CIRCULAR_BUFFERS>& tiles, + const std::unordered_set& buffer_indices, + const std::unordered_set& local_buffer_indices, + const std::unordered_set& remote_buffer_indices, + bool dynamic_cb, + uint32_t max_size, + uint32_t buffer_size) : + total_size_(total_size), + globally_allocated_address_(globally_allocated_address), + data_formats_(data_formats), + page_sizes_(page_sizes), + tiles_(tiles), + buffer_indices_(buffer_indices), + local_buffer_indices_(local_buffer_indices), + remote_buffer_indices_(remote_buffer_indices), + dynamic_cb_(dynamic_cb), + max_size_(max_size), + buffer_size_(buffer_size) {} + CircularBufferConfig& CircularBufferConfig::set_page_size(uint8_t buffer_index, uint32_t page_size) { if (buffer_index > NUM_CIRCULAR_BUFFERS - 1) { TT_THROW( @@ -156,6 +181,12 @@ const std::array, NUM_CIRCULAR_BUFFERS>& CircularBufferC return this->page_sizes_; } +bool CircularBufferConfig::dynamic_cb() const { return this->dynamic_cb_; } + +uint32_t CircularBufferConfig::max_size() const { return this->max_size_; } + +uint32_t CircularBufferConfig::buffer_size() const { return this->buffer_size_; } + CircularBufferConfig::Builder CircularBufferConfig::Builder::LocalBuilder( CircularBufferConfig& parent, uint8_t buffer_index) { auto is_remote_index = parent.remote_buffer_indices_.find(buffer_index) != parent.remote_buffer_indices_.end(); diff --git a/tt_metal/impl/flatbuffer/base_types.fbs b/tt_metal/impl/flatbuffer/base_types.fbs new file mode 100644 index 00000000000..d493274c33a --- /dev/null +++ b/tt_metal/impl/flatbuffer/base_types.fbs @@ -0,0 +1,77 @@ +namespace tt.tt_metal.flatbuffer; + + +enum Arch: uint { + Grayskull = 0, + Wormhole_b0 = 1, + Blackhole = 2, +} + +enum DataMovementProcessor : byte { + RISCV_0, + RISCV_1 +} + +enum NOC : byte { + NOC_0, + NOC_1 +} + +enum NOC_MODE : byte { + DM_DEDICATED_NOC, + DM_DYNAMIC_NOC +} + +enum EthMode : ubyte { + SENDER = 0, + RECEIVER = 1, + IDLE = 2 +} + +enum MathFidelity : ubyte { + LoFi = 0, + HiFi2 = 2, + HiFi3 = 3, + HiFi4 = 4, + Invalid = 255 +} + +enum DataFormat : uint8 { + Float32 = 0, + Float16 = 1, + Bfp8 = 2, + Bfp4 = 3, + Bfp2 = 11, + Float16_b = 5, + Bfp8_b = 6, + Bfp4_b = 7, + Bfp2_b = 15, + Lf8 = 10, + Fp8_e4m3 = 26, // 0x1A in decimal + Int8 = 14, + Tf32 = 4, + UInt8 = 30, + UInt16 = 9, + Int32 = 8, + UInt32 = 24, + RawUInt8 = 240, // 0xf0 in decimal + RawUInt16 = 241, // 0xf1 in decimal + RawUInt32 = 242, // 0xf2 in decimal + Invalid = 255 +} + +enum UnpackToDestMode : byte { + Default, + UnpackToDestFp32 +} + +table DefineEntry { + key: string; + value: string; +} + +// Rather than serialize all members, use Tile constructor arguments. +struct Tile { + tile_shape: [uint32:2]; + transpose_tile: bool; +} diff --git a/tt_metal/impl/flatbuffer/base_types_from_flatbuffer.cpp b/tt_metal/impl/flatbuffer/base_types_from_flatbuffer.cpp new file mode 100644 index 00000000000..99687403014 --- /dev/null +++ b/tt_metal/impl/flatbuffer/base_types_from_flatbuffer.cpp @@ -0,0 +1,96 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "flatbuffer/base_types_from_flatbuffer.hpp" + +namespace tt::tt_metal { + +DataMovementProcessor from_flatbuffer(flatbuffer::DataMovementProcessor in) { + switch (in) { + case flatbuffer::DataMovementProcessor::RISCV_0: return DataMovementProcessor::RISCV_0; + case flatbuffer::DataMovementProcessor::RISCV_1: return DataMovementProcessor::RISCV_1; + } + TT_THROW("Unsupported DataMovementProcessor from flatbuffer."); +} + +NOC from_flatbuffer(flatbuffer::NOC in) { + switch (in) { + case flatbuffer::NOC::NOC_0: return NOC::NOC_0; + case flatbuffer::NOC::NOC_1: return NOC::NOC_1; + } + TT_THROW("Unsupported NOC from flatbuffer."); +} + +NOC_MODE from_flatbuffer(flatbuffer::NOC_MODE in) { + switch (in) { + case flatbuffer::NOC_MODE::DM_DEDICATED_NOC: return NOC_MODE::DM_DEDICATED_NOC; + case flatbuffer::NOC_MODE::DM_DYNAMIC_NOC: return NOC_MODE::DM_DYNAMIC_NOC; + } + TT_THROW("Unsupported NOC_MODE from flatbuffer."); +} + +Eth from_flatbuffer(flatbuffer::EthMode in) { + switch (in) { + case flatbuffer::EthMode::SENDER: return Eth::SENDER; + case flatbuffer::EthMode::RECEIVER: return Eth::RECEIVER; + case flatbuffer::EthMode::IDLE: return Eth::IDLE; + } + TT_THROW("Unsupported EthMode from flatbuffer."); +} + +MathFidelity from_flatbuffer(flatbuffer::MathFidelity input) { + switch (input) { + case flatbuffer::MathFidelity::LoFi: return MathFidelity::LoFi; + case flatbuffer::MathFidelity::HiFi2: return MathFidelity::HiFi2; + case flatbuffer::MathFidelity::HiFi3: return MathFidelity::HiFi3; + case flatbuffer::MathFidelity::HiFi4: return MathFidelity::HiFi4; + case flatbuffer::MathFidelity::Invalid: return MathFidelity::Invalid; + } + TT_THROW("Unsupported MathFidelity from flatbuffer."); +} + +UnpackToDestMode from_flatbuffer(flatbuffer::UnpackToDestMode input) { + switch (input) { + case flatbuffer::UnpackToDestMode::UnpackToDestFp32: return UnpackToDestMode::UnpackToDestFp32; + case flatbuffer::UnpackToDestMode::Default: return UnpackToDestMode::Default; + } + TT_THROW("Unsupported UnpackToDestMode from flatbuffer."); +} + +tt::DataFormat from_flatbuffer(flatbuffer::DataFormat input) { + switch (input) { + case flatbuffer::DataFormat::Float32: return tt::DataFormat::Float32; + case flatbuffer::DataFormat::Float16: return tt::DataFormat::Float16; + case flatbuffer::DataFormat::Bfp8: return tt::DataFormat::Bfp8; + case flatbuffer::DataFormat::Bfp4: return tt::DataFormat::Bfp4; + case flatbuffer::DataFormat::Bfp2: return tt::DataFormat::Bfp2; + case flatbuffer::DataFormat::Float16_b: return tt::DataFormat::Float16_b; + case flatbuffer::DataFormat::Bfp8_b: return tt::DataFormat::Bfp8_b; + case flatbuffer::DataFormat::Bfp4_b: return tt::DataFormat::Bfp4_b; + case flatbuffer::DataFormat::Bfp2_b: return tt::DataFormat::Bfp2_b; + case flatbuffer::DataFormat::Lf8: return tt::DataFormat::Lf8; + case flatbuffer::DataFormat::Fp8_e4m3: return tt::DataFormat::Fp8_e4m3; + case flatbuffer::DataFormat::Int8: return tt::DataFormat::Int8; + case flatbuffer::DataFormat::Tf32: return tt::DataFormat::Tf32; + case flatbuffer::DataFormat::UInt8: return tt::DataFormat::UInt8; + case flatbuffer::DataFormat::UInt16: return tt::DataFormat::UInt16; + case flatbuffer::DataFormat::Int32: return tt::DataFormat::Int32; + case flatbuffer::DataFormat::UInt32: return tt::DataFormat::UInt32; + case flatbuffer::DataFormat::RawUInt8: return tt::DataFormat::RawUInt8; + case flatbuffer::DataFormat::RawUInt16: return tt::DataFormat::RawUInt16; + case flatbuffer::DataFormat::RawUInt32: return tt::DataFormat::RawUInt32; + case flatbuffer::DataFormat::Invalid: return tt::DataFormat::Invalid; + } + TT_THROW("Unsupported DataFormat from flatbuffer."); +} + +Tile from_flatbuffer(const flatbuffer::Tile& tile_fb) { + const auto& shape = *tile_fb.tile_shape(); + // Tile shape is already 2D in flatbuffer schema. + std::array tile_shape = {shape[0], shape[1]}; + bool transpose_tile = tile_fb.transpose_tile(); + return Tile(tile_shape, transpose_tile); +} + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/flatbuffer/base_types_from_flatbuffer.hpp b/tt_metal/impl/flatbuffer/base_types_from_flatbuffer.hpp new file mode 100644 index 00000000000..5463235f9a6 --- /dev/null +++ b/tt_metal/impl/flatbuffer/base_types_from_flatbuffer.hpp @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "base_types_generated.h" +#include +#include +#include +#include +#include +#include + +namespace tt::tt_metal { + +DataMovementProcessor from_flatbuffer(flatbuffer::DataMovementProcessor in); + +NOC from_flatbuffer(flatbuffer::NOC in); +NOC_MODE from_flatbuffer(flatbuffer::NOC_MODE in); +Eth from_flatbuffer(flatbuffer::EthMode in); + +MathFidelity from_flatbuffer(flatbuffer::MathFidelity input); +UnpackToDestMode from_flatbuffer(flatbuffer::UnpackToDestMode input); +tt::DataFormat from_flatbuffer(flatbuffer::DataFormat input); + +Tile from_flatbuffer(const flatbuffer::Tile& tile_fb); + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/flatbuffer/base_types_to_flatbuffer.cpp b/tt_metal/impl/flatbuffer/base_types_to_flatbuffer.cpp new file mode 100644 index 00000000000..6638ae0230c --- /dev/null +++ b/tt_metal/impl/flatbuffer/base_types_to_flatbuffer.cpp @@ -0,0 +1,99 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "flatbuffer/base_types_to_flatbuffer.hpp" + +namespace tt::tt_metal { + +// Original types defined in data_types.hpp +flatbuffer::DataMovementProcessor to_flatbuffer(DataMovementProcessor in) { + switch (in) { + case DataMovementProcessor::RISCV_0: return flatbuffer::DataMovementProcessor::RISCV_0; + case DataMovementProcessor::RISCV_1: return flatbuffer::DataMovementProcessor::RISCV_1; + } + TT_THROW("Unsupported DataMovementProcessor to flatbuffer."); +} + +flatbuffer::NOC to_flatbuffer(NOC in) { + switch (in) { + case NOC::NOC_0: return flatbuffer::NOC::NOC_0; + case NOC::NOC_1: return flatbuffer::NOC::NOC_1; + } + TT_THROW("Unsupported NOC to flatbuffer."); +} + +flatbuffer::NOC_MODE to_flatbuffer(NOC_MODE in) { + switch (in) { + case NOC_MODE::DM_DEDICATED_NOC: return flatbuffer::NOC_MODE::DM_DEDICATED_NOC; + case NOC_MODE::DM_DYNAMIC_NOC: return flatbuffer::NOC_MODE::DM_DYNAMIC_NOC; + } + TT_THROW("Unsupported NOC_MODE to flatbuffer."); +} + +flatbuffer::EthMode to_flatbuffer(Eth in) { + switch (in) { + case Eth::SENDER: return flatbuffer::EthMode::SENDER; + case Eth::RECEIVER: return flatbuffer::EthMode::RECEIVER; + case Eth::IDLE: return flatbuffer::EthMode::IDLE; + } + TT_THROW("Unsupported Eth to flatbuffer."); +} + +// Original types defined in base_types.hpp +flatbuffer::MathFidelity to_flatbuffer(MathFidelity input) { + switch (input) { + case MathFidelity::LoFi: return flatbuffer::MathFidelity::LoFi; + case MathFidelity::HiFi2: return flatbuffer::MathFidelity::HiFi2; + case MathFidelity::HiFi3: return flatbuffer::MathFidelity::HiFi3; + case MathFidelity::HiFi4: return flatbuffer::MathFidelity::HiFi4; + case MathFidelity::Invalid: return flatbuffer::MathFidelity::Invalid; + } + TT_THROW("Unsupported MathFidelity to flatbuffer."); +} + +flatbuffer::UnpackToDestMode to_flatbuffer(UnpackToDestMode input) { + switch (input) { + case UnpackToDestMode::UnpackToDestFp32: return flatbuffer::UnpackToDestMode::UnpackToDestFp32; + case UnpackToDestMode::Default: return flatbuffer::UnpackToDestMode::Default; + } + TT_THROW("Unsupported UnpackToDestMode to flatbuffer."); +} + +// Original types defined in tt_backend_api_types.hpp +flatbuffer::DataFormat to_flatbuffer(tt::DataFormat input) { + switch (input) { + case tt::DataFormat::Float32: return flatbuffer::DataFormat::Float32; + case tt::DataFormat::Float16: return flatbuffer::DataFormat::Float16; + case tt::DataFormat::Bfp8: return flatbuffer::DataFormat::Bfp8; + case tt::DataFormat::Bfp4: return flatbuffer::DataFormat::Bfp4; + case tt::DataFormat::Bfp2: return flatbuffer::DataFormat::Bfp2; + case tt::DataFormat::Float16_b: return flatbuffer::DataFormat::Float16_b; + case tt::DataFormat::Bfp8_b: return flatbuffer::DataFormat::Bfp8_b; + case tt::DataFormat::Bfp4_b: return flatbuffer::DataFormat::Bfp4_b; + case tt::DataFormat::Bfp2_b: return flatbuffer::DataFormat::Bfp2_b; + case tt::DataFormat::Lf8: return flatbuffer::DataFormat::Lf8; + case tt::DataFormat::Fp8_e4m3: return flatbuffer::DataFormat::Fp8_e4m3; + case tt::DataFormat::Int8: return flatbuffer::DataFormat::Int8; + case tt::DataFormat::Tf32: return flatbuffer::DataFormat::Tf32; + case tt::DataFormat::UInt8: return flatbuffer::DataFormat::UInt8; + case tt::DataFormat::UInt16: return flatbuffer::DataFormat::UInt16; + case tt::DataFormat::Int32: return flatbuffer::DataFormat::Int32; + case tt::DataFormat::UInt32: return flatbuffer::DataFormat::UInt32; + case tt::DataFormat::RawUInt8: return flatbuffer::DataFormat::RawUInt8; + case tt::DataFormat::RawUInt16: return flatbuffer::DataFormat::RawUInt16; + case tt::DataFormat::RawUInt32: return flatbuffer::DataFormat::RawUInt32; + case tt::DataFormat::Invalid: return flatbuffer::DataFormat::Invalid; + } + TT_THROW("Unsupported DataFormat to flatbuffer."); +} + +flatbuffer::Tile to_flatbuffer(const Tile& tile) { + TT_FATAL(tile.get_tile_shape().size() == 2, "Conversion to Flatbuffer expecting 2D Tile Shapes."); + std::array shape = {tile.get_tile_shape()[0], tile.get_tile_shape()[1]}; + + return flatbuffer::Tile( + flatbuffers::span(shape), tile.get_transpose_within_face() && tile.get_transpose_of_faces()); +} + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/flatbuffer/base_types_to_flatbuffer.hpp b/tt_metal/impl/flatbuffer/base_types_to_flatbuffer.hpp new file mode 100644 index 00000000000..ebffd4de5f6 --- /dev/null +++ b/tt_metal/impl/flatbuffer/base_types_to_flatbuffer.hpp @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "base_types_generated.h" +#include +#include +#include +#include +#include +#include + +namespace tt::tt_metal { + +flatbuffer::DataMovementProcessor to_flatbuffer(DataMovementProcessor in); +flatbuffer::NOC to_flatbuffer(NOC in); +flatbuffer::NOC_MODE to_flatbuffer(NOC_MODE in); +flatbuffer::EthMode to_flatbuffer(Eth in); + +flatbuffer::MathFidelity to_flatbuffer(MathFidelity input); +flatbuffer::UnpackToDestMode to_flatbuffer(UnpackToDestMode input); +flatbuffer::DataFormat to_flatbuffer(tt::DataFormat input); + +flatbuffer::Tile to_flatbuffer(const Tile& tile); + +flatbuffers::Offset>> to_flatbuffer( + const std::array, NUM_CIRCULAR_BUFFERS>& tiles, flatbuffers::FlatBufferBuilder& builder); + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/flatbuffer/buffer_types.fbs b/tt_metal/impl/flatbuffer/buffer_types.fbs new file mode 100644 index 00000000000..988d5e812b4 --- /dev/null +++ b/tt_metal/impl/flatbuffer/buffer_types.fbs @@ -0,0 +1,59 @@ +include "flatbuffer/base_types.fbs"; + +namespace tt.tt_metal.flatbuffer; + +enum BufferType: ushort { + DRAM = 0, + L1 = 1, + SystemMemory = 2, + L1Small = 3, + Trace = 4, +} + +enum TensorMemoryLayout: ushort { + None = 0, + Interleaved = 1, + SingleBank = 2, + HeightSharded = 3, + WidthSharded = 4, + BlockSharded = 5, +} + +table InterleavedBufferConfig { + device_id: int; // Reference to IDevice *device; + size: int; // Size in bytes + page_size: int; // Size of unit being interleaved. For non-interleaved buffers: size == page_size + buffer_type: BufferType; + buffer_layout: TensorMemoryLayout; +} + +struct CBConfigPageSize { + index: uint32; // The index in the array + size: uint32; // The page-size value for this index +} + +struct CBConfigDataFormat { + index: uint32; // The index in the array + format: DataFormat; // The data format for this index +} + + +struct CBConfigTile { + index: uint32; // The index in the array + tile: Tile; // The tile for this index +} + +table CircularBufferConfig { + total_size: uint32; + globally_allocated_address: uint32; // Optional behavior can be handled with a default value (or union) + data_formats: [CBConfigDataFormat]; // Mimic optional array in C++ by using KV map. + page_sizes: [CBConfigPageSize]; // Mimic optional array in C++ by using KV map. + tiles: [CBConfigTile]; // Mimic optional array in C++ by using KV map. + shadow_buf_global_id: uint32; + buffer_indices: [uint8]; + local_buffer_indices: [uint8]; + remote_buffer_indices: [uint8]; + dynamic_cb: bool; + max_size: uint32; + buffer_size: uint32; +} diff --git a/tt_metal/impl/flatbuffer/buffer_types_from_flatbuffer.cpp b/tt_metal/impl/flatbuffer/buffer_types_from_flatbuffer.cpp new file mode 100644 index 00000000000..ae5df1ab2fc --- /dev/null +++ b/tt_metal/impl/flatbuffer/buffer_types_from_flatbuffer.cpp @@ -0,0 +1,78 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "flatbuffer/buffer_types_from_flatbuffer.hpp" + +namespace tt::tt_metal { + +BufferType from_flatbuffer(flatbuffer::BufferType type) { + switch (type) { + case flatbuffer::BufferType::DRAM: return BufferType::DRAM; + case flatbuffer::BufferType::L1: return BufferType::L1; + case flatbuffer::BufferType::SystemMemory: return BufferType::SYSTEM_MEMORY; + case flatbuffer::BufferType::L1Small: return BufferType::L1_SMALL; + case flatbuffer::BufferType::Trace: return BufferType::TRACE; + } + TT_THROW("Unsupported BufferType from flatbuffer."); +} + +CircularBufferConfig from_flatbuffer( + const flatbuffer::CircularBufferConfig* config_fb, const Buffer* shadow_global_buffer) { + TT_FATAL(config_fb, "Invalid CircularBufferConfig FlatBuffer object"); + + std::optional globally_allocated_address = + (config_fb->globally_allocated_address() == 0) + ? std::nullopt + : std::optional(config_fb->globally_allocated_address()); + + std::array, NUM_CIRCULAR_BUFFERS> data_formats = {}; + if (config_fb->data_formats()) { + for (auto entry : *config_fb->data_formats()) { + data_formats[entry->index()] = from_flatbuffer(entry->format()); + } + } + + std::array, NUM_CIRCULAR_BUFFERS> page_sizes = {}; + if (config_fb->page_sizes()) { + for (auto entry : *config_fb->page_sizes()) { + page_sizes[entry->index()] = entry->size(); + } + } + + std::array, NUM_CIRCULAR_BUFFERS> tiles = {}; + if (config_fb->tiles()) { + for (auto entry : *config_fb->tiles()) { + tiles[entry->index()] = from_flatbuffer(entry->tile()); + } + } + + // Convert FlatBuffer vector to unordered_set of uint8_t + auto create_uint8_set = [](auto* fb_vector) { + std::unordered_set result; + if (fb_vector) { + result.insert(fb_vector->begin(), fb_vector->end()); + } + return result; + }; + + // Constructor supports being able to specify all private members. shadow_global_buffer is public. + CircularBufferConfig config( + config_fb->total_size(), + globally_allocated_address, + data_formats, + page_sizes, + tiles, + create_uint8_set(config_fb->buffer_indices()), + create_uint8_set(config_fb->local_buffer_indices()), + create_uint8_set(config_fb->remote_buffer_indices()), + config_fb->dynamic_cb(), + config_fb->max_size(), + config_fb->buffer_size()); + + config.shadow_global_buffer = shadow_global_buffer; + + return config; +} + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/flatbuffer/buffer_types_from_flatbuffer.hpp b/tt_metal/impl/flatbuffer/buffer_types_from_flatbuffer.hpp new file mode 100644 index 00000000000..74a5579ca51 --- /dev/null +++ b/tt_metal/impl/flatbuffer/buffer_types_from_flatbuffer.hpp @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "buffer_types_generated.h" +#include +#include "flatbuffer/base_types_from_flatbuffer.hpp" + +namespace tt::tt_metal { + +BufferType from_flatbuffer(flatbuffer::BufferType type); + +CircularBufferConfig from_flatbuffer( + const flatbuffer::CircularBufferConfig* config_fb, const Buffer* shadow_global_buffer); + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/flatbuffer/buffer_types_to_flatbuffer.cpp b/tt_metal/impl/flatbuffer/buffer_types_to_flatbuffer.cpp new file mode 100644 index 00000000000..0c3f4c3822b --- /dev/null +++ b/tt_metal/impl/flatbuffer/buffer_types_to_flatbuffer.cpp @@ -0,0 +1,79 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "flatbuffer/buffer_types_to_flatbuffer.hpp" + +namespace tt::tt_metal { + +// Original types defined in buffer_constants.hpp +flatbuffer::BufferType to_flatbuffer(BufferType type) { + switch (type) { + case BufferType::DRAM: return flatbuffer::BufferType::DRAM; + case BufferType::L1: return flatbuffer::BufferType::L1; + case BufferType::SYSTEM_MEMORY: return flatbuffer::BufferType::SystemMemory; + case BufferType::L1_SMALL: return flatbuffer::BufferType::L1Small; + case BufferType::TRACE: return flatbuffer::BufferType::Trace; + } + TT_THROW("Unsupported BufferType to flatbuffer."); +} + +// Original types defined in buffer_constants.hpp +flatbuffer::TensorMemoryLayout to_flatbuffer(TensorMemoryLayout layout) { + switch (layout) { + case TensorMemoryLayout::INTERLEAVED: return flatbuffer::TensorMemoryLayout::Interleaved; + case TensorMemoryLayout::SINGLE_BANK: return flatbuffer::TensorMemoryLayout::SingleBank; + case TensorMemoryLayout::HEIGHT_SHARDED: return flatbuffer::TensorMemoryLayout::HeightSharded; + case TensorMemoryLayout::WIDTH_SHARDED: return flatbuffer::TensorMemoryLayout::WidthSharded; + case TensorMemoryLayout::BLOCK_SHARDED: return flatbuffer::TensorMemoryLayout::BlockSharded; + } + TT_THROW("Unsupported TensorMemoryLayout to flatbuffer."); +} + +// For page sizes, keep lambda usage consistent across types. +static inline uint32_t to_flatbuffer(const uint32_t& value) { return value; } + +// Original type defined in circular_buffer_types.hpp +flatbuffers::Offset to_flatbuffer( + const CircularBufferConfig& config, flatbuffers::FlatBufferBuilder& builder) { + // Convert optional arrays of various types to Flatbuffers vectors. + auto create_fb_vec_of_structs = [&](const auto& array, auto fb_type_tag) { + using FlatBufferType = decltype(fb_type_tag); + std::vector vec; + for (size_t i = 0; i < array.size(); i++) { + if (array[i]) { + vec.push_back(FlatBufferType{i, to_flatbuffer(*array[i])}); + } + } + return builder.CreateVectorOfStructs(vec); + }; + + // Convert unordered_set of uint8_t to FlatBuffer vector + auto create_fb_vec_of_uint8 = [&](const auto& set) { + return builder.CreateVector(std::vector(set.begin(), set.end())); + }; + + // Optional shadow buffer for dynamically allocated CBs, get global_id or use 0 as none/nullptr. + // auto& ctx = LightMetalCaptureContext::Get(); + // auto shadow_buf_global_id = config.shadow_global_buffer ? ctx.GetGlobalId(config.shadow_global_buffer) : 0; + // TODO (kmabee) - Uncomment above code once capture library is merged. Temp hack here for now. + uint32_t shadow_buf_global_id = 0; + + // Create the FlatBuffer object + return flatbuffer::CreateCircularBufferConfig( + builder, + config.total_size(), + config.globally_allocated_address().value_or(0), // Optional, default 0 if nullopt. + create_fb_vec_of_structs(config.data_formats(), flatbuffer::CBConfigDataFormat{}), + create_fb_vec_of_structs(config.page_sizes(), flatbuffer::CBConfigPageSize{}), + create_fb_vec_of_structs(config.tiles(), flatbuffer::CBConfigTile{}), + shadow_buf_global_id, + create_fb_vec_of_uint8(config.buffer_indices()), + create_fb_vec_of_uint8(config.local_buffer_indices()), + create_fb_vec_of_uint8(config.remote_buffer_indices()), + config.dynamic_cb(), + config.max_size(), + config.buffer_size()); +} + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/flatbuffer/buffer_types_to_flatbuffer.hpp b/tt_metal/impl/flatbuffer/buffer_types_to_flatbuffer.hpp new file mode 100644 index 00000000000..a963488da09 --- /dev/null +++ b/tt_metal/impl/flatbuffer/buffer_types_to_flatbuffer.hpp @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "buffer_types_generated.h" +#include "flatbuffer/base_types_to_flatbuffer.hpp" +#include + +namespace tt::tt_metal { + +flatbuffer::BufferType to_flatbuffer(BufferType type); +flatbuffer::TensorMemoryLayout to_flatbuffer(TensorMemoryLayout layout); + +flatbuffers::Offset to_flatbuffer( + const CircularBufferConfig& config, flatbuffers::FlatBufferBuilder& builder); + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/flatbuffer/command.fbs b/tt_metal/impl/flatbuffer/command.fbs new file mode 100644 index 00000000000..b21a4a5dba2 --- /dev/null +++ b/tt_metal/impl/flatbuffer/command.fbs @@ -0,0 +1,125 @@ +// Define schema for tracing host API calls, called Commands in this context. +include "flatbuffer/buffer_types.fbs"; +include "flatbuffer/program_types.fbs"; + +namespace tt.tt_metal.flatbuffer; + +table ReplayTraceCommand { + // TODO (kmabee) - add device. + cq_id: int; + tid: int; + blocking: bool; +} + +table EnqueueTraceCommand { + // TODO (kmabee) - add device. + cq_id: int; + tid: int; + blocking: bool; +} + +table LoadTraceCommand { + tid: int; // Pointer to trace data. + cq_id: int; +} + +table ReleaseTraceCommand { + // TODO (kmabee) - add device. + tid: int; // Pointer to trace data. +} + +table CreateBufferCommand { + global_id: uint32; + config: InterleavedBufferConfig; // Later grow to union for Sharded. + address: uint32; // Optional for pre-allocated buffers. +} + +table DeallocateBufferCommand { + global_id: uint32; // Reference to Buffer to be deallocated +} + +table EnqueueWriteBufferCommand { + cq_global_id: uint32; // reference to CommandQueue + buffer_global_id: uint32; // Reference to Buffer used as destination + src: [uint32]; // Data to be written. Support only some types for now. + blocking: bool; +} + +table EnqueueReadBufferCommand { + cq_global_id: uint32; // reference to CommandQueue + buffer_global_id: uint32; // Reference to Buffer used as source + blocking: bool; +} + +table FinishCommand { + cq_global_id: uint32; // reference to CommandQueue + sub_device_ids: [ubyte]; // array of uint8 values representing SubDeviceId::Id +} + +table CreateProgramCommand { + global_id: uint32; +} + +table EnqueueProgramCommand { + cq_global_id: uint32; // reference to CommandQueue + program_global_id: uint32; // Reference to Program + blocking: bool; +} + +table CreateKernelCommand { + global_id: uint32; // Reference to Kernel + program_global_id: uint32; // Reference to Program + file_name: string; // Later replace with src, then binary + core_spec: CoreSpec; + kernel_config: KernelConfig; +} + +table SetRuntimeArgsUint32Command { + program_global_id: uint32; // Reference to Program + kernel_global_id: uint32; // Reference to Kernel + core_spec: CoreSpec; + args: [uint32]; // Arguments to be passed to kernel +} + +table SetRuntimeArgsCommand { + kernel_global_id: uint32; // Reference to Kernel + core_spec: CoreSpec; + args: [RuntimeArg]; // Arguments to be passed to kernel +} + +table CreateCircularBufferCommand { + global_id: uint32; // Reference to CBHandle + program_global_id: uint32; // Reference to Program + core_spec: CoreSpec; + config: CircularBufferConfig; +} + +table LightMetalCompareCommand { + cq_global_id: uint32; // reference to CommandQueue + buffer_global_id: uint32; // Reference to Buffer used as destination + golden_data: [uint32]; // Golden data to compare against at replay + is_user_data: bool; // Informational, denote if golden data is from user or capture +} + +union CommandType { + ReplayTraceCommand, + EnqueueTraceCommand, + LoadTraceCommand, + ReleaseTraceCommand, + CreateBufferCommand, + DeallocateBufferCommand, + EnqueueWriteBufferCommand, + EnqueueReadBufferCommand, + FinishCommand, + CreateProgramCommand, + EnqueueProgramCommand, + CreateKernelCommand, + SetRuntimeArgsUint32Command, + SetRuntimeArgsCommand, + CreateCircularBufferCommand, + LightMetalCompareCommand, +} + +table Command { + cmd: CommandType; +} diff --git a/tt_metal/impl/flatbuffer/light_metal_binary.fbs b/tt_metal/impl/flatbuffer/light_metal_binary.fbs new file mode 100644 index 00000000000..619e69bf01c --- /dev/null +++ b/tt_metal/impl/flatbuffer/light_metal_binary.fbs @@ -0,0 +1,38 @@ +include "flatbuffer/command.fbs"; + +namespace tt.tt_metal.flatbuffer; + +// Represents the Descriptor struct inside TraceDescriptor, given slightly less vague name here. +table TraceDescriptorMetaData { + num_completion_worker_cores: uint32; + num_traced_programs_needing_go_signal_multicast: uint32; + num_traced_programs_needing_go_signal_unicast: uint32; +} + +// Represents a key-value pair for SubDeviceId -> TraceDescriptorMetaData mapping +table SubDeviceDescriptorMapping { + sub_device_id: uint8; + descriptor: TraceDescriptorMetaData; +} + +// Matches C++ struct TraceDescriptor +table TraceDescriptor { + trace_data: [uint32]; + sub_device_descriptors: [SubDeviceDescriptorMapping]; // Vector of key-value pairs + sub_device_ids: [uint8]; // Optimized vector of sub_device_ids +} + +// Associate key (trace_id) to value (TraceDescriptor) +table TraceDescriptorByTraceId { + trace_id: uint32 (key); + desc: TraceDescriptor; +} + +// Top level Binary to represent a host+device workload as LightMetalBinary. +table LightMetalBinary { + // TODO (kmabee) - Git Hash, Versioning, SystemDesc, etc. + commands: [tt.tt_metal.flatbuffer.Command]; + trace_descriptors: [TraceDescriptorByTraceId]; // Metal "Traces" +} + +root_type LightMetalBinary; diff --git a/tt_metal/impl/flatbuffer/program_types.fbs b/tt_metal/impl/flatbuffer/program_types.fbs new file mode 100644 index 00000000000..0d3b338fc90 --- /dev/null +++ b/tt_metal/impl/flatbuffer/program_types.fbs @@ -0,0 +1,74 @@ +include "flatbuffer/base_types.fbs"; + +namespace tt.tt_metal.flatbuffer; + +table CoreCoord { + x: int; + y: int; +} + +table CoreRange { + start: CoreCoord; + end: CoreCoord; +} + +table CoreRangeSet { + ranges: [CoreRange]; +} + +union CoreSpec { + CoreCoord, + CoreRange, + CoreRangeSet +} + +table DataMovementConfig { + processor: DataMovementProcessor; + noc: NOC; + noc_mode: NOC_MODE; + compile_args: [uint32]; // Array of compile arguments + defines: [DefineEntry]; // Key-value pair map for defines +} + +table ComputeConfig { + math_fidelity: MathFidelity; + fp32_dest_acc_en: bool; + dst_full_sync_en: bool; + unpack_to_dest_mode: [UnpackToDestMode]; // Array of unpack modes + bfp8_pack_precise: bool; + math_approx_mode: bool; + compile_args: [uint32]; // Array of compile arguments + defines: [DefineEntry]; // Key-value pair map for defines +} + +table EthernetConfig { + eth_mode: EthMode; + noc: NOC; + processor: DataMovementProcessor; + compile_args: [uint32]; // Array of compile arguments + defines: [DefineEntry]; // Key-value pair map for defines +} + +// Union to include multiple configurations +union KernelConfig { + DataMovementConfig, + ComputeConfig, + EthernetConfig +} + +struct UInt32Value { + value: uint32; +} + +struct BufferGlobalId { + id: uint32; +} + +union RuntimeArgValue { + UInt32Value, + BufferGlobalId, +} + +table RuntimeArg { + value: RuntimeArgValue; +} diff --git a/tt_metal/impl/flatbuffer/program_types_from_flatbuffer.cpp b/tt_metal/impl/flatbuffer/program_types_from_flatbuffer.cpp new file mode 100644 index 00000000000..27bac75649e --- /dev/null +++ b/tt_metal/impl/flatbuffer/program_types_from_flatbuffer.cpp @@ -0,0 +1,136 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "flatbuffer/program_types_from_flatbuffer.hpp" +#include "flatbuffer/base_types_from_flatbuffer.hpp" + +namespace tt::tt_metal { + +std::variant from_flatbuffer( + const flatbuffer::CoreSpec core_spec, const void* flatbuffer_union) { + switch (core_spec) { + case flatbuffer::CoreSpec::CoreCoord: { + auto core_coord = static_cast(flatbuffer_union); + TT_FATAL(core_coord, "Invalid CoreCoord data"); + return CoreCoord{core_coord->x(), core_coord->y()}; + } + case flatbuffer::CoreSpec::CoreRange: { + auto core_range = static_cast(flatbuffer_union); + TT_FATAL(core_range, "Invalid CoreRange data"); + return CoreRange{ + {core_range->start()->x(), core_range->start()->y()}, {core_range->end()->x(), core_range->end()->y()}}; + } + case flatbuffer::CoreSpec::CoreRangeSet: { + auto core_range_set = static_cast(flatbuffer_union); + TT_FATAL(core_range_set, "Invalid CoreRangeSet data"); + std::vector ranges; + for (const auto range : *core_range_set->ranges()) { + ranges.emplace_back( + CoreCoord{range->start()->x(), range->start()->y()}, + CoreCoord{range->end()->x(), range->end()->y()}); + } + return CoreRangeSet{ranges}; + } + default: throw std::runtime_error("Unhandled CoreSpec type in from_flatbuffer"); + } +} + +DataMovementConfig from_flatbuffer(const flatbuffer::DataMovementConfig* fb_config) { + DataMovementConfig config; + + // Extract processor, noc, and noc_mode + config.processor = from_flatbuffer(fb_config->processor()); + config.noc = from_flatbuffer(fb_config->noc()); + config.noc_mode = from_flatbuffer(fb_config->noc_mode()); + + // Extract compile_args + auto fb_compile_args = fb_config->compile_args(); + config.compile_args.assign(fb_compile_args->begin(), fb_compile_args->end()); + + // Extract defines + auto fb_defines = fb_config->defines(); + for (auto fb_define : *fb_defines) { + config.defines.emplace(fb_define->key()->str(), fb_define->value()->str()); + } + + return config; +} + +ComputeConfig from_flatbuffer(const flatbuffer::ComputeConfig* fb_config) { + ComputeConfig config; + + // Extract math_fidelity and boolean flags + config.math_fidelity = from_flatbuffer(fb_config->math_fidelity()); + config.fp32_dest_acc_en = fb_config->fp32_dest_acc_en(); + config.dst_full_sync_en = fb_config->dst_full_sync_en(); + config.bfp8_pack_precise = fb_config->bfp8_pack_precise(); + config.math_approx_mode = fb_config->math_approx_mode(); + + // Extract unpack_to_dest_mode + auto fb_unpack_modes = fb_config->unpack_to_dest_mode(); + config.unpack_to_dest_mode.reserve(fb_unpack_modes->size()); + for (auto fb_mode : *fb_unpack_modes) { + config.unpack_to_dest_mode.push_back(from_flatbuffer(fb_mode)); + } + + // Extract compile_args + auto fb_compile_args = fb_config->compile_args(); + config.compile_args.assign(fb_compile_args->begin(), fb_compile_args->end()); + + // Extract defines + auto fb_defines = fb_config->defines(); + for (auto fb_define : *fb_defines) { + config.defines.emplace(fb_define->key()->str(), fb_define->value()->str()); + } + + return config; +} + +EthernetConfig from_flatbuffer(const flatbuffer::EthernetConfig* fb_config) { + EthernetConfig config; + + // Extract eth_mode, noc, and processor + config.eth_mode = from_flatbuffer(fb_config->eth_mode()); + config.noc = from_flatbuffer(fb_config->noc()); + config.processor = from_flatbuffer(fb_config->processor()); + + // Extract compile_args + auto fb_compile_args = fb_config->compile_args(); + config.compile_args.assign(fb_compile_args->begin(), fb_compile_args->end()); + + // Extract defines + auto fb_defines = fb_config->defines(); + for (auto fb_define : *fb_defines) { + config.defines.emplace(fb_define->key()->str(), fb_define->value()->str()); + } + + return config; +} + +std::variant from_flatbuffer( + const flatbuffer::KernelConfig config_type, const void* flatbuffer_union) { + switch (config_type) { + case flatbuffer::KernelConfig::DataMovementConfig: + return from_flatbuffer(static_cast(flatbuffer_union)); + case flatbuffer::KernelConfig::ComputeConfig: + return from_flatbuffer(static_cast(flatbuffer_union)); + case flatbuffer::KernelConfig::EthernetConfig: + return from_flatbuffer(static_cast(flatbuffer_union)); + case flatbuffer::KernelConfig::NONE: + throw std::runtime_error("Unhandled KernelConfig type in from_flatbuffer."); + } + TT_THROW("Unhandled KernelConfig type in from_flatbuffer."); +} + +std::vector from_flatbuffer(const flatbuffers::Vector* fb_sub_device_ids) { + std::vector sub_device_ids(fb_sub_device_ids ? fb_sub_device_ids->size() : 0); + + for (size_t i = 0; i < sub_device_ids.size(); ++i) { + sub_device_ids[i] = SubDeviceId{(*fb_sub_device_ids)[i]}; + } + + return sub_device_ids; +} + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/flatbuffer/program_types_from_flatbuffer.hpp b/tt_metal/impl/flatbuffer/program_types_from_flatbuffer.hpp new file mode 100644 index 00000000000..f6742c1d3ba --- /dev/null +++ b/tt_metal/impl/flatbuffer/program_types_from_flatbuffer.hpp @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "program_types_generated.h" +#include +#include +#include + +namespace tt::tt_metal { + +std::variant from_flatbuffer( + const flatbuffer::CoreSpec core_spec, const void* flatbuffer_union); + +DataMovementConfig from_flatbuffer(const flatbuffer::DataMovementConfig* fb_config); +ComputeConfig from_flatbuffer(const flatbuffer::ComputeConfig* fb_config); +EthernetConfig from_flatbuffer(const flatbuffer::EthernetConfig* fb_config); + +std::variant from_flatbuffer( + const flatbuffer::KernelConfig config_type, const void* flatbuffer_union); + +std::vector from_flatbuffer(const flatbuffers::Vector* fb_sub_device_ids); + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/flatbuffer/program_types_to_flatbuffer.cpp b/tt_metal/impl/flatbuffer/program_types_to_flatbuffer.cpp new file mode 100644 index 00000000000..a3d8e875819 --- /dev/null +++ b/tt_metal/impl/flatbuffer/program_types_to_flatbuffer.cpp @@ -0,0 +1,190 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "flatbuffer/base_types_to_flatbuffer.hpp" +#include "flatbuffer/program_types_to_flatbuffer.hpp" +#include +namespace tt::tt_metal { + +// Original types defined in core_coord.hpp +std::pair> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const std::variant& core_spec) { + return std::visit( + tt::stl::overloaded{ + [&](const CoreCoord& spec) -> std::pair> { + auto core_coord = flatbuffer::CreateCoreCoord(builder, spec.x, spec.y); + return {flatbuffer::CoreSpec::CoreCoord, core_coord.Union()}; + }, + [&](const CoreRange& spec) -> std::pair> { + auto start = flatbuffer::CreateCoreCoord(builder, spec.start_coord.x, spec.start_coord.y); + auto end = flatbuffer::CreateCoreCoord(builder, spec.end_coord.x, spec.end_coord.y); + auto core_range = flatbuffer::CreateCoreRange(builder, start, end); + return {flatbuffer::CoreSpec::CoreRange, core_range.Union()}; + }, + [&](const CoreRangeSet& spec) -> std::pair> { + std::vector> range_offsets; + for (const auto& range : spec.ranges()) { + auto start = flatbuffer::CreateCoreCoord(builder, range.start_coord.x, range.start_coord.y); + auto end = flatbuffer::CreateCoreCoord(builder, range.end_coord.x, range.end_coord.y); + range_offsets.push_back(flatbuffer::CreateCoreRange(builder, start, end)); + } + auto ranges_vector = builder.CreateVector(range_offsets); + auto core_range_set = flatbuffer::CreateCoreRangeSet(builder, ranges_vector); + return {flatbuffer::CoreSpec::CoreRangeSet, core_range_set.Union()}; + }}, + core_spec); +} + +// Original types defined in kernel_types.hpp +std::pair> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const DataMovementConfig& config) { + // Convert defines (map) to FlatBuffer format + std::vector> defines_vector; + for (const auto& [key, value] : config.defines) { + auto key_offset = builder.CreateString(key); + auto value_offset = builder.CreateString(value); + defines_vector.push_back(flatbuffer::CreateDefineEntry(builder, key_offset, value_offset)); + } + auto defines_offset = builder.CreateVector(defines_vector); + + auto compile_args_offset = builder.CreateVector(config.compile_args); + auto config_offset = flatbuffer::CreateDataMovementConfig( + builder, + to_flatbuffer(config.processor), + to_flatbuffer(config.noc), + to_flatbuffer(config.noc_mode), + compile_args_offset, + defines_offset); + + return {flatbuffer::KernelConfig::DataMovementConfig, config_offset.Union()}; +} + +std::pair> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const ComputeConfig& config) { + // Convert defines (map) to FlatBuffer format + std::vector> defines_vector; + for (const auto& [key, value] : config.defines) { + auto key_offset = builder.CreateString(key); + auto value_offset = builder.CreateString(value); + defines_vector.push_back(flatbuffer::CreateDefineEntry(builder, key_offset, value_offset)); + } + auto defines_offset = builder.CreateVector(defines_vector); + + // Convert unpack_to_dest_mode to FlatBuffer format + std::vector unpack_modes; + for (const auto& mode : config.unpack_to_dest_mode) { + unpack_modes.push_back(to_flatbuffer(mode)); + } + auto unpack_modes_offset = builder.CreateVector(unpack_modes); + + auto compile_args_offset = builder.CreateVector(config.compile_args); + auto config_offset = flatbuffer::CreateComputeConfig( + builder, + to_flatbuffer(config.math_fidelity), + config.fp32_dest_acc_en, + config.dst_full_sync_en, + unpack_modes_offset, + config.bfp8_pack_precise, + config.math_approx_mode, + compile_args_offset, + defines_offset); + + return {flatbuffer::KernelConfig::ComputeConfig, config_offset.Union()}; +} + +std::pair> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const EthernetConfig& config) { + // Convert defines (map) to FlatBuffer format + std::vector> defines_vector; + for (const auto& [key, value] : config.defines) { + auto key_offset = builder.CreateString(key); + auto value_offset = builder.CreateString(value); + defines_vector.push_back(flatbuffer::CreateDefineEntry(builder, key_offset, value_offset)); + } + auto defines_offset = builder.CreateVector(defines_vector); + + auto compile_args_offset = builder.CreateVector(config.compile_args); + auto config_offset = flatbuffer::CreateEthernetConfig( + builder, + to_flatbuffer(config.eth_mode), + to_flatbuffer(config.noc), + to_flatbuffer(config.processor), + compile_args_offset, + defines_offset); + + return {flatbuffer::KernelConfig::EthernetConfig, config_offset.Union()}; +} + +// Generic function for variant, specialized for each type above. +std::pair> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, + const std::variant& config) { + return std::visit( + [&](auto&& cfg) { + using T = std::decay_t; + static_assert( + std::is_same_v || std::is_same_v || + std::is_same_v, + "Unhandled config type in to_flatbuffer."); + return to_flatbuffer(builder, cfg); + }, + config); +} + +std::pair> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const ReaderDataMovementConfig& config) { + const DataMovementConfig& base_config = config; // Cast to base + return to_flatbuffer(builder, base_config); +} + +std::pair> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const WriterDataMovementConfig& config) { + const DataMovementConfig& base_config = config; // Cast to base + return to_flatbuffer(builder, base_config); +} + +flatbuffers::Offset create_runtime_arg( + flatbuffers::FlatBufferBuilder& builder, const std::variant& arg) { + flatbuffer::RuntimeArgValue value_type; + + flatbuffers::Offset value_offset = std::visit( + tt::stl::overloaded{ + [&](uint32_t arg_value) -> flatbuffers::Offset { + value_type = flatbuffer::RuntimeArgValue::UInt32Value; + return builder.CreateStruct(tt_metal::flatbuffer::UInt32Value{arg_value}).Union(); + }, + [&](Buffer* arg_value) -> flatbuffers::Offset { + // auto& ctx = LightMetalCaptureContext::Get(); + // uint32_t buffer_global_id = ctx.GetGlobalId(arg_value); + // TODO (kmabee) - Uncomment above code once capture library is merged. Temp hack here for now. + uint32_t buffer_global_id = 0; + value_type = flatbuffer::RuntimeArgValue::BufferGlobalId; + return builder.CreateStruct(tt_metal::flatbuffer::BufferGlobalId{buffer_global_id}).Union(); + }}, + arg); + + return flatbuffer::CreateRuntimeArg(builder, value_type, value_offset); +} + +flatbuffers::Offset>> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const std::shared_ptr& runtime_args) { + std::vector> arg_offsets; + + for (const auto& arg : *runtime_args) { + arg_offsets.push_back(create_runtime_arg(builder, arg)); + } + + return builder.CreateVector(arg_offsets); +} + +flatbuffers::Offset> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, tt::stl::Span sub_device_ids) { + std::vector fb_sub_device_ids(sub_device_ids.size()); + for (size_t i = 0; i < sub_device_ids.size(); ++i) { + fb_sub_device_ids[i] = sub_device_ids[i].id; + } + return builder.CreateVector(fb_sub_device_ids); +} + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/flatbuffer/program_types_to_flatbuffer.hpp b/tt_metal/impl/flatbuffer/program_types_to_flatbuffer.hpp new file mode 100644 index 00000000000..858cdfdc0da --- /dev/null +++ b/tt_metal/impl/flatbuffer/program_types_to_flatbuffer.hpp @@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "flatbuffer/base_types_to_flatbuffer.hpp" +#include "program_types_generated.h" +#include +#include +#include +#include +#include + +namespace tt::tt_metal { + +std::pair> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const std::variant& core_spec); + +std::pair> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const DataMovementConfig& config); + +std::pair> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const ComputeConfig& config); + +std::pair> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const EthernetConfig& config); + +std::pair> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, + const std::variant& config); + +std::pair> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const ReaderDataMovementConfig& config); + +std::pair> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const WriterDataMovementConfig& config); + +flatbuffers::Offset create_runtime_arg( + flatbuffers::FlatBufferBuilder& builder, const std::variant& arg); + +flatbuffers::Offset>> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const std::shared_ptr& runtime_args); + +flatbuffers::Offset> to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, tt::stl::Span sub_device_ids); + +} // namespace tt::tt_metal