diff --git a/.github/workflows/cpp-post-commit.yaml b/.github/workflows/cpp-post-commit.yaml index 1177c1f6efe..0feaa3b80cb 100644 --- a/.github/workflows/cpp-post-commit.yaml +++ b/.github/workflows/cpp-post-commit.yaml @@ -64,6 +64,7 @@ jobs: {name: stl, cmd: "./build/test/tt_metal/unit_tests_stl"}, {name: distributed, cmd: "./build/test/tt_metal/distributed/distributed_unit_tests_${{ inputs.arch }} --gtest_filter=MeshDeviceSuite.*"}, + {name: lightmetal, cmd: "./build/test/tt_metal/unit_tests_lightmetal"}, {name: dispatch multicmd queue, cmd: "TT_METAL_GTEST_NUM_HW_CQS=2 ./build/test/tt_metal/unit_tests_dispatch_${{ inputs.arch }} --gtest_filter=MultiCommandQueue*Fixture.*"}, {name: ttnn cpp unit tests, cmd: ./build/test/ttnn/unit_tests_ttnn}, diff --git a/CMakeLists.txt b/CMakeLists.txt index 0a36f8d106d..e36603fff9a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -111,6 +111,7 @@ message(STATUS "Build TT METAL Tests: ${TT_METAL_BUILD_TESTS}") message(STATUS "Build TTNN Tests: ${TTNN_BUILD_TESTS}") message(STATUS "Build with Unity builds: ${TT_UNITY_BUILDS}") message(STATUS "Build with Shared TTNN Sublibraries: ${ENABLE_TTNN_SHARED_SUBLIBS}") +message(STATUS "Build with LightMetal Trace Enabled: ${TT_ENABLE_LIGHT_METAL_TRACE}") ############################################################################################################################ @@ -232,6 +233,13 @@ add_link_options( "$<$:-fsanitize=undefined>" ) +# Planned to be temporary, remove later. +if(TT_ENABLE_LIGHT_METAL_TRACE) + add_compile_definitions(TT_ENABLE_LIGHT_METAL_TRACE=1) +else() + add_compile_definitions(TT_ENABLE_LIGHT_METAL_TRACE=0) +endif() + if(ENABLE_CODE_TIMERS) add_compile_definitions(TT_ENABLE_CODE_TIMERS) endif() diff --git a/build_metal.sh b/build_metal.sh index 0ade43090ce..d4be0cd657d 100755 --- a/build_metal.sh +++ b/build_metal.sh @@ -29,6 +29,7 @@ show_help() { echo " --clean Remove build workspaces." echo " --build-static-libs Build tt_metal (not ttnn) as a static lib (BUILD_SHARED_LIBS=OFF)" echo " --disable-unity-builds Disable Unity builds" + echo " --disable-light-metal-trace Disable Light Metal tracing to binary." echo " --cxx-compiler-path Set path to C++ compiler." echo " --c-compiler-path Set path to C++ compiler." echo " --ttnn-shared-sub-libs Use shared libraries for ttnn." @@ -58,6 +59,7 @@ build_programming_examples="OFF" build_tt_train="OFF" build_static_libs="OFF" unity_builds="ON" +light_metal_trace="ON" build_all="OFF" cxx_compiler_path="" c_compiler_path="" @@ -88,6 +90,7 @@ build-programming-examples build-tt-train build-static-libs disable-unity-builds +disable-light-metal-trace release development debug @@ -155,6 +158,8 @@ while true; do ttnn_shared_sub_libs="ON";; --disable-unity-builds) unity_builds="OFF";; + --disable-light-metal-trace) + light_metal_trace="OFF";; --cxx-compiler-path) cxx_compiler_path="$2";shift;; --c-compiler-path) @@ -218,6 +223,7 @@ echo "INFO: Install Prefix: $cmake_install_prefix" echo "INFO: Build tests: $build_tests" echo "INFO: Enable Unity builds: $unity_builds" echo "INFO: TTNN Shared sub libs : $ttnn_shared_sub_libs" +echo "INFO: Enable Light Metal Trace: $light_metal_trace" # Prepare cmake arguments cmake_args+=("-B" "$build_dir") @@ -308,6 +314,12 @@ else cmake_args+=("-DTT_UNITY_BUILDS=OFF") fi +if [ "$light_metal_trace" = "ON" ]; then + cmake_args+=("-DTT_ENABLE_LIGHT_METAL_TRACE=ON") +else + cmake_args+=("-DTT_ENABLE_LIGHT_METAL_TRACE=OFF") +fi + if [ "$build_all" = "ON" ]; then cmake_args+=("-DTT_METAL_BUILD_TESTS=ON") cmake_args+=("-DTTNN_BUILD_TESTS=ON") diff --git a/cmake/project_options.cmake b/cmake/project_options.cmake index 3187b2efc10..3937b609500 100644 --- a/cmake/project_options.cmake +++ b/cmake/project_options.cmake @@ -19,6 +19,7 @@ option(ENABLE_CCACHE "Build with compiler cache" FALSE) option(TT_UNITY_BUILDS "Build with Unity builds" ON) option(BUILD_TT_TRAIN "Enables build of tt-train" OFF) option(ENABLE_TTNN_SHARED_SUBLIBS "Use shared libraries for ttnn to speed up incremental builds" OFF) +option(TT_ENABLE_LIGHT_METAL_TRACE "Enable Light Metal Trace" ON) ########################################################################################### diff --git a/tests/tt_metal/tt_metal/CMakeLists.txt b/tests/tt_metal/tt_metal/CMakeLists.txt index 1e1da2ac982..e162b7cbc13 100644 --- a/tests/tt_metal/tt_metal/CMakeLists.txt +++ b/tests/tt_metal/tt_metal/CMakeLists.txt @@ -69,6 +69,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/llk) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/perf_microbenchmark) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/stl) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/noc) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/lightmetal) add_custom_target( metal_tests @@ -92,4 +93,5 @@ add_custom_target( unit_tests_llk unit_tests_stl unit_tests_noc + unit_tests_lightmetal ) diff --git a/tests/tt_metal/tt_metal/lightmetal/CMakeLists.txt b/tests/tt_metal/tt_metal/lightmetal/CMakeLists.txt new file mode 100644 index 00000000000..c8d1015f344 --- /dev/null +++ b/tests/tt_metal/tt_metal/lightmetal/CMakeLists.txt @@ -0,0 +1,23 @@ +set(UNIT_TESTS_LIGHTMETAL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/test_lightmetal.cpp) + +add_executable(unit_tests_lightmetal ${UNIT_TESTS_LIGHTMETAL_SRC}) +TT_ENABLE_UNITY_BUILD(unit_tests_lightmetal) + +target_link_libraries(unit_tests_lightmetal PUBLIC test_metal_common_libs) + +target_include_directories( + unit_tests_lightmetal + PRIVATE + "$" + ${PROJECT_SOURCE_DIR}/tests + ${PROJECT_SOURCE_DIR}/tests/tt_metal/tt_metal/common + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/common +) + +set_target_properties( + unit_tests_lightmetal + PROPERTIES + RUNTIME_OUTPUT_DIRECTORY + ${PROJECT_BINARY_DIR}/test/tt_metal +) diff --git a/tests/tt_metal/tt_metal/lightmetal/lightmetal_fixture.hpp b/tests/tt_metal/tt_metal/lightmetal/lightmetal_fixture.hpp new file mode 100644 index 00000000000..28e1604874f --- /dev/null +++ b/tests/tt_metal/tt_metal/lightmetal/lightmetal_fixture.hpp @@ -0,0 +1,62 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "dispatch_fixture.hpp" +#include +#include +#include +#include +#include +#include +#include +#include "command_queue_fixture.hpp" +#include + +class SingleDeviceLightMetalFixture : public CommandQueueFixture { +protected: + std::string trace_bin_path_; + bool write_bin_to_disk_; + + void SetUp() override { + this->validate_dispatch_mode(); + this->arch_ = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); + } + + void CreateDeviceAndBeginCapture( + const size_t trace_region_size, const bool replay_binary = true, const std::string trace_bin_path = "") { + // Skip writing to disk by default, unless user sets env var for local testing + write_bin_to_disk_ = tt::parse_env("LIGHTMETAL_SAVE_BINARY", false); + + // If user didn't provide a specific trace bin path, set a default here based on test name + if (trace_bin_path == "") { + const auto test_info = ::testing::UnitTest::GetInstance()->current_test_info(); + auto trace_filename = test_info ? std::string(test_info->name()) + ".bin" : "lightmetal_trace.bin"; + this->trace_bin_path_ = "/tmp/" + trace_filename; + } + + this->create_device(trace_region_size); + // TODO (kmabee) - revisit placement. CreateDevice() path calls CreateKernel() on programs not + // created with CreateProgram() traced API which leads to "program not in global_id map" + LightMetalBeginCapture(); + } + + // End light metal tracing, write to optional filename and optionally run from binary blob + void TearDown() override { + LightMetalBinary binary = LightMetalEndCapture(); + + if (binary.is_empty()) { + FAIL() << "Light Metal Binary is empty for test, unexpected."; + } + if (write_bin_to_disk_ && !this->trace_bin_path_.empty() && !binary.is_empty()) { + log_info(tt::LogTest, "Writing light metal binary {} bytes to {}", binary.size(), this->trace_bin_path_); + binary.save_to_file(this->trace_bin_path_); + } + + if (!this->IsSlowDispatch()) { + tt::tt_metal::CloseDevice(this->device_); + } + } +}; diff --git a/tests/tt_metal/tt_metal/lightmetal/test_lightmetal.cpp b/tests/tt_metal/tt_metal/lightmetal/test_lightmetal.cpp new file mode 100644 index 00000000000..083e072a322 --- /dev/null +++ b/tests/tt_metal/tt_metal/lightmetal/test_lightmetal.cpp @@ -0,0 +1,379 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "lightmetal_fixture.hpp" +#include +#include "env_lib.hpp" +#include "gtest/gtest.h" +#include +#include +#include +#include +#include +#include +#include "lightmetal_capture_utils.hpp" + +using std::vector; +using namespace tt; +using namespace tt::tt_metal; + +namespace tt::tt_metal { +namespace { + +// Single RISC, no CB's here. Very simple. +Program create_simple_datamovement_program(Buffer& input, Buffer& output, Buffer& l1_buffer) { + Program program = CreateProgram(); + IDevice* device = input.device(); + constexpr CoreCoord core = {0, 0}; + + KernelHandle dram_copy_kernel_id = CreateKernel( + program, + "tt_metal/programming_examples/loopback/kernels/loopback_dram_copy.cpp", + core, + DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default}); + + // Since all interleaved buffers have size == page_size, they are entirely contained in the first DRAM bank + const uint32_t input_bank_id = 0; + const uint32_t output_bank_id = 0; + + // Handle Runtime Args + const std::vector runtime_args = { + l1_buffer.address(), input.address(), input_bank_id, output.address(), output_bank_id, l1_buffer.size()}; + + // Note - this interface doesn't take Buffer, just data. + SetRuntimeArgs(program, dram_copy_kernel_id, core, runtime_args); + + return program; +} + +// Copied from test_EnqueueTrace.cpp +Program create_simple_unary_program(Buffer& input, Buffer& output, Buffer* cb_input_buffer = nullptr) { + Program program = CreateProgram(); + IDevice* device = input.device(); + CoreCoord worker = {0, 0}; + auto reader_kernel = CreateKernel( + program, + "tt_metal/kernels/dataflow/reader_unary.cpp", + worker, + DataMovementConfig{.processor = DataMovementProcessor::RISCV_1, .noc = NOC::RISCV_1_default}); + + auto writer_kernel = CreateKernel( + program, + "tt_metal/kernels/dataflow/writer_unary.cpp", + worker, + DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default}); + + auto sfpu_kernel = CreateKernel( + program, + "tt_metal/kernels/compute/eltwise_sfpu.cpp", + worker, + ComputeConfig{ + .math_approx_mode = true, + .compile_args = {1, 1}, + .defines = {{"SFPU_OP_EXP_INCLUDE", "1"}, {"SFPU_OP_CHAIN_0", "exp_tile_init(); exp_tile(0);"}}}); + + CircularBufferConfig input_cb_config = CircularBufferConfig(2048, {{tt::CBIndex::c_0, tt::DataFormat::Float16_b}}) + .set_page_size(tt::CBIndex::c_0, 2048); + + // For testing dynamic CB for which CB config has a shadow buffer ptr to test. + if (cb_input_buffer) { + input_cb_config.set_globally_allocated_address(*cb_input_buffer); + } + + CoreRange core_range({0, 0}); + CreateCircularBuffer(program, core_range, input_cb_config); + std::shared_ptr writer_runtime_args = std::make_shared(); + std::shared_ptr reader_runtime_args = std::make_shared(); + + *writer_runtime_args = {&output, (uint32_t)0, output.num_pages()}; + + *reader_runtime_args = {&input, (uint32_t)0, input.num_pages()}; + + SetRuntimeArgs(device, detail::GetKernel(program, writer_kernel), worker, writer_runtime_args); + SetRuntimeArgs(device, detail::GetKernel(program, reader_kernel), worker, reader_runtime_args); + + CircularBufferConfig output_cb_config = CircularBufferConfig(2048, {{tt::CBIndex::c_16, tt::DataFormat::Float16_b}}) + .set_page_size(tt::CBIndex::c_16, 2048); + + CreateCircularBuffer(program, core_range, output_cb_config); + return program; +} + +void write_junk_to_buffer(CommandQueue& command_queue, Buffer& buffer) { + vector dummy_write_data(buffer.size() / sizeof(uint32_t), 0xDEADBEEF); + vector dummy_read_data(buffer.size() / sizeof(uint32_t), 0); + EnqueueWriteBuffer(command_queue, buffer, dummy_write_data.data(), true); + EnqueueReadBuffer(command_queue, buffer, dummy_read_data.data(), true); + for (size_t i = 0; i < dummy_read_data.size(); i++) { + log_trace(tt::LogMetalTrace, "i: {:3d} output: {:x} after write+read of dummy data", i, dummy_read_data[i]); + } + EXPECT_TRUE(dummy_write_data == dummy_read_data); +} + +// TODO (kmabee) - consider looping over blocking_flags in some/all tests once stable. +constexpr bool kBlocking = true; +constexpr bool kNonBlocking = false; +vector blocking_flags = {kBlocking, kNonBlocking}; + +using LightMetalBasicTest = SingleDeviceLightMetalFixture; + +// Test that create buffer, write, readback, and verify works when traced + replayed. +TEST_F(LightMetalBasicTest, CreateBufferEnqueueWriteRead) { + CreateDeviceAndBeginCapture(4096); + + CommandQueue& command_queue = this->device_->command_queue(); + uint32_t num_loops = 5; + bool keep_buffers_alive = true; + std::vector> buffers_vec; + + for (uint32_t loop_idx = 0; loop_idx < num_loops; loop_idx++) { + log_debug(tt::LogTest, "Running loop: {}", loop_idx); + + // Switch to use top level CreateBuffer API that has trace support. + uint32_t size_bytes = 64; // 16 elements. + auto buffer = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + log_debug( + tt::LogTest, + "created buffer loop: {} with size: {} bytes addr: 0x{:x}", + loop_idx, + buffer->size(), + buffer->address()); + + if (keep_buffers_alive && loop_idx > 1) { + buffers_vec.push_back(buffer); + } + + // We don't want to capture inputs in binary, but do it to start for testing. + uint32_t start_val = loop_idx * 100; + vector input_data(buffer->size() / sizeof(uint32_t), 0); + for (uint32_t i = 0; i < input_data.size(); i++) { + input_data[i] = start_val + i; + } + log_debug(tt::LogTest, "initialize input_data with {} elements start_val: {}", input_data.size(), start_val); + + vector readback_data; + readback_data.resize(input_data.size()); // This is required. + + // Write data to buffer, then read outputs and verify against expected. + EnqueueWriteBuffer(command_queue, *buffer, input_data.data(), /*blocking=*/true); + // This will verify that readback matches between capture + replay + LightMetalCompareToCapture(command_queue, *buffer, readback_data.data()); + + EXPECT_TRUE(input_data == readback_data); + + // For dev/debug go ahead and print the results. Had a replay bug, was seeing wrong data. + for (size_t i = 0; i < readback_data.size(); i++) { + log_debug(tt::LogMetalTrace, "loop: {} rd_data i: {:3d} => data: {}", loop_idx, i, readback_data[i]); + } + } + + // If any Buffers were kept alive for testing, Deallocate them now to exercise that path for capture/replay. + if (keep_buffers_alive) { + log_info(tt::LogTest, "Explicitly deallocating {} buffers now.", buffers_vec.size()); + for (auto& buffer : buffers_vec) { + DeallocateBuffer(*buffer); + } + } + + Finish(command_queue); +} + +// Test simple case of single datamovement program on single RISC works for trace + replay. +TEST_F(LightMetalBasicTest, SingleRISCDataMovement) { + CreateDeviceAndBeginCapture(4096); + + uint32_t size_bytes = 64; // 16 elements. + auto input = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + auto output = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + auto l1_buffer = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::L1}); + log_debug( + tt::LogTest, + "Created 3 Buffers. input: 0x{:x} output: 0x{:x} l1_buffer: 0x{:x}", + input->address(), + output->address(), + l1_buffer->address()); + + CommandQueue& command_queue = this->device_->command_queue(); + + Program simple_program = create_simple_datamovement_program(*input, *output, *l1_buffer); + vector input_data(input->size() / sizeof(uint32_t), 0); + for (uint32_t i = 0; i < input_data.size(); i++) { + input_data[i] = i; + } + + vector eager_output_data; + eager_output_data.resize(input_data.size()); + + // Write data to buffer, enqueue program, then read outputs and verify against expected. + EnqueueWriteBuffer(command_queue, *input, input_data.data(), /*blocking=*/true); + EnqueueProgram(command_queue, simple_program, /*blocking=*/true); + // This will verify that outputs matches between capture + replay + LightMetalCompareToCapture(command_queue, *output, eager_output_data.data()); + + EXPECT_TRUE(eager_output_data == input_data); + + // For dev/debug go ahead and print the results + for (size_t i = 0; i < eager_output_data.size(); i++) { + log_debug(tt::LogMetalTrace, "i: {:3d} input: {} output: {}", i, input_data[i], eager_output_data[i]); + } + + Finish(command_queue); +} + +// Test simple case of 3 riscs used for datamovement and compute works for trace + replay. +TEST_F(LightMetalBasicTest, ThreeRISCDataMovementCompute) { + CreateDeviceAndBeginCapture(4096); + + uint32_t size_bytes = 64; // 16 elements. + auto input = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + auto output = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + + CommandQueue& command_queue = this->device_->command_queue(); + + // TODO (kmabee) - There is issue with using make_shared, revisit this. + // auto simple_program = std::make_shared(create_simple_unary_program(*input, + // *output)); + auto simple_program = create_simple_unary_program(*input, *output); + + vector input_data(input->size() / sizeof(uint32_t), 0); + for (uint32_t i = 0; i < input_data.size(); i++) { + input_data[i] = i; + } + + // Write data to buffer, enqueue program, then read outputs. + EnqueueWriteBuffer(command_queue, *input, input_data.data(), /*blocking=*/true); + EnqueueProgram(command_queue, simple_program, /*blocking=*/true); + // This will verify that outputs matches between capture + replay + LightMetalCompareToCapture(command_queue, *output); // No read return + + Finish(command_queue); +} + +// Test simple case of 3 riscs used for datamovement and compute works for trace + replay. Also include dynamic CB. +TEST_F(LightMetalBasicTest, ThreeRISCDataMovementComputeDynamicCB) { + CreateDeviceAndBeginCapture(4096); + + uint32_t buf_size_bytes = 64; // 16 elements. + uint32_t cb_size_bytes = 2048; + auto input = CreateBuffer(InterleavedBufferConfig{this->device_, buf_size_bytes, buf_size_bytes, BufferType::DRAM}); + auto output = + CreateBuffer(InterleavedBufferConfig{this->device_, buf_size_bytes, buf_size_bytes, BufferType::DRAM}); + auto cb_in_buf = CreateBuffer(InterleavedBufferConfig{this->device_, cb_size_bytes, cb_size_bytes, BufferType::L1}); + log_info( + tt::LogTest, + "Created 3 Buffers. 0x{:x} 0x{:x} 0x{:x}", + input->address(), + output->address(), + cb_in_buf->address()); + + CommandQueue& command_queue = this->device_->command_queue(); + auto simple_program = create_simple_unary_program(*input, *output, cb_in_buf.get()); + + vector input_data(input->size() / sizeof(uint32_t), 0); + for (uint32_t i = 0; i < input_data.size(); i++) { + input_data[i] = i; + } + + // Write data to buffer, enqueue program, then read outputs. + EnqueueWriteBuffer(command_queue, *input, input_data.data(), /*blocking=*/true); + EnqueueProgram(command_queue, simple_program, /*blocking=*/true); + // This will verify that outputs matches between capture + replay + LightMetalCompareToCapture(command_queue, *output); // No read return + + Finish(command_queue); +} + +// Test simple compute test with metal trace, but no explicit trace replay (added automatically by light metal trace). +TEST_F(LightMetalBasicTest, SingleProgramTraceCapture) { + CreateDeviceAndBeginCapture(4096); + + uint32_t size_bytes = 64; // 16 elements. Was 2048 in original test. + auto input = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + auto output = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + + CommandQueue& command_queue = this->device_->command_queue(); + Program simple_program = create_simple_unary_program(*input, *output); + + // Setup input data for program with some simple values. + vector input_data(input->size() / sizeof(uint32_t), 0); + for (uint32_t i = 0; i < input_data.size(); i++) { + input_data[i] = i; + } + + std::vector eager_output_data(input_data.size()); + + // Initial run w/o trace. Preloads binary cache, and captures golden output. + EnqueueWriteBuffer(command_queue, *input, input_data.data(), /*blocking=*/true); + EnqueueProgram(command_queue, simple_program, /*blocking=*/true); + // This will verify that outputs matches between capture + replay. + LightMetalCompareToCapture(command_queue, *output, eager_output_data.data()); + + // Write junk to output buffer to help make sure trace run from standalone binary works. + write_junk_to_buffer(command_queue, *output); + + // Now enable Metal Trace and run program again for capture. + uint32_t tid = BeginTraceCapture(this->device_, command_queue.id()); + EnqueueProgram(command_queue, simple_program, false); + EndTraceCapture(this->device_, command_queue.id(), tid); + + // Verify trace output during replay matches expected output from original capture. + LightMetalCompareToGolden(command_queue, *output, eager_output_data.data()); + + // Done + Finish(command_queue); + ReleaseTrace(this->device_, tid); +} + +// Test simple compute test with metal trace, but no explicit trace replay (added automatically by light metal trace). +TEST_F(LightMetalBasicTest, TwoProgramTraceCapture) { + CreateDeviceAndBeginCapture(4096); + + uint32_t size_bytes = 64; // 16 elements. Was 2048 in original test. + auto input = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + auto interm = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + auto output = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + + CommandQueue& command_queue = this->device_->command_queue(); + + Program op0 = create_simple_unary_program(*input, *interm); + Program op1 = create_simple_unary_program(*interm, *output); + + // Setup input data for program with some simple values. + vector input_data(input->size() / sizeof(uint32_t), 0); + for (uint32_t i = 0; i < input_data.size(); i++) { + input_data[i] = i; + } + + std::vector eager_output_data(input_data.size()); + + // Initial run w/o trace. Preloads binary cache, and captures golden output. + EnqueueWriteBuffer(command_queue, *input, input_data.data(), /*blocking=*/true); + EnqueueProgram(command_queue, op0, /*blocking=*/true); + EnqueueProgram(command_queue, op1, /*blocking=*/true); + // This will verify that outputs matches between capture + replay. + LightMetalCompareToCapture(command_queue, *output, eager_output_data.data()); + Finish(command_queue); + + // Write junk to output buffer to help make sure trace run from standalone binary works. + write_junk_to_buffer(command_queue, *output); + + // Now enable Metal Trace and run program again for capture. + uint32_t tid = BeginTraceCapture(this->device_, command_queue.id()); + EnqueueProgram(command_queue, op0, false); + EnqueueProgram(command_queue, op1, false); + EndTraceCapture(this->device_, command_queue.id(), tid); + + // Verify trace output during replay matches expected output from original capture. + LightMetalCompareToGolden(command_queue, *output, eager_output_data.data()); + + // Done + Finish(command_queue); + ReleaseTrace(this->device_, tid); +} + +} // namespace +} // namespace tt::tt_metal diff --git a/tt_metal/CMakeLists.txt b/tt_metal/CMakeLists.txt index 768c9318eac..e88549d0cd9 100644 --- a/tt_metal/CMakeLists.txt +++ b/tt_metal/CMakeLists.txt @@ -33,6 +33,7 @@ target_link_libraries( HAL::grayskull HAL::wormhole HAL::blackhole + FlatBuffers::FlatBuffers ) target_precompile_headers( diff --git a/tt_metal/api/tt-metalium/device.hpp b/tt_metal/api/tt-metalium/device.hpp index 3a3238668d7..81577228be9 100644 --- a/tt_metal/api/tt-metalium/device.hpp +++ b/tt_metal/api/tt-metalium/device.hpp @@ -48,6 +48,10 @@ class CommandQueue; class TraceBuffer; struct TraceDescriptor; +namespace detail { +struct TraceDescriptor; +} + inline namespace v0 { class IDevice { diff --git a/tt_metal/api/tt-metalium/lightmetal_capture_utils.hpp b/tt_metal/api/tt-metalium/lightmetal_capture_utils.hpp new file mode 100644 index 00000000000..5c6aec97b59 --- /dev/null +++ b/tt_metal/api/tt-metalium/lightmetal_capture_utils.hpp @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "lightmetal/host_api_capture_helpers.hpp" +#include + +namespace tt::tt_metal { + +// Note: LightMetalCompare functions could have been inside host_api.hpp / command_queue.cpp but seems better +// to not make as visible, since these are APIs used at light-metal capture time for verification purposes. + +// clang-format off +/** + * Reads a buffer from the device and captures return data as golden inside Light Metal Binary, and optionally returns to user. + * When replaying Light Metal Binary, buffer is read and data is compared to the capture-time golden data. + * + * Return value: void + * + * | Argument | Description | Type | Valid Range | Required | + * |----------------|-----------------------------------------------------------------------------------|-------------------------------------|----------------------------------------|----------| + * | cq | The command queue object which dispatches the command to the hardware | CommandQueue & | | Yes | + * | buffer | The device buffer we are reading from | Buffer & or std::shared_ptr | | Yes | + * | dst | The memory where the result will be stored, if provided | void* | | No | + */ +// clang-format on +void LightMetalCompareToCapture( + CommandQueue& cq, + const std::variant, std::shared_ptr>& buffer, + void* dst = nullptr); + +// clang-format off +/** + * Accepts user-supplied golden data, stored inside Light Metal Binary. + * When replaying Light Metal Binary, buffer is read and data is compared to the user-supplied golden data. + * + * Return value: void + * + * | Argument | Description | Type | Valid Range | Required | + * |----------------|-----------------------------------------------------------------------------------|-------------------------------------|----------------------------------------|----------| + * | cq | The command queue object which dispatches the command to the hardware | CommandQueue & | | Yes | + * | buffer | The device buffer we are reading from | Buffer & or std::shared_ptr | | Yes | + * | golden_data | User supplied expected/golden data for buffer | void* | | Yes | + */ +// clang-format on + +void LightMetalCompareToGolden( + CommandQueue& cq, + const std::variant, std::shared_ptr>& buffer, + void* golden_data); + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/CMakeLists.txt b/tt_metal/impl/CMakeLists.txt index 3ba20f30f52..2c35f81f4cb 100644 --- a/tt_metal/impl/CMakeLists.txt +++ b/tt_metal/impl/CMakeLists.txt @@ -52,6 +52,9 @@ set(IMPL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/buffer_types_to_flatbuffer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/program_types_from_flatbuffer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/program_types_to_flatbuffer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/lightmetal/lightmetal_capture.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/lightmetal/lightmetal_capture_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/lightmetal/host_api_capture_helpers.cpp ) # Include helper functions and generate headers from flatbuffer schemas diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index 7dd8756cffa..605cbdcc822 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -12,6 +12,7 @@ #include #include #include +#include "lightmetal/lightmetal_capture.hpp" #include "tracy/Tracy.hpp" #include #include "dprint_server.hpp" @@ -1450,6 +1451,13 @@ void Device::end_trace(const uint8_t cq_id, const uint32_t tid) { this->id_, active_sub_device_manager->id()); this->command_queues_[cq_id]->record_end(); + + // Capture Trace if light metal trace capturing is enabled. + auto& lm_capture_ctx = LightMetalCaptureContext::get(); + if (lm_capture_ctx.is_tracing()) { + lm_capture_ctx.capture_trace_descriptor(*trace_buffer->desc, tid); + } + Trace::initialize_buffer(this->command_queue(cq_id), trace_buffer); this->mark_allocations_unsafe(); }, diff --git a/tt_metal/impl/dispatch/host_runtime_commands.cpp b/tt_metal/impl/dispatch/host_runtime_commands.cpp index 68eb075e998..e1e0dfa8b5b 100644 --- a/tt_metal/impl/dispatch/host_runtime_commands.cpp +++ b/tt_metal/impl/dispatch/host_runtime_commands.cpp @@ -43,6 +43,7 @@ #include #include +#include "lightmetal/host_api_capture_helpers.hpp" using namespace tt::tt_metal; @@ -513,6 +514,8 @@ void EnqueueReadBuffer( const std::variant, std::shared_ptr>& buffer, void* dst, bool blocking) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureEnqueueReadBuffer, cq, buffer, dst, blocking); Buffer& buffer_obj = detail::GetBufferObject(buffer); BufferRegion region(0, buffer_obj.size()); EnqueueReadSubBuffer(cq, buffer, dst, region, blocking); @@ -543,6 +546,8 @@ void EnqueueWriteBuffer( const std::variant, std::shared_ptr>& buffer, HostDataType src, bool blocking) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureEnqueueWriteBuffer, cq, buffer, src, blocking); Buffer& buffer_obj = detail::GetBufferObject(buffer); BufferRegion region(0, buffer_obj.size()); EnqueueWriteSubBuffer(cq, buffer, std::move(src), region, blocking); @@ -562,6 +567,8 @@ void EnqueueWriteSubBuffer( void EnqueueProgram(CommandQueue& cq, Program& program, bool blocking) { ZoneScoped; + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureEnqueueProgram, cq, program, blocking); detail::DispatchStateCheck(true); IDevice* device = cq.device(); @@ -632,6 +639,8 @@ bool EventQuery(const std::shared_ptr& event) { } void Finish(CommandQueue& cq, tt::stl::Span sub_device_ids) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureFinish, cq, sub_device_ids); detail::DispatchStateCheck(true); cq.finish(sub_device_ids); TT_ASSERT( @@ -643,6 +652,8 @@ void Finish(CommandQueue& cq, tt::stl::Span sub_device_ids) { } void EnqueueTrace(CommandQueue& cq, uint32_t trace_id, bool blocking) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureEnqueueTrace, cq, trace_id, blocking); detail::DispatchStateCheck(true); TT_FATAL(cq.device()->get_trace(trace_id) != nullptr, "Trace instance {} must exist on device", trace_id); cq.enqueue_trace(trace_id, blocking); diff --git a/tt_metal/impl/flatbuffer/buffer_types_to_flatbuffer.cpp b/tt_metal/impl/flatbuffer/buffer_types_to_flatbuffer.cpp index 0c3f4c3822b..33d1fe52571 100644 --- a/tt_metal/impl/flatbuffer/buffer_types_to_flatbuffer.cpp +++ b/tt_metal/impl/flatbuffer/buffer_types_to_flatbuffer.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "flatbuffer/buffer_types_to_flatbuffer.hpp" +#include "lightmetal/lightmetal_capture.hpp" // For LightMetalCaptureContext namespace tt::tt_metal { @@ -54,10 +55,8 @@ flatbuffers::Offset to_flatbuffer( }; // Optional shadow buffer for dynamically allocated CBs, get global_id or use 0 as none/nullptr. - // auto& ctx = LightMetalCaptureContext::Get(); - // auto shadow_buf_global_id = config.shadow_global_buffer ? ctx.GetGlobalId(config.shadow_global_buffer) : 0; - // TODO (kmabee) - Uncomment above code once capture library is merged. Temp hack here for now. - uint32_t shadow_buf_global_id = 0; + auto& ctx = LightMetalCaptureContext::get(); + auto shadow_buf_global_id = config.shadow_global_buffer ? ctx.get_global_id(config.shadow_global_buffer) : 0; // Create the FlatBuffer object return flatbuffer::CreateCircularBufferConfig( diff --git a/tt_metal/impl/flatbuffer/program_types_from_flatbuffer.hpp b/tt_metal/impl/flatbuffer/program_types_from_flatbuffer.hpp index f8176eb0f98..930ebe230e7 100644 --- a/tt_metal/impl/flatbuffer/program_types_from_flatbuffer.hpp +++ b/tt_metal/impl/flatbuffer/program_types_from_flatbuffer.hpp @@ -54,8 +54,7 @@ std::variant kernel_config_fr return from_flatbuffer(cmd->kernel_config_as_DataMovementConfig()); case flatbuffer::KernelConfig::ComputeConfig: return from_flatbuffer(cmd->kernel_config_as_ComputeConfig()); case flatbuffer::KernelConfig::EthernetConfig: return from_flatbuffer(cmd->kernel_config_as_EthernetConfig()); - case flatbuffer::KernelConfig::NONE: - throw std::runtime_error("Unhandled KernelConfig type in from_flatbuffer."); + case flatbuffer::KernelConfig::NONE: TT_THROW("Unhandled KernelConfig type in from_flatbuffer."); } TT_THROW("Unhandled KernelConfig type in from_flatbuffer."); } diff --git a/tt_metal/impl/flatbuffer/program_types_to_flatbuffer.cpp b/tt_metal/impl/flatbuffer/program_types_to_flatbuffer.cpp index a3d8e875819..6c8f1570604 100644 --- a/tt_metal/impl/flatbuffer/program_types_to_flatbuffer.cpp +++ b/tt_metal/impl/flatbuffer/program_types_to_flatbuffer.cpp @@ -4,7 +4,9 @@ #include "flatbuffer/base_types_to_flatbuffer.hpp" #include "flatbuffer/program_types_to_flatbuffer.hpp" +#include "lightmetal/lightmetal_capture.hpp" // For LightMetalCaptureContext #include + namespace tt::tt_metal { // Original types defined in core_coord.hpp @@ -155,10 +157,8 @@ flatbuffers::Offset create_runtime_arg( return builder.CreateStruct(tt_metal::flatbuffer::UInt32Value{arg_value}).Union(); }, [&](Buffer* arg_value) -> flatbuffers::Offset { - // auto& ctx = LightMetalCaptureContext::Get(); - // uint32_t buffer_global_id = ctx.GetGlobalId(arg_value); - // TODO (kmabee) - Uncomment above code once capture library is merged. Temp hack here for now. - uint32_t buffer_global_id = 0; + auto& ctx = LightMetalCaptureContext::get(); + uint32_t buffer_global_id = ctx.get_global_id(arg_value); value_type = flatbuffer::RuntimeArgValue::BufferGlobalId; return builder.CreateStruct(tt_metal::flatbuffer::BufferGlobalId{buffer_global_id}).Union(); }}, diff --git a/tt_metal/impl/lightmetal/host_api_capture_helpers.cpp b/tt_metal/impl/lightmetal/host_api_capture_helpers.cpp new file mode 100644 index 00000000000..9d4905bb2c6 --- /dev/null +++ b/tt_metal/impl/lightmetal/host_api_capture_helpers.cpp @@ -0,0 +1,394 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include +#include "lightmetal/host_api_capture_helpers.hpp" +#include "command_generated.h" +#include "lightmetal/lightmetal_capture.hpp" +#include "flatbuffer/base_types_to_flatbuffer.hpp" +#include "flatbuffer/program_types_to_flatbuffer.hpp" +#include "flatbuffer/buffer_types_to_flatbuffer.hpp" + +namespace tt::tt_metal { + +////////////////////////////////////////////////////////////// +// Debug Code // +////////////////////////////////////////////////////////////// + +namespace { +// This can be useful for debug. Not all data types are currently supported, can use this during developmenmt. +void PrintHostDataType(const HostDataType& data) { + std::visit( + tt::stl::overloaded{ + [](const std::shared_ptr>& value) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + }, + [](const std::shared_ptr>& value) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + }, + [](const std::shared_ptr>& value) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + }, + [](const std::shared_ptr>& value) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + }, + [](const std::shared_ptr>& value) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + }, + [](const std::shared_ptr>& value) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + }, + [](const void* value) { log_info(tt::LogMetalTrace, "HostDataType contains: const void*"); }, + [](auto&&) { log_info(tt::LogMetalTrace, "HostDataType contains: Unknown type"); }}, + data); +} +} // namespace + +////////////////////////////////////////////////////////////// +// Host API tracing helper functions // +////////////////////////////////////////////////////////////// + +// Generic helper to build command and add to vector of cmds (CQ) - no need to make public +namespace { +void CaptureCommand(tt::tt_metal::flatbuffer::CommandType cmd_type, ::flatbuffers::Offset fb_offset) { + auto& ctx = LightMetalCaptureContext::get(); + ctx.get_cmds_vector().push_back(tt::tt_metal::flatbuffer::CreateCommand(ctx.get_builder(), cmd_type, fb_offset)); +} +} // namespace + +void CaptureReplayTrace(IDevice* device, uint8_t cq_id, uint32_t trace_id, bool blocking) { + auto& ctx = LightMetalCaptureContext::get(); + log_debug(tt::LogMetalTrace, "{}: cq_id: {} trace_id: {} blocking: {}", __FUNCTION__, cq_id, trace_id, blocking); + auto cmd = tt::tt_metal::flatbuffer::CreateReplayTraceCommand(ctx.get_builder(), cq_id, trace_id, blocking); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::ReplayTraceCommand, cmd.Union()); +} + +void CaptureEnqueueTrace(CommandQueue& cq, uint32_t trace_id, bool blocking) { + auto& ctx = LightMetalCaptureContext::get(); + log_debug(tt::LogMetalTrace, "{}: cq_id: {} trace_id: {} blocking: {}", __FUNCTION__, cq.id(), trace_id, blocking); + auto cmd = tt::tt_metal::flatbuffer::CreateEnqueueTraceCommand(ctx.get_builder(), cq.id(), trace_id, blocking); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::EnqueueTraceCommand, cmd.Union()); +} + +void CaptureLoadTrace(IDevice* device, uint8_t cq_id, uint32_t trace_id) { + auto& ctx = LightMetalCaptureContext::get(); + log_debug(tt::LogMetalTrace, "{}: cq_id: {} trace_id: {}", __FUNCTION__, cq_id, trace_id); + auto cmd = tt::tt_metal::flatbuffer::CreateLoadTraceCommand(ctx.get_builder(), trace_id, cq_id); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::LoadTraceCommand, cmd.Union()); +} + +void CaptureReleaseTrace(IDevice* device, uint32_t trace_id) { + auto& ctx = LightMetalCaptureContext::get(); + log_debug(tt::LogMetalTrace, "{}: trace_id: {}", __FUNCTION__, trace_id); + auto cmd = tt::tt_metal::flatbuffer::CreateReleaseTraceCommand(ctx.get_builder(), trace_id); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::ReleaseTraceCommand, cmd.Union()); +} + +void CaptureCreateBuffer(const std::shared_ptr& buffer, const InterleavedBufferConfig& config) { + auto& ctx = LightMetalCaptureContext::get(); + + uint32_t buffer_global_id = ctx.add_to_map(buffer.get()); + log_debug( + tt::LogMetalTrace, + "{}: size: {} page_size: {} buffer_type: {} buffer_layout: {} buffer_global_id: {}", + __FUNCTION__, + config.size, + config.page_size, + config.buffer_type, + config.buffer_layout, + buffer_global_id); + + assert(config.device->id() == 0 && "multichip not supported yet"); + auto buffer_config_offset = tt::tt_metal::flatbuffer::CreateInterleavedBufferConfig( + ctx.get_builder(), + config.device->id(), + config.size, + config.page_size, + to_flatbuffer(config.buffer_type), + to_flatbuffer(config.buffer_layout)); + auto cmd = + tt::tt_metal::flatbuffer::CreateCreateBufferCommand(ctx.get_builder(), buffer_global_id, buffer_config_offset); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::CreateBufferCommand, cmd.Union()); +} + +void CaptureDeallocateBuffer(Buffer& buffer) { + auto& ctx = LightMetalCaptureContext::get(); + + // Kind of a workaround, but Program Binaries buffer is created via Buffer::create() but can be + // deallocated on Program destruction while capturing is still enabled depending on test structure (scope) + // so let's just not capture these DeallocateBuffer() calls since they will occur on playback naturally. + if (!ctx.is_in_map(&buffer)) { + log_debug(tt::LogMetalTrace, "Cannot capture DeallocateBuffer() without CreateBuffer() - ignoring."); + return; + } + + auto buffer_global_id = ctx.get_global_id(&buffer); + + log_debug( + tt::LogMetalTrace, + "{}: buffer_global_id: {} size: {} address: {}", + __FUNCTION__, + buffer_global_id, + buffer.size(), + buffer.address()); + + auto cmd = tt::tt_metal::flatbuffer::CreateDeallocateBufferCommand(ctx.get_builder(), buffer_global_id); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::DeallocateBufferCommand, cmd.Union()); +} + +void CaptureEnqueueWriteBuffer( + CommandQueue& cq, + std::variant, std::shared_ptr> buffer, + HostDataType src, + bool blocking) { + auto& ctx = LightMetalCaptureContext::get(); + + // We don't want to use shared_ptr to extend lifetime of buffer when adding to global_id map. + Buffer* buffer_ptr = std::holds_alternative>(buffer) + ? std::get>(buffer).get() + : &std::get>(buffer).get(); + + uint32_t cq_global_id = cq.id(); // TODO (kmabee) - consider storing/getting CQ from global map instead. + uint32_t buffer_global_id = ctx.get_global_id(buffer_ptr); + + log_debug( + tt::LogMetalTrace, "{}: cq_global_id: {} buffer_global_id: {}", __FUNCTION__, cq_global_id, buffer_global_id); + // PrintHostDataType(src); // Debug + + // TODO (kmabee) - Currently support limited data formats. Long term we might not store data in flatbuffer, + // but have it provided at runtime so just do what's easiest here and support few types for now. + ::flatbuffers::Offset<::flatbuffers::Vector> src_vector; + if (auto* uint32_vec = std::get_if>>(&src)) { + src_vector = ctx.get_builder().CreateVector(**uint32_vec); + } else if (auto* uint16_vec = std::get_if>>(&src)) { + // Convert uint16_t to uint32_t before creating the FlatBuffers vector + std::vector converted(uint16_vec->get()->begin(), uint16_vec->get()->end()); + src_vector = ctx.get_builder().CreateVector(converted); + } else if (auto* void_ptr = std::get_if(&src)) { + // Assuming the void* points to a buffer of uint32_t values. Infer size, cast to uint32_t. + size_t num_elements = buffer_ptr->size() / sizeof(uint32_t); + auto uint32_data = static_cast(*void_ptr); + src_vector = ctx.get_builder().CreateVector(uint32_data, num_elements); + } else { + TT_THROW("Unsupported HostDataType for captureEnqueueWriteBuffer()"); + } + + auto cmd = tt::tt_metal::flatbuffer::CreateEnqueueWriteBufferCommand( + ctx.get_builder(), cq_global_id, buffer_global_id, src_vector, blocking); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::EnqueueWriteBufferCommand, cmd.Union()); +} + +void CaptureEnqueueReadBuffer( + CommandQueue& cq, + std::variant, std::shared_ptr> buffer, + void* dst, + bool blocking) { + auto& ctx = LightMetalCaptureContext::get(); + + // We don't want to use shared_ptr to extend lifetime of buffer when adding to global_id map. + Buffer* buffer_ptr = std::holds_alternative>(buffer) + ? std::get>(buffer).get() + : &std::get>(buffer).get(); + + uint32_t cq_global_id = cq.id(); // TODO (kmabee) - consider storing/getting CQ from global map instead. + uint32_t buffer_global_id = ctx.get_global_id(buffer_ptr); + + log_debug( + tt::LogMetalTrace, "{}: cq_global_id: {} buffer_global_id: {}", __FUNCTION__, cq_global_id, buffer_global_id); + + // Idea store a read_global_id to keep track of read results. + auto cmd = tt::tt_metal::flatbuffer::CreateEnqueueReadBufferCommand( + ctx.get_builder(), cq_global_id, buffer_global_id, blocking); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::EnqueueReadBufferCommand, cmd.Union()); +} + +void CaptureFinish(CommandQueue& cq, tt::stl::Span sub_device_ids) { + auto& ctx = LightMetalCaptureContext::get(); + uint32_t cq_global_id = cq.id(); // TODO (kmabee) - consider storing/getting CQ from global map instead. + + // Use to_flatbuffer to convert SubDeviceIds to FlatBuffer vector + auto fb_sub_device_ids = to_flatbuffer(ctx.get_builder(), sub_device_ids); + + log_debug( + tt::LogMetalTrace, "{}: cq_global_id: {} sub_devices: {}", __FUNCTION__, cq_global_id, sub_device_ids.size()); + auto cmd = tt::tt_metal::flatbuffer::CreateFinishCommand(ctx.get_builder(), cq_global_id, fb_sub_device_ids); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::FinishCommand, cmd.Union()); +} + +void CaptureCreateProgram(Program& program) { + auto& ctx = LightMetalCaptureContext::get(); + uint32_t program_global_id = ctx.add_to_map(&program); + log_debug(tt::LogMetalTrace, "{}: program_global_id: {}", __FUNCTION__, program_global_id); + + auto cmd = tt::tt_metal::flatbuffer::CreateCreateProgramCommand(ctx.get_builder(), program_global_id); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::CreateProgramCommand, cmd.Union()); +} + +void CaptureEnqueueProgram(CommandQueue& cq, Program& program, bool blocking) { + auto& ctx = LightMetalCaptureContext::get(); + + // When Metal Trace is enabled, skip EnqueueProgram capture (replaced with LoadTrace + ReplayTrace) + if (cq.sysmem_manager().get_bypass_mode()) { + return; + } + + uint32_t cq_global_id = cq.id(); // TODO (kmabee) - consider storing/getting CQ from global map instead. + uint32_t program_global_id = ctx.get_global_id(&program); + log_debug( + tt::LogMetalTrace, "{}: cq_global_id: {} program_global_id: {}", __FUNCTION__, cq_global_id, program_global_id); + + auto cmd = tt::tt_metal::flatbuffer::CreateEnqueueProgramCommand( + ctx.get_builder(), cq_global_id, program_global_id, blocking); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::EnqueueProgramCommand, cmd.Union()); +} + +void CaptureCreateKernel( + KernelHandle kernel_id, + Program& program, + const std::string& file_name, + const std::variant& core_spec, + const std::variant& config) { + auto& ctx = LightMetalCaptureContext::get(); + + std::shared_ptr kernel = program.get_kernel(kernel_id); + uint32_t kernel_global_id = ctx.add_to_map(kernel.get()); + uint32_t program_global_id = ctx.get_global_id(&program); + log_debug( + tt::LogMetalTrace, + "{}: file_name: {} kernel_global_id: {} (kernel_id: {}) program_global_id: {}", + __FUNCTION__, + file_name, + kernel_global_id, + kernel_id, + program_global_id); + + auto& fbb = ctx.get_builder(); + auto filename_offset = fbb.CreateString(file_name); + auto [core_spec_type, core_spec_offset] = to_flatbuffer(fbb, core_spec); + auto [kernel_config_type, kernel_config_offset] = to_flatbuffer(fbb, config); + + auto cmd = tt::tt_metal::flatbuffer::CreateCreateKernelCommand( + fbb, + kernel_global_id, + program_global_id, + filename_offset, + core_spec_type, + core_spec_offset, + kernel_config_type, + kernel_config_offset); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::CreateKernelCommand, cmd.Union()); +} + +void CaptureSetRuntimeArgsUint32( + const Program& program, + KernelHandle kernel_id, + const std::variant& core_spec, + tt::stl::Span runtime_args) { + auto& ctx = LightMetalCaptureContext::get(); + + std::shared_ptr kernel = program.get_kernel(kernel_id); + uint32_t program_global_id = ctx.get_global_id(&program); + uint32_t kernel_global_id = ctx.get_global_id(kernel.get()); + log_debug( + tt::LogMetalTrace, + "{}(uint32): kernel_global_id: {} program_global_id: {} rt_args: {}", + __FUNCTION__, + kernel_global_id, + program_global_id, + runtime_args.size()); + + auto& fbb = ctx.get_builder(); + auto [core_spec_type, core_spec_offset] = to_flatbuffer(fbb, core_spec); + auto rt_args_offset = fbb.CreateVector(runtime_args.data(), runtime_args.size()); + + auto cmd = tt::tt_metal::flatbuffer::CreateSetRuntimeArgsUint32Command( + fbb, program_global_id, kernel_global_id, core_spec_type, core_spec_offset, rt_args_offset); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::SetRuntimeArgsUint32Command, cmd.Union()); +} + +void CaptureSetRuntimeArgs( + IDevice* device, + const std::shared_ptr& kernel, + const std::variant& core_spec, + const std::shared_ptr& runtime_args) { + auto& ctx = LightMetalCaptureContext::get(); + auto& fbb = ctx.get_builder(); + uint32_t kernel_global_id = ctx.get_global_id(kernel.get()); + auto [core_spec_type, core_spec_offset] = to_flatbuffer(fbb, core_spec); + auto rt_args_offset = to_flatbuffer(fbb, runtime_args); + log_debug( + tt::LogMetalTrace, + "{}(RuntimeArgs): kernel_global_id: {} rt_args_size: {}", + __FUNCTION__, + kernel_global_id, + runtime_args->size()); + + auto cmd = tt::tt_metal::flatbuffer::CreateSetRuntimeArgsCommand( + fbb, kernel_global_id, core_spec_type, core_spec_offset, rt_args_offset); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::SetRuntimeArgsCommand, cmd.Union()); +} + +void CaptureCreateCircularBuffer( + CBHandle& cb_handle, + Program& program, + const std::variant& core_spec, + const CircularBufferConfig& config) { + auto& ctx = LightMetalCaptureContext::get(); + auto& fbb = ctx.get_builder(); + uint32_t cb_global_id = ctx.add_to_map(cb_handle); + uint32_t program_global_id = ctx.get_global_id(&program); + auto [core_spec_type, core_spec_offset] = to_flatbuffer(fbb, core_spec); + auto cb_config_offset = to_flatbuffer(config, fbb); + log_debug( + tt::LogMetalTrace, + "{}: cb_global_id: {} program_global_id: {} ", + __FUNCTION__, + cb_global_id, + program_global_id); + + auto cmd = tt::tt_metal::flatbuffer::CreateCreateCircularBufferCommand( + fbb, cb_global_id, program_global_id, core_spec_type, core_spec_offset, cb_config_offset); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::CreateCircularBufferCommand, cmd.Union()); +} + +void CaptureLightMetalCompare( + CommandQueue& cq, + std::variant, std::shared_ptr> buffer, + void* golden_data, + bool is_user_data) { + auto& ctx = LightMetalCaptureContext::get(); + + // We don't want to use shared_ptr to extend lifetime of buffer when adding to global_id map. + Buffer* buffer_ptr = std::holds_alternative>(buffer) + ? std::get>(buffer).get() + : &std::get>(buffer).get(); + + uint32_t cq_global_id = cq.id(); // TODO (kmabee) - consider storing/getting CQ from global map instead. + uint32_t buffer_global_id = ctx.get_global_id(buffer_ptr); + + // Calculate num uint32_t elements in buffer, and convert golden void* to vector + size_t golden_data_len = buffer_ptr->size() / sizeof(uint32_t); + const uint32_t* golden_data_uint32 = static_cast(golden_data); + std::vector golden_data_vector(golden_data_uint32, golden_data_uint32 + golden_data_len); + + log_debug( + tt::LogMetalTrace, + "{}: buffer_global_id: {} is_user_data: {} golden_data_len: {}", + __FUNCTION__, + buffer_global_id, + is_user_data, + golden_data_len); + + // Serialize golden_data into FlatBuffer + auto golden_data_fb = ctx.get_builder().CreateVector(golden_data_vector); + + auto cmd = tt::tt_metal::flatbuffer::CreateLightMetalCompareCommand( + ctx.get_builder(), cq_global_id, buffer_global_id, golden_data_fb, is_user_data); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::LightMetalCompareCommand, cmd.Union()); +} + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/lightmetal/host_api_capture_helpers.hpp b/tt_metal/impl/lightmetal/host_api_capture_helpers.hpp new file mode 100644 index 00000000000..3639fd3b90b --- /dev/null +++ b/tt_metal/impl/lightmetal/host_api_capture_helpers.hpp @@ -0,0 +1,133 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "flatbuffers/flatbuffers.h" +#include "lightmetal/lightmetal_capture.hpp" +#include +#include +#include +#include + +namespace tt::tt_metal { + +// Many forward decls and aliases to reduce includes. +class CommandQueue; +struct DataMovementConfig; +struct ComputeConfig; +struct EthernetConfig; + +inline namespace v0 { +class IDevice; +struct BufferConfig; +struct CircularBufferConfig; +using RuntimeArgs = std::vector>; +} // namespace v0 + +////////////////////////////////////////////////////////////// +// TRACE GUARD & LIGHT METAL TRACE MACRO // +////////////////////////////////////////////////////////////// + +// This struct will disable further tracing in current scope, and re-enable +// when scope ends. Prevents recursive tracing of host APIs. +struct TraceScope { + // Provide an inline definition in the header + static inline thread_local int depth = 0; + // Increment depth on entering scope, decrement on exiting + TraceScope() { ++depth; } + ~TraceScope() { --depth; } +}; + +} // namespace tt::tt_metal + +#if defined(TT_ENABLE_LIGHT_METAL_TRACE) && (TT_ENABLE_LIGHT_METAL_TRACE == 1) + +#define LIGHT_METAL_TRACE_FUNCTION_ENTRY() tt::tt_metal::TraceScope __traceScopeGuard + +#define LIGHT_METAL_TRACE_FUNCTION_CALL(capture_func, ...) \ + do { \ + log_trace( \ + tt::LogMetalTrace, \ + "LIGHT_METAL_TRACE_FUNCTION_CALL: {} via {} istracing: {} depth: {}", \ + #capture_func, \ + __FUNCTION__, \ + LightMetalCaptureContext::get().is_tracing(), \ + tt::tt_metal::TraceScope::depth); \ + if (LightMetalCaptureContext::get().is_tracing() && tt::tt_metal::TraceScope::depth == 1) { \ + capture_func(__VA_ARGS__); \ + } \ + } while (0) +#else + +#define LIGHT_METAL_TRACE_FUNCTION_ENTRY() +#define LIGHT_METAL_TRACE_FUNCTION_CALL(capture_func, ...) \ + do { \ + } while (0) + +#endif + +namespace tt::tt_metal { + +// Per Command type capture helper functions +void CaptureReplayTrace(IDevice* device, uint8_t cq_id, uint32_t tid, bool blocking); + +void CaptureEnqueueTrace(CommandQueue& cq, uint32_t tid, bool blocking); + +void CaptureLoadTrace(IDevice* device, const uint8_t cq_id, const uint32_t tid); + +void CaptureReleaseTrace(IDevice* device, uint32_t tid); + +void CaptureCreateBuffer(const std::shared_ptr& buffer, const InterleavedBufferConfig& config); + +void CaptureDeallocateBuffer(Buffer& buffer); + +void CaptureEnqueueWriteBuffer( + CommandQueue& cq, + std::variant, std::shared_ptr> buffer, + HostDataType src, + bool blocking); + +void CaptureEnqueueReadBuffer( + CommandQueue& cq, + std::variant, std::shared_ptr> buffer, + void* dst, + bool blocking); + +void CaptureFinish(CommandQueue& cq, tt::stl::Span sub_device_ids); +void CaptureCreateProgram(Program& program); +void CaptureEnqueueProgram(CommandQueue& cq, Program& program, bool blocking); + +void CaptureCreateKernel( + KernelHandle kernel_id, + Program& program, + const std::string& file_name, + const std::variant& core_spec, + const std::variant& config); + +void CaptureSetRuntimeArgsUint32( + const Program& program, + KernelHandle kernel_id, + const std::variant& core_spec, + tt::stl::Span runtime_args); + +void CaptureSetRuntimeArgs( + IDevice* device, + const std::shared_ptr& kernel, + const std::variant& core_spec, + const std::shared_ptr& runtime_args); + +void CaptureCreateCircularBuffer( + CBHandle& cb_handle, + Program& program, + const std::variant& core_spec, + const CircularBufferConfig& config); + +void CaptureLightMetalCompare( + CommandQueue& cq, + std::variant, std::shared_ptr> buffer, + void* golden_data, + bool is_user_data); + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/lightmetal/lightmetal_capture.cpp b/tt_metal/impl/lightmetal/lightmetal_capture.cpp new file mode 100644 index 00000000000..c1c7d4e4dee --- /dev/null +++ b/tt_metal/impl/lightmetal/lightmetal_capture.cpp @@ -0,0 +1,234 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include "lightmetal/lightmetal_capture.hpp" +#include "flatbuffers/flatbuffers.h" +#include "command_generated.h" +#include "light_metal_binary_generated.h" +#include +#include +#include +#include + +namespace tt::tt_metal { +inline namespace v0 { + +LightMetalCaptureContext::LightMetalCaptureContext() : is_tracing_(false), builder_() {} + +// Singleton instance accessor +LightMetalCaptureContext& LightMetalCaptureContext::get() { + static LightMetalCaptureContext instance; + return instance; +} + +bool LightMetalCaptureContext::is_tracing() const { return is_tracing_; } + +void LightMetalCaptureContext::set_tracing(bool is_tracing) { is_tracing_ = is_tracing; } + +flatbuffers::FlatBufferBuilder& LightMetalCaptureContext::get_builder() { return builder_; } + +std::vector>& LightMetalCaptureContext::get_cmds_vector() { + return cmds_vec_; +} + +void LightMetalCaptureContext::capture_trace_descriptor(const TraceDescriptor& trace_desc, const uint32_t tid) { + trace_descs_vec_.push_back(to_flatbuffer(builder_, trace_desc, tid)); +} + +// Create final flatbuffer binary from the built up data and return to caller as blob. +// If light_metal_binary itself (flatbuffer object) is of interest, could return it instead. +LightMetalBinary LightMetalCaptureContext::create_light_metal_binary() { + auto cmds_vec_fb = builder_.CreateVector(cmds_vec_); + auto sorted_trace_descs = builder_.CreateVectorOfSortedTables(&trace_descs_vec_); + auto light_metal_binary = + tt::tt_metal::flatbuffer::CreateLightMetalBinary(builder_, cmds_vec_fb, sorted_trace_descs); + builder_.Finish(light_metal_binary); + + const uint8_t* buffer_ptr = builder_.GetBufferPointer(); + size_t buffer_size = builder_.GetSize(); + + std::vector binary_data(buffer_ptr, buffer_ptr + buffer_size); + return LightMetalBinary(std::move(binary_data)); +} + +// Reset some internal state, and ensure tracing isn't active. Should only be called at start of tracing. +void LightMetalCaptureContext::reset() { + TT_ASSERT(!is_tracing_, "Cannot reset light metal capture context while tracing is enabled."); + builder_.Clear(); + next_global_id_ = 0; + cmds_vec_.clear(); + trace_descs_vec_.clear(); + buffer_to_global_id_map_.clear(); + program_to_global_id_map_.clear(); + kernel_to_global_id_map_.clear(); + cb_handle_to_global_id_map_.clear(); +} + +//////////////////////////////////////////// +// Object Map Public Accessors // +//////////////////////////////////////////// + +bool LightMetalCaptureContext::is_in_map(const Buffer* obj) { + return buffer_to_global_id_map_.find(obj) != buffer_to_global_id_map_.end(); +} + +uint32_t LightMetalCaptureContext::add_to_map(const Buffer* obj) { + if (is_in_map(obj)) { + log_warning(tt::LogMetalTrace, "Buffer already exists in global_id map."); + } + uint32_t global_id = next_global_id_++; + buffer_to_global_id_map_[obj] = global_id; + return global_id; +} + +void LightMetalCaptureContext::remove_from_map(const Buffer* obj) { + if (!is_in_map(obj)) { + log_warning(tt::LogMetalTrace, "Buffer not found in global_id map."); + } + buffer_to_global_id_map_.erase(obj); +} + +uint32_t LightMetalCaptureContext::get_global_id(const Buffer* obj) { + auto it = buffer_to_global_id_map_.find(obj); + if (it != buffer_to_global_id_map_.end()) { + return it->second; + } else { + TT_THROW("Buffer not found in global_id global_id map"); + } +} + +bool LightMetalCaptureContext::is_in_map(const Program* obj) { + return program_to_global_id_map_.find(obj) != program_to_global_id_map_.end(); +} + +uint32_t LightMetalCaptureContext::add_to_map(const Program* obj) { + if (is_in_map(obj)) { + log_warning(tt::LogMetalTrace, "Program already exists in global_id map."); + } + uint32_t global_id = next_global_id_++; + program_to_global_id_map_[obj] = global_id; + return global_id; +} + +void LightMetalCaptureContext::remove_from_map(const Program* obj) { + if (!is_in_map(obj)) { + log_warning(tt::LogMetalTrace, "Program not found in global_id map."); + } + program_to_global_id_map_.erase(obj); +} + +uint32_t LightMetalCaptureContext::get_global_id(const Program* obj) { + auto it = program_to_global_id_map_.find(obj); + if (it != program_to_global_id_map_.end()) { + return it->second; + } else { + TT_THROW("Program not found in global_id map."); + } +} + +bool LightMetalCaptureContext::is_in_map(const Kernel* obj) { + return kernel_to_global_id_map_.find(obj) != kernel_to_global_id_map_.end(); +} + +uint32_t LightMetalCaptureContext::add_to_map(const Kernel* obj) { + if (is_in_map(obj)) { + log_warning(tt::LogMetalTrace, "Kernel already exists in global_id map."); + } + uint32_t global_id = next_global_id_++; + kernel_to_global_id_map_[obj] = global_id; + return global_id; +} + +void LightMetalCaptureContext::remove_from_map(const Kernel* obj) { + if (!is_in_map(obj)) { + log_warning(tt::LogMetalTrace, "Kernel not found in global_id map."); + } + kernel_to_global_id_map_.erase(obj); +} + +uint32_t LightMetalCaptureContext::get_global_id(const Kernel* obj) { + auto it = kernel_to_global_id_map_.find(obj); + if (it != kernel_to_global_id_map_.end()) { + return it->second; + } else { + TT_THROW("Kernel not found in global_id map."); + } +} + +bool LightMetalCaptureContext::is_in_map(const CBHandle handle) { + return cb_handle_to_global_id_map_.find(handle) != cb_handle_to_global_id_map_.end(); +} + +uint32_t LightMetalCaptureContext::add_to_map(const CBHandle handle) { + if (is_in_map(handle)) { + log_warning(tt::LogMetalTrace, "CBHandle already exists in global_id map."); + } + uint32_t global_id = next_global_id_++; + cb_handle_to_global_id_map_[handle] = global_id; + return global_id; +} + +void LightMetalCaptureContext::remove_from_map(const CBHandle handle) { + if (!is_in_map(handle)) { + log_warning(tt::LogMetalTrace, "CBHandle not found in global_id map."); + } + cb_handle_to_global_id_map_.erase(handle); +} + +uint32_t LightMetalCaptureContext::get_global_id(const CBHandle handle) { + auto it = cb_handle_to_global_id_map_.find(handle); + if (it != cb_handle_to_global_id_map_.end()) { + return it->second; + } else { + TT_THROW("CBHandle not found in global_id map."); + } +} + +//////////////////////////////////////////// +// Non-Class Helper Functions // +//////////////////////////////////////////// + +// Serialize tt-metal traceDescriptor and trace_id to flatbuffer format. +TraceDescriptorByTraceIdOffset to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const TraceDescriptor& trace_desc, const uint32_t trace_id) { + // Serialize the trace_data vector + auto trace_data_offset = builder.CreateVector(trace_desc.data); + + // Serialize the sub_device_descriptors (map) + std::vector> + sub_device_descriptor_offsets; + for (const auto& [sub_device_id, descriptor] : trace_desc.descriptors) { + auto descriptor_offset = tt::tt_metal::flatbuffer::CreateTraceDescriptorMetaData( + builder, + descriptor.num_completion_worker_cores, + descriptor.num_traced_programs_needing_go_signal_multicast, + descriptor.num_traced_programs_needing_go_signal_unicast); + auto mapping_offset = tt::tt_metal::flatbuffer::CreateSubDeviceDescriptorMapping( + builder, + sub_device_id.to_index(), // No need for static_cast; directly use uint8_t + descriptor_offset); + sub_device_descriptor_offsets.push_back(mapping_offset); + } + auto sub_device_descriptors_offset = builder.CreateVector(sub_device_descriptor_offsets); + + // Serialize the sub_device_ids vector + std::vector sub_device_ids_converted; + sub_device_ids_converted.reserve(trace_desc.sub_device_ids.size()); + for (const auto& sub_device_id : trace_desc.sub_device_ids) { + sub_device_ids_converted.push_back(sub_device_id.to_index()); + } + auto sub_device_ids_offset = builder.CreateVector(sub_device_ids_converted); + + // Create the TraceDescriptor + auto trace_descriptor_offset = tt::tt_metal::flatbuffer::CreateTraceDescriptor( + builder, trace_data_offset, sub_device_descriptors_offset, sub_device_ids_offset); + + // Create the TraceDescriptorByTraceId + return tt::tt_metal::flatbuffer::CreateTraceDescriptorByTraceId(builder, trace_id, trace_descriptor_offset); +} + +} // namespace v0 +} // namespace tt::tt_metal diff --git a/tt_metal/impl/lightmetal/lightmetal_capture.hpp b/tt_metal/impl/lightmetal/lightmetal_capture.hpp new file mode 100644 index 00000000000..3712e666108 --- /dev/null +++ b/tt_metal/impl/lightmetal/lightmetal_capture.hpp @@ -0,0 +1,92 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include + +// Forward decl for command_generated.h +namespace tt::tt_metal::flatbuffer { +class Command; +} + +// Forward decl for light_metal_binary_generated.h +namespace tt::tt_metal::flatbuffer { +struct TraceDescriptor; +struct TraceDescriptorByTraceId; +} // namespace tt::tt_metal::flatbuffer + +// Forward decl for trace_buffer.hpp +namespace tt::tt_metal { +class TraceDescriptor; +} + +namespace tt::tt_metal { +inline namespace v0 { + +class Buffer; +class Program; +class Kernel; +using CBHandle = uintptr_t; +using TraceDescriptorByTraceIdOffset = flatbuffers::Offset; + +class LightMetalCaptureContext { +public: + static LightMetalCaptureContext& get(); + + bool is_tracing() const; + void set_tracing(bool tracing); + + flatbuffers::FlatBufferBuilder& get_builder(); + std::vector>& get_cmds_vector(); + void capture_trace_descriptor(const TraceDescriptor& trace_desc, uint32_t tid); + LightMetalBinary create_light_metal_binary(); + void reset(); + + // Object Map Public Accessors + bool is_in_map(const Buffer* obj); + uint32_t add_to_map(const Buffer* obj); + void remove_from_map(const Buffer* obj); + uint32_t get_global_id(const Buffer* obj); + bool is_in_map(const Program* obj); + uint32_t add_to_map(const Program* obj); + void remove_from_map(const Program* obj); + uint32_t get_global_id(const Program* obj); + bool is_in_map(const Kernel* obj); + uint32_t add_to_map(const Kernel* obj); + void remove_from_map(const Kernel* obj); + uint32_t get_global_id(const Kernel* obj); + bool is_in_map(const CBHandle handle); + uint32_t add_to_map(const CBHandle handle); + void remove_from_map(const CBHandle handle); + uint32_t get_global_id(const CBHandle handle); + +private: + LightMetalCaptureContext(); // Private constructor + + bool is_tracing_ = false; + flatbuffers::FlatBufferBuilder builder_; + std::vector> cmds_vec_; + std::vector trace_descs_vec_; + + // Object maps for associating each object with a global_id + uint32_t next_global_id_ = 0; // Shared across all object types. + std::unordered_map buffer_to_global_id_map_; + std::unordered_map program_to_global_id_map_; + std::unordered_map kernel_to_global_id_map_; + std::unordered_map cb_handle_to_global_id_map_; + // TODO (kmabee) - consider adding map for CommandQueue object. + + LightMetalCaptureContext(const LightMetalCaptureContext&) = delete; + LightMetalCaptureContext& operator=(const LightMetalCaptureContext&) = delete; +}; + +TraceDescriptorByTraceIdOffset to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const TraceDescriptor& trace_desc, uint32_t trace_id); + +} // namespace v0 +} // namespace tt::tt_metal diff --git a/tt_metal/impl/lightmetal/lightmetal_capture_utils.cpp b/tt_metal/impl/lightmetal/lightmetal_capture_utils.cpp new file mode 100644 index 00000000000..d33250777b1 --- /dev/null +++ b/tt_metal/impl/lightmetal/lightmetal_capture_utils.cpp @@ -0,0 +1,38 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "lightmetal/host_api_capture_helpers.hpp" +#include +#include + +namespace tt::tt_metal { + +void LightMetalCompareToCapture( + CommandQueue& cq, const std::variant, std::shared_ptr>& buffer, void* dst) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + + // If dst ptr is not provided, just allocate temp space for rd return capture/usage. + std::vector rd_data_tmp; + if (!dst) { + size_t buffer_size = std::holds_alternative>(buffer) + ? std::get>(buffer).get().size() + : std::get>(buffer)->size(); + rd_data_tmp.resize(buffer_size / sizeof(uint32_t)); + dst = rd_data_tmp.data(); + } + + EnqueueReadBuffer(cq, buffer, dst, true); // Blocking read to get golden value. + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureLightMetalCompare, cq, buffer, dst, false); +} + +void LightMetalCompareToGolden( + CommandQueue& cq, + const std::variant, std::shared_ptr>& buffer, + void* golden_data) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureLightMetalCompare, cq, buffer, golden_data, true); +} + +} // namespace tt::tt_metal diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 2ccb761ed09..f1a36ce8f7a 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -35,6 +35,7 @@ #include "tracy/Tracy.hpp" #include +#include "lightmetal/host_api_capture_helpers.hpp" #include "llrt.hpp" @@ -933,7 +934,12 @@ bool CloseDevice(IDevice* device) { return tt::DevicePool::instance().close_device(device_id); } -Program CreateProgram() { return Program(); } +Program CreateProgram() { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + auto program = Program(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureCreateProgram, program); + return program; +} KernelHandle CreateDataMovementKernel( Program& program, @@ -1019,7 +1025,8 @@ KernelHandle CreateKernel( const std::string& file_name, const std::variant& core_spec, const std::variant& config) { - return std::visit( + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + KernelHandle kernel = std::visit( [&](auto&& cfg) -> KernelHandle { CoreRangeSet core_ranges = GetCoreRangeSet(core_spec); KernelSource kernel_src(file_name, KernelSource::FILE_PATH); @@ -1033,6 +1040,9 @@ KernelHandle CreateKernel( } }, config); + + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureCreateKernel, kernel, program, file_name, core_spec, config); + return kernel; } KernelHandle CreateKernelFromString( @@ -1060,8 +1070,11 @@ CBHandle CreateCircularBuffer( Program& program, const std::variant& core_spec, const CircularBufferConfig& config) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); CoreRangeSet core_ranges = GetCoreRangeSet(core_spec); - return program.add_circular_buffer(core_ranges, config); + auto cb_handle = program.add_circular_buffer(core_ranges, config); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureCreateCircularBuffer, cb_handle, program, core_spec, config); + return cb_handle; } const CircularBufferConfig& GetCircularBufferConfig(Program& program, CBHandle cb_handle) { @@ -1141,7 +1154,8 @@ GlobalSemaphore CreateGlobalSemaphore( } std::shared_ptr CreateBuffer(const InterleavedBufferConfig& config) { - return Buffer::create( + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + auto buffer = Buffer::create( config.device, config.size, config.page_size, @@ -1150,6 +1164,9 @@ std::shared_ptr CreateBuffer(const InterleavedBufferConfig& config) { std::nullopt, std::nullopt, std::nullopt); + + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureCreateBuffer, buffer, config); + return buffer; } std::shared_ptr CreateBuffer(const InterleavedBufferConfig& config, DeviceAddr address) { return Buffer::create( @@ -1208,7 +1225,11 @@ std::shared_ptr CreateBuffer(const ShardedBufferConfig& config, SubDevic sub_device_id); } -void DeallocateBuffer(Buffer& buffer) { buffer.deallocate(); } +void DeallocateBuffer(Buffer& buffer) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureDeallocateBuffer, buffer); + buffer.deallocate(); +} void AssignGlobalBufferToProgram(const std::shared_ptr& buffer, Program& program) { detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); @@ -1220,6 +1241,8 @@ void SetRuntimeArgs( KernelHandle kernel_id, const std::variant& core_spec, stl::Span runtime_args) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureSetRuntimeArgsUint32, program, kernel_id, core_spec, runtime_args); ZoneScoped; std::visit([&](auto&& core_spec) { SetRuntimeArgsImpl(program, kernel_id, core_spec, runtime_args); }, core_spec); } @@ -1246,7 +1269,9 @@ void SetRuntimeArgs( const std::shared_ptr& kernel, const std::variant& core_spec, const std::shared_ptr& runtime_args) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); detail::DispatchStateCheck(not device->using_slow_dispatch()); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureSetRuntimeArgs, device, kernel, core_spec, runtime_args); SetRuntimeArgsImpl(kernel, core_spec, std::move(runtime_args), false); } @@ -1289,22 +1314,51 @@ uint32_t BeginTraceCapture(IDevice* device, const uint8_t cq_id) { return tid; } -void EndTraceCapture(IDevice* device, const uint8_t cq_id, const uint32_t tid) { device->end_trace(cq_id, tid); } +void EndTraceCapture(IDevice* device, const uint8_t cq_id, const uint32_t tid) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + device->end_trace(cq_id, tid); + // When light metal tracing is enabled, TraceDescriptor will be serialized via end_trace() and this + // will serialize the LightMetalLoadTraceId call to be used during replay to load trace back to device. + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureLoadTrace, device, cq_id, tid); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureReplayTrace, device, cq_id, tid, true); // blocking=true +} void ReplayTrace(IDevice* device, const uint8_t cq_id, const uint32_t tid, const bool blocking) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureReplayTrace, device, cq_id, tid, blocking); device->replay_trace(cq_id, tid, blocking); } -void ReleaseTrace(IDevice* device, const uint32_t tid) { device->release_trace(tid); } +void ReleaseTrace(IDevice* device, const uint32_t tid) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureReleaseTrace, device, tid); + device->release_trace(tid); +} -// Light Metal Begin/End Capture APIs are stubs for now, filled in soon. +// This is nop if compile time define not set. void LightMetalBeginCapture() { - log_warning(tt::LogMetalTrace, "Begin LightMetalBinary Capture - not yet implemented."); +#if defined(TT_ENABLE_LIGHT_METAL_TRACE) && (TT_ENABLE_LIGHT_METAL_TRACE == 1) + log_debug(tt::LogMetalTrace, "Begin LightMetalBinary Capture"); + auto& lm_capture_ctx = LightMetalCaptureContext::get(); + lm_capture_ctx.reset(); // Clear previous traces if any, ensure tracing disabled + lm_capture_ctx.set_tracing(true); // Enable tracing +#else + log_warning(tt::LogMetalTrace, "TT_ENABLE_LIGHT_METAL_TRACE!=1, ignoring LightMetalBeginCapture()"); +#endif } +// This is nop if compile time define not set, return empty vector. LightMetalBinary LightMetalEndCapture() { - log_warning(tt::LogMetalTrace, "End LightMetalBinary Capture - not yet implemented."); +#if defined(TT_ENABLE_LIGHT_METAL_TRACE) && (TT_ENABLE_LIGHT_METAL_TRACE == 1) + log_debug(tt::LogMetalTrace, "End LightMetalBinary Capture"); + auto& lm_capture_ctx = LightMetalCaptureContext::get(); + TT_ASSERT(lm_capture_ctx.is_tracing(), "Light Metal Capture was not enabled."); + lm_capture_ctx.set_tracing(false); // Disable tracing + return lm_capture_ctx.create_light_metal_binary(); +#else + log_warning(tt::LogMetalTrace, "TT_ENABLE_LIGHT_METAL_TRACE!=1, ignoring LightMetalEndCapture()"); return {}; +#endif } void LoadTrace(IDevice* device, const uint8_t cq_id, const uint32_t trace_id, const TraceDescriptor& trace_desc) {