Skip to content

Commit

Permalink
LightMetal - Add flatbuffer schema and conversion to/from functions f…
Browse files Browse the repository at this point in the history
…or various types (#17039)

 - Code is compiled, but not used by anything yet, will be used after
   subsequent merge of light metal capture/replay libraries.
 - Remove default case statements from case statements for enums so when
   new enum values are added, compile error is seen to force updates.
 - All the PR feedback implemented. Add throws after switch statements for gcc12
 - CircularBufferConfig needed tweaks, add accessors for 3 private
   members for capture, and new constructor to set all private members
   for replay, following lengthy PR/offline discussion (avoid friend).
  • Loading branch information
kmabeeTT committed Feb 2, 2025
1 parent f8fe02d commit e4b974e
Show file tree
Hide file tree
Showing 20 changed files with 1,285 additions and 1 deletion.
19 changes: 19 additions & 0 deletions tt_metal/api/tt-metalium/circular_buffer_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ class CircularBufferConfig {
CircularBufferConfig(
uint32_t total_size, const std::map<uint8_t, tt::DataFormat>& data_format_spec, const Buffer& buffer);

// For flatbuffer deserialization, set all private members.
CircularBufferConfig(
uint32_t total_size,
std::optional<uint32_t> globally_allocated_address,
const std::array<std::optional<tt::DataFormat>, NUM_CIRCULAR_BUFFERS>& data_formats,
const std::array<std::optional<uint32_t>, NUM_CIRCULAR_BUFFERS>& page_sizes,
const std::array<std::optional<Tile>, NUM_CIRCULAR_BUFFERS>& tiles,
const std::unordered_set<uint8_t>& buffer_indices,
const std::unordered_set<uint8_t>& local_buffer_indices,
const std::unordered_set<uint8_t>& remote_buffer_indices,
bool dynamic_cb,
uint32_t max_size,
uint32_t buffer_size);

CircularBufferConfig& set_page_size(uint8_t buffer_index, uint32_t page_size);

CircularBufferConfig& set_total_size(uint32_t total_size);
Expand All @@ -59,6 +73,11 @@ class CircularBufferConfig {

const std::array<std::optional<uint32_t>, NUM_CIRCULAR_BUFFERS>& page_sizes() const;

// These 3 getters are not typically used, but needed for flatbuffer serialization
bool dynamic_cb() const;
uint32_t max_size() const;
uint32_t buffer_size() const;

const Buffer* shadow_global_buffer{nullptr};

class Builder {
Expand Down
14 changes: 13 additions & 1 deletion tt_metal/impl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,24 @@ set(IMPL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/trace/trace.cpp
${CMAKE_CURRENT_SOURCE_DIR}/trace/trace_buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/base_types_from_flatbuffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/base_types_to_flatbuffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/buffer_types_from_flatbuffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/buffer_types_to_flatbuffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/program_types_from_flatbuffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/program_types_to_flatbuffer.cpp
)

# Include helper functions and generate headers from flatbuffer schemas
include(flatbuffers)

set(FLATBUFFER_SCHEMAS) # Empty to start, coming soon.
set(FLATBUFFER_SCHEMAS
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/light_metal_binary.fbs
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/command.fbs
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/base_types.fbs
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/buffer_types.fbs
${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/program_types.fbs
)

foreach(FBS_FILE ${FLATBUFFER_SCHEMAS})
GENERATE_FBS_HEADER(${FBS_FILE})
Expand Down
31 changes: 31 additions & 0 deletions tt_metal/impl/buffers/circular_buffer_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,31 @@ CircularBufferConfig::CircularBufferConfig(
this->set_config(data_format_spec);
}

// For flatbuffer deserialization, set all private members.
CircularBufferConfig::CircularBufferConfig(
uint32_t total_size,
std::optional<uint32_t> globally_allocated_address,
const std::array<std::optional<tt::DataFormat>, NUM_CIRCULAR_BUFFERS>& data_formats,
const std::array<std::optional<uint32_t>, NUM_CIRCULAR_BUFFERS>& page_sizes,
const std::array<std::optional<Tile>, NUM_CIRCULAR_BUFFERS>& tiles,
const std::unordered_set<uint8_t>& buffer_indices,
const std::unordered_set<uint8_t>& local_buffer_indices,
const std::unordered_set<uint8_t>& remote_buffer_indices,
bool dynamic_cb,
uint32_t max_size,
uint32_t buffer_size) :
total_size_(total_size),
globally_allocated_address_(globally_allocated_address),
data_formats_(data_formats),
page_sizes_(page_sizes),
tiles_(tiles),
buffer_indices_(buffer_indices),
local_buffer_indices_(local_buffer_indices),
remote_buffer_indices_(remote_buffer_indices),
dynamic_cb_(dynamic_cb),
max_size_(max_size),
buffer_size_(buffer_size) {}

CircularBufferConfig& CircularBufferConfig::set_page_size(uint8_t buffer_index, uint32_t page_size) {
if (buffer_index > NUM_CIRCULAR_BUFFERS - 1) {
TT_THROW(
Expand Down Expand Up @@ -156,6 +181,12 @@ const std::array<std::optional<uint32_t>, NUM_CIRCULAR_BUFFERS>& CircularBufferC
return this->page_sizes_;
}

bool CircularBufferConfig::dynamic_cb() const { return this->dynamic_cb_; }

uint32_t CircularBufferConfig::max_size() const { return this->max_size_; }

uint32_t CircularBufferConfig::buffer_size() const { return this->buffer_size_; }

CircularBufferConfig::Builder CircularBufferConfig::Builder::LocalBuilder(
CircularBufferConfig& parent, uint8_t buffer_index) {
auto is_remote_index = parent.remote_buffer_indices_.find(buffer_index) != parent.remote_buffer_indices_.end();
Expand Down
77 changes: 77 additions & 0 deletions tt_metal/impl/flatbuffer/base_types.fbs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
namespace tt.tt_metal.flatbuffer;


enum Arch: uint {
Grayskull = 0,
Wormhole_b0 = 1,
Blackhole = 2,
}

enum DataMovementProcessor : byte {
RISCV_0,
RISCV_1
}

enum NOC : byte {
NOC_0,
NOC_1
}

enum NOC_MODE : byte {
DM_DEDICATED_NOC,
DM_DYNAMIC_NOC
}

enum EthMode : ubyte {
SENDER = 0,
RECEIVER = 1,
IDLE = 2
}

enum MathFidelity : ubyte {
LoFi = 0,
HiFi2 = 2,
HiFi3 = 3,
HiFi4 = 4,
Invalid = 255
}

enum DataFormat : uint8 {
Float32 = 0,
Float16 = 1,
Bfp8 = 2,
Bfp4 = 3,
Bfp2 = 11,
Float16_b = 5,
Bfp8_b = 6,
Bfp4_b = 7,
Bfp2_b = 15,
Lf8 = 10,
Fp8_e4m3 = 26, // 0x1A in decimal
Int8 = 14,
Tf32 = 4,
UInt8 = 30,
UInt16 = 9,
Int32 = 8,
UInt32 = 24,
RawUInt8 = 240, // 0xf0 in decimal
RawUInt16 = 241, // 0xf1 in decimal
RawUInt32 = 242, // 0xf2 in decimal
Invalid = 255
}

enum UnpackToDestMode : byte {
Default,
UnpackToDestFp32
}

table DefineEntry {
key: string;
value: string;
}

// Rather than serialize all members, use Tile constructor arguments.
struct Tile {
tile_shape: [uint32:2];
transpose_tile: bool;
}
96 changes: 96 additions & 0 deletions tt_metal/impl/flatbuffer/base_types_from_flatbuffer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "flatbuffer/base_types_from_flatbuffer.hpp"

namespace tt::tt_metal {

DataMovementProcessor from_flatbuffer(flatbuffer::DataMovementProcessor in) {
switch (in) {
case flatbuffer::DataMovementProcessor::RISCV_0: return DataMovementProcessor::RISCV_0;
case flatbuffer::DataMovementProcessor::RISCV_1: return DataMovementProcessor::RISCV_1;
}
TT_THROW("Unsupported DataMovementProcessor from flatbuffer.");
}

NOC from_flatbuffer(flatbuffer::NOC in) {
switch (in) {
case flatbuffer::NOC::NOC_0: return NOC::NOC_0;
case flatbuffer::NOC::NOC_1: return NOC::NOC_1;
}
TT_THROW("Unsupported NOC from flatbuffer.");
}

NOC_MODE from_flatbuffer(flatbuffer::NOC_MODE in) {
switch (in) {
case flatbuffer::NOC_MODE::DM_DEDICATED_NOC: return NOC_MODE::DM_DEDICATED_NOC;
case flatbuffer::NOC_MODE::DM_DYNAMIC_NOC: return NOC_MODE::DM_DYNAMIC_NOC;
}
TT_THROW("Unsupported NOC_MODE from flatbuffer.");
}

Eth from_flatbuffer(flatbuffer::EthMode in) {
switch (in) {
case flatbuffer::EthMode::SENDER: return Eth::SENDER;
case flatbuffer::EthMode::RECEIVER: return Eth::RECEIVER;
case flatbuffer::EthMode::IDLE: return Eth::IDLE;
}
TT_THROW("Unsupported EthMode from flatbuffer.");
}

MathFidelity from_flatbuffer(flatbuffer::MathFidelity input) {
switch (input) {
case flatbuffer::MathFidelity::LoFi: return MathFidelity::LoFi;
case flatbuffer::MathFidelity::HiFi2: return MathFidelity::HiFi2;
case flatbuffer::MathFidelity::HiFi3: return MathFidelity::HiFi3;
case flatbuffer::MathFidelity::HiFi4: return MathFidelity::HiFi4;
case flatbuffer::MathFidelity::Invalid: return MathFidelity::Invalid;
}
TT_THROW("Unsupported MathFidelity from flatbuffer.");
}

UnpackToDestMode from_flatbuffer(flatbuffer::UnpackToDestMode input) {
switch (input) {
case flatbuffer::UnpackToDestMode::UnpackToDestFp32: return UnpackToDestMode::UnpackToDestFp32;
case flatbuffer::UnpackToDestMode::Default: return UnpackToDestMode::Default;
}
TT_THROW("Unsupported UnpackToDestMode from flatbuffer.");
}

tt::DataFormat from_flatbuffer(flatbuffer::DataFormat input) {
switch (input) {
case flatbuffer::DataFormat::Float32: return tt::DataFormat::Float32;
case flatbuffer::DataFormat::Float16: return tt::DataFormat::Float16;
case flatbuffer::DataFormat::Bfp8: return tt::DataFormat::Bfp8;
case flatbuffer::DataFormat::Bfp4: return tt::DataFormat::Bfp4;
case flatbuffer::DataFormat::Bfp2: return tt::DataFormat::Bfp2;
case flatbuffer::DataFormat::Float16_b: return tt::DataFormat::Float16_b;
case flatbuffer::DataFormat::Bfp8_b: return tt::DataFormat::Bfp8_b;
case flatbuffer::DataFormat::Bfp4_b: return tt::DataFormat::Bfp4_b;
case flatbuffer::DataFormat::Bfp2_b: return tt::DataFormat::Bfp2_b;
case flatbuffer::DataFormat::Lf8: return tt::DataFormat::Lf8;
case flatbuffer::DataFormat::Fp8_e4m3: return tt::DataFormat::Fp8_e4m3;
case flatbuffer::DataFormat::Int8: return tt::DataFormat::Int8;
case flatbuffer::DataFormat::Tf32: return tt::DataFormat::Tf32;
case flatbuffer::DataFormat::UInt8: return tt::DataFormat::UInt8;
case flatbuffer::DataFormat::UInt16: return tt::DataFormat::UInt16;
case flatbuffer::DataFormat::Int32: return tt::DataFormat::Int32;
case flatbuffer::DataFormat::UInt32: return tt::DataFormat::UInt32;
case flatbuffer::DataFormat::RawUInt8: return tt::DataFormat::RawUInt8;
case flatbuffer::DataFormat::RawUInt16: return tt::DataFormat::RawUInt16;
case flatbuffer::DataFormat::RawUInt32: return tt::DataFormat::RawUInt32;
case flatbuffer::DataFormat::Invalid: return tt::DataFormat::Invalid;
}
TT_THROW("Unsupported DataFormat from flatbuffer.");
}

Tile from_flatbuffer(const flatbuffer::Tile& tile_fb) {
const auto& shape = *tile_fb.tile_shape();
// Tile shape is already 2D in flatbuffer schema.
std::array<uint32_t, 2> tile_shape = {shape[0], shape[1]};
bool transpose_tile = tile_fb.transpose_tile();
return Tile(tile_shape, transpose_tile);
}

} // namespace tt::tt_metal
29 changes: 29 additions & 0 deletions tt_metal/impl/flatbuffer/base_types_from_flatbuffer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "base_types_generated.h"
#include <buffer_constants.hpp>
#include <kernel_types.hpp>
#include <data_types.hpp>
#include <tt_backend_api_types.hpp>
#include <tile.hpp>
#include <circular_buffer_constants.h>

namespace tt::tt_metal {

DataMovementProcessor from_flatbuffer(flatbuffer::DataMovementProcessor in);

NOC from_flatbuffer(flatbuffer::NOC in);
NOC_MODE from_flatbuffer(flatbuffer::NOC_MODE in);
Eth from_flatbuffer(flatbuffer::EthMode in);

MathFidelity from_flatbuffer(flatbuffer::MathFidelity input);
UnpackToDestMode from_flatbuffer(flatbuffer::UnpackToDestMode input);
tt::DataFormat from_flatbuffer(flatbuffer::DataFormat input);

Tile from_flatbuffer(const flatbuffer::Tile& tile_fb);

} // namespace tt::tt_metal
Loading

0 comments on commit e4b974e

Please sign in to comment.