diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/LoadTrace.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/LoadTrace.rst new file mode 100644 index 00000000000..ae495816c93 --- /dev/null +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/LoadTrace.rst @@ -0,0 +1,4 @@ +LoadTrace +========= + +.. doxygenfunction:: tt::tt_metal::v0::LoadTrace diff --git a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/command_queue.rst b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/command_queue.rst index e56f5612c6e..6f6a5df2b5f 100644 --- a/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/command_queue.rst +++ b/docs/source/tt-metalium/tt_metal/apis/host_apis/command_queue/command_queue.rst @@ -16,6 +16,7 @@ CommandQueue ReplayTrace ReleaseTrace EnqueueTrace + LoadTrace LightMetalBeginCapture LightMetalEndCapture Finish diff --git a/tt_metal/api/tt-metalium/command_queue.hpp b/tt_metal/api/tt-metalium/command_queue.hpp index 7889adda403..5099810b71d 100644 --- a/tt_metal/api/tt-metalium/command_queue.hpp +++ b/tt_metal/api/tt-metalium/command_queue.hpp @@ -160,7 +160,7 @@ class EnqueueTraceCommand : public Command { Buffer& buffer; IDevice* device; SystemMemoryManager& manager; - std::shared_ptr& descriptor; + std::shared_ptr& descriptor; std::array& expected_num_workers_completed; bool clear_count; NOC noc_index; @@ -170,7 +170,7 @@ class EnqueueTraceCommand : public Command { uint32_t command_queue_id, IDevice* device, SystemMemoryManager& manager, - std::shared_ptr& descriptor, + std::shared_ptr& descriptor, Buffer& buffer, std::array& expected_num_workers_completed, NOC noc_index, diff --git a/tt_metal/api/tt-metalium/device.hpp b/tt_metal/api/tt-metalium/device.hpp index 576ddfc659f..ba447ae6b96 100644 --- a/tt_metal/api/tt-metalium/device.hpp +++ b/tt_metal/api/tt-metalium/device.hpp @@ -46,6 +46,7 @@ class SubDevice; class JitBuildEnv; class CommandQueue; class TraceBuffer; +struct TraceDescriptor; inline namespace v0 { @@ -186,6 +187,9 @@ class IDevice { virtual uint32_t get_trace_buffers_size() const = 0; virtual void set_trace_buffers_size(uint32_t size) = 0; + // Light Metal + virtual void load_trace(uint8_t cq_id, uint32_t trace_id, const TraceDescriptor& trace_desc) = 0; + virtual bool using_slow_dispatch() const = 0; virtual bool using_fast_dispatch() const = 0; diff --git a/tt_metal/api/tt-metalium/device_impl.hpp b/tt_metal/api/tt-metalium/device_impl.hpp index b7af05ef3fd..888e8b3739f 100644 --- a/tt_metal/api/tt-metalium/device_impl.hpp +++ b/tt_metal/api/tt-metalium/device_impl.hpp @@ -22,6 +22,7 @@ #include "hardware_command_queue.hpp" #include "sub_device_manager_tracker.hpp" #include "sub_device_types.hpp" +#include "trace_buffer.hpp" #include "span.hpp" #include "program_cache.hpp" @@ -179,6 +180,9 @@ class Device : public IDevice { uint32_t get_trace_buffers_size() const override { return trace_buffers_size_; } void set_trace_buffers_size(uint32_t size) override { trace_buffers_size_ = size; } + // Light Metal + void load_trace(uint8_t cq_id, uint32_t trace_id, const TraceDescriptor& trace_desc) override; + bool using_slow_dispatch() const override; bool using_fast_dispatch() const override; diff --git a/tt_metal/api/tt-metalium/hardware_command_queue.hpp b/tt_metal/api/tt-metalium/hardware_command_queue.hpp index a9550761574..f9c4f40a2e9 100644 --- a/tt_metal/api/tt-metalium/hardware_command_queue.hpp +++ b/tt_metal/api/tt-metalium/hardware_command_queue.hpp @@ -84,7 +84,7 @@ class CommandQueue { volatile bool is_dprint_server_hung(); volatile bool is_noc_hung(); - void record_begin(const uint32_t tid, std::shared_ptr ctx); + void record_begin(const uint32_t tid, std::shared_ptr ctx); void record_end(); void set_num_worker_sems_on_dispatch(uint32_t num_worker_sems); void set_go_signal_noc_data_on_dispatch(const vector_memcpy_aligned& go_signal_noc_data); @@ -149,7 +149,7 @@ class CommandQueue { uint32_t size_B; uint32_t completion_queue_reader_core = 0; std::optional tid_; - std::shared_ptr trace_ctx; + std::shared_ptr trace_ctx; std::thread completion_queue_thread; SystemMemoryManager& manager; std::array config_buffer_mgr; diff --git a/tt_metal/api/tt-metalium/host_api.hpp b/tt_metal/api/tt-metalium/host_api.hpp index 4667813f320..433edafd273 100644 --- a/tt_metal/api/tt-metalium/host_api.hpp +++ b/tt_metal/api/tt-metalium/host_api.hpp @@ -34,6 +34,7 @@ namespace tt { namespace tt_metal { class CommandQueue; +struct TraceDescriptor; inline namespace v0 { class Program; @@ -906,6 +907,22 @@ void LightMetalBeginCapture(); // clang-format on LightMetalBinary LightMetalEndCapture(); +// clang-format off +/** + * Load an existing trace descriptor onto a particular device and command queue and assign it as user-provided trace id. Useful for Light Metal Binary replay. + * + * Return value: void + * + * | Argument | Description | Type | Valid Range | Required | + * |--------------|------------------------------------------------------------------------|-------------------------------|------------------------------------|----------| + * | device | The device to load the trace onto. | IDevice * | | Yes | + * | cq_id | The command queue id to load the trace onto. | uint8_t | | Yes | + * | trace_id | A unique id to represent the trace on device. | uint32_t | | Yes | + * | trace_desc | The trace descriptor to load onto the device. | TraceDescriptor& | | Yes | + */ +// clang-format on +void LoadTrace(IDevice* device, uint8_t cq_id, uint32_t trace_id, const TraceDescriptor& trace_desc); + // clang-format off /** * Read device side profiler data and dump results into device side CSV log diff --git a/tt_metal/api/tt-metalium/mesh_device.hpp b/tt_metal/api/tt-metalium/mesh_device.hpp index 288cd124937..ae67a2c0b75 100644 --- a/tt_metal/api/tt-metalium/mesh_device.hpp +++ b/tt_metal/api/tt-metalium/mesh_device.hpp @@ -173,6 +173,9 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this #include -#include "buffer.hpp" #include "sub_device_types.hpp" namespace tt::tt_metal { -namespace detail { +// Forward decl to avoid including header +inline namespace v0 { +class Buffer; +} + struct TraceDescriptor { struct Descriptor { uint32_t num_completion_worker_cores = 0; @@ -30,13 +33,12 @@ struct TraceDescriptor { std::vector sub_device_ids; std::vector data; }; -} // namespace detail struct TraceBuffer { - std::shared_ptr desc; + std::shared_ptr desc; std::shared_ptr buffer; - TraceBuffer(std::shared_ptr desc, std::shared_ptr buffer); + TraceBuffer(std::shared_ptr desc, std::shared_ptr buffer); ~TraceBuffer(); }; diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index aabe52e264c..05365931331 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -651,6 +651,12 @@ void MeshDevice::set_trace_buffers_size(uint32_t size) { reference_device()->set_trace_buffers_size(size); } +// Light Metal +void MeshDevice::load_trace(const uint8_t cq_id, const uint32_t trace_id, const TraceDescriptor& trace_desc) { + TT_THROW("load_trace() is not supported on MeshDevice - use individual devices instead"); + reference_device()->load_trace(cq_id, trace_id, trace_desc); +} + // Dispatch and initialization bool MeshDevice::initialize(const uint8_t num_hw_cqs, size_t l1_small_size, size_t trace_region_size, tt::stl::Span l1_bank_remap, bool minimal) { work_executor_->initialize(); diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index f8b522dcc59..cb11b22b2e9 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -1588,6 +1588,24 @@ void Device::end_trace(const uint8_t cq_id, const uint32_t tid) { this->mark_allocations_unsafe(); } +// Load the TraceDescriptor for a given trace_id to the device. A combination of logic from begin/end_trace. +void Device::load_trace(const uint8_t cq_id, const uint32_t trace_id, const TraceDescriptor& trace_desc) { + this->mark_allocations_safe(); + + auto* active_sub_device_manager = sub_device_manager_tracker_->get_active_sub_device_manager(); + TT_FATAL( + active_sub_device_manager->get_trace(trace_id) == nullptr, + "Trace already exists for trace_id {} on device {}'s active sub-device manager {}", + trace_id, + this->id_, + active_sub_device_manager->id()); + + auto& trace_buffer = active_sub_device_manager->create_trace(trace_id); + *trace_buffer->desc = trace_desc; + Trace::initialize_buffer(this->command_queue(cq_id), trace_buffer); + this->mark_allocations_unsafe(); +} + void Device::replay_trace(const uint8_t cq_id, const uint32_t tid, const bool blocking) { ZoneScoped; TracyTTMetalReplayTrace(this->id(), tid); diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index 6407d33ef79..7adc0425ae3 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -312,7 +312,7 @@ EnqueueTraceCommand::EnqueueTraceCommand( uint32_t command_queue_id, IDevice* device, SystemMemoryManager& manager, - std::shared_ptr& descriptor, + std::shared_ptr& descriptor, Buffer& buffer, std::array & expected_num_workers_completed, NOC noc_index, diff --git a/tt_metal/impl/dispatch/hardware_command_queue.cpp b/tt_metal/impl/dispatch/hardware_command_queue.cpp index 1b14b1302fc..062f7e34f5f 100644 --- a/tt_metal/impl/dispatch/hardware_command_queue.cpp +++ b/tt_metal/impl/dispatch/hardware_command_queue.cpp @@ -701,7 +701,7 @@ volatile bool CommandQueue::is_dprint_server_hung() { return dprint_server_hang; volatile bool CommandQueue::is_noc_hung() { return illegal_noc_txn_hang; } -void CommandQueue::record_begin(const uint32_t tid, std::shared_ptr ctx) { +void CommandQueue::record_begin(const uint32_t tid, std::shared_ptr ctx) { auto num_sub_devices = this->device_->num_sub_devices(); // Record the original value of expected_num_workers_completed, and reset it to 0. std::copy( diff --git a/tt_metal/impl/trace/trace.cpp b/tt_metal/impl/trace/trace.cpp index d0f8082a048..1dc733b4bcd 100644 --- a/tt_metal/impl/trace/trace.cpp +++ b/tt_metal/impl/trace/trace.cpp @@ -70,7 +70,7 @@ std::atomic Trace::global_trace_id = 0; uint32_t Trace::next_id() { return global_trace_id++; } std::shared_ptr Trace::create_empty_trace_buffer() { - return std::make_shared(std::make_shared(), nullptr); + return std::make_shared(std::make_shared(), nullptr); } void Trace::initialize_buffer(CommandQueue& cq, const std::shared_ptr& trace_buffer) { diff --git a/tt_metal/impl/trace/trace_buffer.cpp b/tt_metal/impl/trace/trace_buffer.cpp index f9efe387ad9..6491170a95e 100644 --- a/tt_metal/impl/trace/trace_buffer.cpp +++ b/tt_metal/impl/trace/trace_buffer.cpp @@ -6,10 +6,11 @@ #include #include +#include "buffer.hpp" namespace tt::tt_metal { -TraceBuffer::TraceBuffer(std::shared_ptr desc, std::shared_ptr buffer) : +TraceBuffer::TraceBuffer(std::shared_ptr desc, std::shared_ptr buffer) : desc(std::move(desc)), buffer(std::move(buffer)) {} TraceBuffer::~TraceBuffer() { diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 43eb33ebe7a..2264834644d 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -1353,6 +1353,10 @@ LightMetalBinary LightMetalEndCapture() { return {}; } +void LoadTrace(IDevice* device, const uint8_t cq_id, const uint32_t trace_id, const TraceDescriptor& trace_desc) { + device->load_trace(cq_id, trace_id, trace_desc); +} + void Synchronize(IDevice* device, const std::optional cq_id, tt::stl::Span sub_device_ids) { if (std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr) { if (cq_id.has_value()) {