Skip to content

Commit

Permalink
LightMetal - Add LoadTrace() API and move TraceDescriptor out of deta…
Browse files Browse the repository at this point in the history
…il namespace (#17039)

 - Will be used by Light Metal replay after upcoming PR, when executing
   a metal-trace traced program from binary, TraceDescriptor is
   extracted from flatbuffer binary and loaded to device through this API.
 - Unrelated - Change trace_buffer.hpp to use fwd decl Buffer instead of
   buffer.hpp incl to reduce dependencies on users of trace_buffer.hpp
  • Loading branch information
kmabeeTT committed Jan 29, 2025
1 parent 2a82748 commit 0971366
Show file tree
Hide file tree
Showing 16 changed files with 77 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
LoadTrace
=========

.. doxygenfunction:: tt::tt_metal::v0::LoadTrace
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ CommandQueue
ReplayTrace
ReleaseTrace
EnqueueTrace
LoadTrace
LightMetalBeginCapture
LightMetalEndCapture
Finish
Expand Down
4 changes: 2 additions & 2 deletions tt_metal/api/tt-metalium/command_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class EnqueueTraceCommand : public Command {
Buffer& buffer;
IDevice* device;
SystemMemoryManager& manager;
std::shared_ptr<detail::TraceDescriptor>& descriptor;
std::shared_ptr<TraceDescriptor>& descriptor;
std::array<uint32_t, dispatch_constants::DISPATCH_MESSAGE_ENTRIES>& expected_num_workers_completed;
bool clear_count;
NOC noc_index;
Expand All @@ -170,7 +170,7 @@ class EnqueueTraceCommand : public Command {
uint32_t command_queue_id,
IDevice* device,
SystemMemoryManager& manager,
std::shared_ptr<detail::TraceDescriptor>& descriptor,
std::shared_ptr<TraceDescriptor>& descriptor,
Buffer& buffer,
std::array<uint32_t, dispatch_constants::DISPATCH_MESSAGE_ENTRIES>& expected_num_workers_completed,
NOC noc_index,
Expand Down
4 changes: 4 additions & 0 deletions tt_metal/api/tt-metalium/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class SubDevice;
class JitBuildEnv;
class CommandQueue;
class TraceBuffer;
struct TraceDescriptor;

inline namespace v0 {

Expand Down Expand Up @@ -186,6 +187,9 @@ class IDevice {
virtual uint32_t get_trace_buffers_size() const = 0;
virtual void set_trace_buffers_size(uint32_t size) = 0;

// Light Metal
virtual void load_trace(uint8_t cq_id, uint32_t trace_id, const TraceDescriptor& trace_desc) = 0;

virtual bool using_slow_dispatch() const = 0;
virtual bool using_fast_dispatch() const = 0;

Expand Down
4 changes: 4 additions & 0 deletions tt_metal/api/tt-metalium/device_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "hardware_command_queue.hpp"
#include "sub_device_manager_tracker.hpp"
#include "sub_device_types.hpp"
#include "trace_buffer.hpp"
#include "span.hpp"
#include "program_cache.hpp"

Expand Down Expand Up @@ -179,6 +180,9 @@ class Device : public IDevice {
uint32_t get_trace_buffers_size() const override { return trace_buffers_size_; }
void set_trace_buffers_size(uint32_t size) override { trace_buffers_size_ = size; }

// Light Metal
void load_trace(uint8_t cq_id, uint32_t trace_id, const TraceDescriptor& trace_desc) override;

bool using_slow_dispatch() const override;
bool using_fast_dispatch() const override;

Expand Down
4 changes: 2 additions & 2 deletions tt_metal/api/tt-metalium/hardware_command_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class CommandQueue {
volatile bool is_dprint_server_hung();
volatile bool is_noc_hung();

void record_begin(const uint32_t tid, std::shared_ptr<detail::TraceDescriptor> ctx);
void record_begin(const uint32_t tid, std::shared_ptr<TraceDescriptor> ctx);
void record_end();
void set_num_worker_sems_on_dispatch(uint32_t num_worker_sems);
void set_go_signal_noc_data_on_dispatch(const vector_memcpy_aligned<uint32_t>& go_signal_noc_data);
Expand Down Expand Up @@ -149,7 +149,7 @@ class CommandQueue {
uint32_t size_B;
uint32_t completion_queue_reader_core = 0;
std::optional<uint32_t> tid_;
std::shared_ptr<detail::TraceDescriptor> trace_ctx;
std::shared_ptr<TraceDescriptor> trace_ctx;
std::thread completion_queue_thread;
SystemMemoryManager& manager;
std::array<tt::tt_metal::WorkerConfigBufferMgr, dispatch_constants::DISPATCH_MESSAGE_ENTRIES> config_buffer_mgr;
Expand Down
17 changes: 17 additions & 0 deletions tt_metal/api/tt-metalium/host_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace tt {
namespace tt_metal {

class CommandQueue;
struct TraceDescriptor;
inline namespace v0 {

class Program;
Expand Down Expand Up @@ -906,6 +907,22 @@ void LightMetalBeginCapture();
// clang-format on
LightMetalBinary LightMetalEndCapture();

// clang-format off
/**
* Load an existing trace descriptor onto a particular device and command queue and assign it as user-provided trace id. Useful for Light Metal Binary replay.
*
* Return value: void
*
* | Argument | Description | Type | Valid Range | Required |
* |--------------|------------------------------------------------------------------------|-------------------------------|------------------------------------|----------|
* | device | The device to load the trace onto. | IDevice * | | Yes |
* | cq_id | The command queue id to load the trace onto. | uint8_t | | Yes |
* | trace_id | A unique id to represent the trace on device. | uint32_t | | Yes |
* | trace_desc | The trace descriptor to load onto the device. | TraceDescriptor& | | Yes |
*/
// clang-format on
void LoadTrace(IDevice* device, uint8_t cq_id, uint32_t trace_id, const TraceDescriptor& trace_desc);

// clang-format off
/**
* Read device side profiler data and dump results into device side CSV log
Expand Down
3 changes: 3 additions & 0 deletions tt_metal/api/tt-metalium/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
uint32_t get_trace_buffers_size() const override;
void set_trace_buffers_size(uint32_t size) override;

// Light Metal
void load_trace(uint8_t cq_id, uint32_t trace_id, const TraceDescriptor& trace_desc) override;

bool using_slow_dispatch() const override;
bool using_fast_dispatch() const override;

Expand Down
12 changes: 7 additions & 5 deletions tt_metal/api/tt-metalium/trace_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
#include <utility>
#include <variant>

#include "buffer.hpp"
#include "sub_device_types.hpp"

namespace tt::tt_metal {

namespace detail {
// Forward decl to avoid including header
inline namespace v0 {
class Buffer;
}

struct TraceDescriptor {
struct Descriptor {
uint32_t num_completion_worker_cores = 0;
Expand All @@ -30,13 +33,12 @@ struct TraceDescriptor {
std::vector<SubDeviceId> sub_device_ids;
std::vector<uint32_t> data;
};
} // namespace detail

struct TraceBuffer {
std::shared_ptr<detail::TraceDescriptor> desc;
std::shared_ptr<TraceDescriptor> desc;
std::shared_ptr<Buffer> buffer;

TraceBuffer(std::shared_ptr<detail::TraceDescriptor> desc, std::shared_ptr<Buffer> buffer);
TraceBuffer(std::shared_ptr<TraceDescriptor> desc, std::shared_ptr<Buffer> buffer);
~TraceBuffer();
};

Expand Down
6 changes: 6 additions & 0 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,12 @@ void MeshDevice::set_trace_buffers_size(uint32_t size) {
reference_device()->set_trace_buffers_size(size);
}

// Light Metal
void MeshDevice::load_trace(const uint8_t cq_id, const uint32_t trace_id, const TraceDescriptor& trace_desc) {
TT_THROW("load_trace() is not supported on MeshDevice - use individual devices instead");
reference_device()->load_trace(cq_id, trace_id, trace_desc);
}

// Dispatch and initialization
bool MeshDevice::initialize(const uint8_t num_hw_cqs, size_t l1_small_size, size_t trace_region_size, tt::stl::Span<const std::uint32_t> l1_bank_remap, bool minimal) {
work_executor_->initialize();
Expand Down
18 changes: 18 additions & 0 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,24 @@ void Device::end_trace(const uint8_t cq_id, const uint32_t tid) {
this->mark_allocations_unsafe();
}

// Load the TraceDescriptor for a given trace_id to the device. A combination of logic from begin/end_trace.
void Device::load_trace(const uint8_t cq_id, const uint32_t trace_id, const TraceDescriptor& trace_desc) {
this->mark_allocations_safe();

auto* active_sub_device_manager = sub_device_manager_tracker_->get_active_sub_device_manager();
TT_FATAL(
active_sub_device_manager->get_trace(trace_id) == nullptr,
"Trace already exists for trace_id {} on device {}'s active sub-device manager {}",
trace_id,
this->id_,
active_sub_device_manager->id());

auto& trace_buffer = active_sub_device_manager->create_trace(trace_id);
*trace_buffer->desc = trace_desc;
Trace::initialize_buffer(this->command_queue(cq_id), trace_buffer);
this->mark_allocations_unsafe();
}

void Device::replay_trace(const uint8_t cq_id, const uint32_t tid, const bool blocking) {
ZoneScoped;
TracyTTMetalReplayTrace(this->id(), tid);
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ EnqueueTraceCommand::EnqueueTraceCommand(
uint32_t command_queue_id,
IDevice* device,
SystemMemoryManager& manager,
std::shared_ptr<detail::TraceDescriptor>& descriptor,
std::shared_ptr<TraceDescriptor>& descriptor,
Buffer& buffer,
std::array<uint32_t, dispatch_constants::DISPATCH_MESSAGE_ENTRIES> & expected_num_workers_completed,
NOC noc_index,
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/dispatch/hardware_command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ volatile bool CommandQueue::is_dprint_server_hung() { return dprint_server_hang;

volatile bool CommandQueue::is_noc_hung() { return illegal_noc_txn_hang; }

void CommandQueue::record_begin(const uint32_t tid, std::shared_ptr<detail::TraceDescriptor> ctx) {
void CommandQueue::record_begin(const uint32_t tid, std::shared_ptr<TraceDescriptor> ctx) {
auto num_sub_devices = this->device_->num_sub_devices();
// Record the original value of expected_num_workers_completed, and reset it to 0.
std::copy(
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/trace/trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ std::atomic<uint32_t> Trace::global_trace_id = 0;
uint32_t Trace::next_id() { return global_trace_id++; }

std::shared_ptr<TraceBuffer> Trace::create_empty_trace_buffer() {
return std::make_shared<TraceBuffer>(std::make_shared<detail::TraceDescriptor>(), nullptr);
return std::make_shared<TraceBuffer>(std::make_shared<TraceDescriptor>(), nullptr);
}

void Trace::initialize_buffer(CommandQueue& cq, const std::shared_ptr<TraceBuffer>& trace_buffer) {
Expand Down
3 changes: 2 additions & 1 deletion tt_metal/impl/trace/trace_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

#include <utility>
#include <device.hpp>
#include "buffer.hpp"

namespace tt::tt_metal {

TraceBuffer::TraceBuffer(std::shared_ptr<detail::TraceDescriptor> desc, std::shared_ptr<Buffer> buffer) :
TraceBuffer::TraceBuffer(std::shared_ptr<TraceDescriptor> desc, std::shared_ptr<Buffer> buffer) :
desc(std::move(desc)), buffer(std::move(buffer)) {}

TraceBuffer::~TraceBuffer() {
Expand Down
4 changes: 4 additions & 0 deletions tt_metal/tt_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,10 @@ LightMetalBinary LightMetalEndCapture() {
return {};
}

void LoadTrace(IDevice* device, const uint8_t cq_id, const uint32_t trace_id, const TraceDescriptor& trace_desc) {
device->load_trace(cq_id, trace_id, trace_desc);
}

void Synchronize(IDevice* device, const std::optional<uint8_t> cq_id, tt::stl::Span<const SubDeviceId> sub_device_ids) {
if (std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr) {
if (cq_id.has_value()) {
Expand Down

0 comments on commit 0971366

Please sign in to comment.