Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightMetal - Add LoadTrace() API and move TraceDescriptor out of detail namespace (#17039) #17313

Merged
merged 1 commit into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
LoadTrace
=========

.. doxygenfunction:: tt::tt_metal::v0::LoadTrace
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ CommandQueue
ReplayTrace
ReleaseTrace
EnqueueTrace
LoadTrace
LightMetalBeginCapture
LightMetalEndCapture
Finish
Expand Down
4 changes: 2 additions & 2 deletions tt_metal/api/tt-metalium/command_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class EnqueueTraceCommand : public Command {
Buffer& buffer;
IDevice* device;
SystemMemoryManager& manager;
std::shared_ptr<detail::TraceDescriptor>& descriptor;
std::shared_ptr<TraceDescriptor>& descriptor;
std::array<uint32_t, dispatch_constants::DISPATCH_MESSAGE_ENTRIES>& expected_num_workers_completed;
bool clear_count;
NOC noc_index;
Expand All @@ -170,7 +170,7 @@ class EnqueueTraceCommand : public Command {
uint32_t command_queue_id,
IDevice* device,
SystemMemoryManager& manager,
std::shared_ptr<detail::TraceDescriptor>& descriptor,
std::shared_ptr<TraceDescriptor>& descriptor,
Buffer& buffer,
std::array<uint32_t, dispatch_constants::DISPATCH_MESSAGE_ENTRIES>& expected_num_workers_completed,
NOC noc_index,
Expand Down
4 changes: 4 additions & 0 deletions tt_metal/api/tt-metalium/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class SubDevice;
class JitBuildEnv;
class CommandQueue;
class TraceBuffer;
struct TraceDescriptor;

inline namespace v0 {

Expand Down Expand Up @@ -186,6 +187,9 @@ class IDevice {
virtual uint32_t get_trace_buffers_size() const = 0;
virtual void set_trace_buffers_size(uint32_t size) = 0;

// Light Metal
virtual void load_trace(uint8_t cq_id, uint32_t trace_id, const TraceDescriptor& trace_desc) = 0;

virtual bool using_slow_dispatch() const = 0;
virtual bool using_fast_dispatch() const = 0;

Expand Down
4 changes: 4 additions & 0 deletions tt_metal/api/tt-metalium/device_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "hardware_command_queue.hpp"
#include "sub_device_manager_tracker.hpp"
#include "sub_device_types.hpp"
#include "trace_buffer.hpp"
#include "span.hpp"
#include "program_cache.hpp"

Expand Down Expand Up @@ -179,6 +180,9 @@ class Device : public IDevice {
uint32_t get_trace_buffers_size() const override { return trace_buffers_size_; }
void set_trace_buffers_size(uint32_t size) override { trace_buffers_size_ = size; }

// Light Metal
void load_trace(uint8_t cq_id, uint32_t trace_id, const TraceDescriptor& trace_desc) override;

bool using_slow_dispatch() const override;
bool using_fast_dispatch() const override;

Expand Down
4 changes: 2 additions & 2 deletions tt_metal/api/tt-metalium/hardware_command_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class CommandQueue {
volatile bool is_dprint_server_hung();
volatile bool is_noc_hung();

void record_begin(const uint32_t tid, std::shared_ptr<detail::TraceDescriptor> ctx);
void record_begin(const uint32_t tid, std::shared_ptr<TraceDescriptor> ctx);
void record_end();
void set_num_worker_sems_on_dispatch(uint32_t num_worker_sems);
void set_go_signal_noc_data_on_dispatch(const vector_memcpy_aligned<uint32_t>& go_signal_noc_data);
Expand Down Expand Up @@ -149,7 +149,7 @@ class CommandQueue {
uint32_t size_B;
uint32_t completion_queue_reader_core = 0;
std::optional<uint32_t> tid_;
std::shared_ptr<detail::TraceDescriptor> trace_ctx;
std::shared_ptr<TraceDescriptor> trace_ctx;
std::thread completion_queue_thread;
SystemMemoryManager& manager;
std::array<tt::tt_metal::WorkerConfigBufferMgr, dispatch_constants::DISPATCH_MESSAGE_ENTRIES> config_buffer_mgr;
Expand Down
17 changes: 17 additions & 0 deletions tt_metal/api/tt-metalium/host_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace tt {
namespace tt_metal {

class CommandQueue;
struct TraceDescriptor;
inline namespace v0 {

class Program;
Expand Down Expand Up @@ -906,6 +907,22 @@ void LightMetalBeginCapture();
// clang-format on
LightMetalBinary LightMetalEndCapture();

// clang-format off
/**
* Load an existing trace descriptor onto a particular device and command queue and assign it as user-provided trace id. Useful for Light Metal Binary replay.
*
* Return value: void
*
* | Argument | Description | Type | Valid Range | Required |
* |--------------|------------------------------------------------------------------------|-------------------------------|------------------------------------|----------|
* | device | The device to load the trace onto. | IDevice * | | Yes |
* | cq_id | The command queue id to load the trace onto. | uint8_t | | Yes |
* | trace_id | A unique id to represent the trace on device. | uint32_t | | Yes |
* | trace_desc | The trace descriptor to load onto the device. | TraceDescriptor& | | Yes |
*/
// clang-format on
void LoadTrace(IDevice* device, uint8_t cq_id, uint32_t trace_id, const TraceDescriptor& trace_desc);

// clang-format off
/**
* Read device side profiler data and dump results into device side CSV log
Expand Down
3 changes: 3 additions & 0 deletions tt_metal/api/tt-metalium/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this<MeshDevic
uint32_t get_trace_buffers_size() const override;
void set_trace_buffers_size(uint32_t size) override;

// Light Metal
void load_trace(uint8_t cq_id, uint32_t trace_id, const TraceDescriptor& trace_desc) override;

bool using_slow_dispatch() const override;
bool using_fast_dispatch() const override;

Expand Down
12 changes: 7 additions & 5 deletions tt_metal/api/tt-metalium/trace_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
#include <utility>
#include <variant>

#include "buffer.hpp"
#include "sub_device_types.hpp"

namespace tt::tt_metal {

namespace detail {
// Forward decl to avoid including header
inline namespace v0 {
class Buffer;
}

struct TraceDescriptor {
struct Descriptor {
uint32_t num_completion_worker_cores = 0;
Expand All @@ -30,13 +33,12 @@ struct TraceDescriptor {
std::vector<SubDeviceId> sub_device_ids;
std::vector<uint32_t> data;
};
} // namespace detail

struct TraceBuffer {
std::shared_ptr<detail::TraceDescriptor> desc;
std::shared_ptr<TraceDescriptor> desc;
std::shared_ptr<Buffer> buffer;

TraceBuffer(std::shared_ptr<detail::TraceDescriptor> desc, std::shared_ptr<Buffer> buffer);
TraceBuffer(std::shared_ptr<TraceDescriptor> desc, std::shared_ptr<Buffer> buffer);
~TraceBuffer();
};

Expand Down
6 changes: 6 additions & 0 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,12 @@ void MeshDevice::set_trace_buffers_size(uint32_t size) {
reference_device()->set_trace_buffers_size(size);
}

// Light Metal
void MeshDevice::load_trace(const uint8_t cq_id, const uint32_t trace_id, const TraceDescriptor& trace_desc) {
TT_THROW("load_trace() is not supported on MeshDevice - use individual devices instead");
reference_device()->load_trace(cq_id, trace_id, trace_desc);
}

// Dispatch and initialization
bool MeshDevice::initialize(const uint8_t num_hw_cqs, size_t l1_small_size, size_t trace_region_size, tt::stl::Span<const std::uint32_t> l1_bank_remap, bool minimal) {
work_executor_->initialize();
Expand Down
18 changes: 18 additions & 0 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,24 @@ void Device::end_trace(const uint8_t cq_id, const uint32_t tid) {
this->mark_allocations_unsafe();
}

// Load the TraceDescriptor for a given trace_id to the device. A combination of logic from begin/end_trace.
void Device::load_trace(const uint8_t cq_id, const uint32_t trace_id, const TraceDescriptor& trace_desc) {
this->mark_allocations_safe();

auto* active_sub_device_manager = sub_device_manager_tracker_->get_active_sub_device_manager();
TT_FATAL(
active_sub_device_manager->get_trace(trace_id) == nullptr,
"Trace already exists for trace_id {} on device {}'s active sub-device manager {}",
trace_id,
this->id_,
active_sub_device_manager->id());

auto& trace_buffer = active_sub_device_manager->create_trace(trace_id);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(outside of this PR)

These are weird APIs that return references to smart pointers. @dmakoviichuk-tt wdyt? I think these:

    const std::unique_ptr<Allocator>& get_initialized_allocator(SubDeviceId sub_device_id) const;
    std::unique_ptr<Allocator>& sub_device_allocator(SubDeviceId sub_device_id);
    std::shared_ptr<TraceBuffer>& create_trace(uint32_t tid);

Should be re-written as:

    const Allocator* get_initialized_allocator(SubDeviceId sub_device_id) const;
    Allocator* sub_device_allocator(SubDeviceId sub_device_id);
    std::shared_ptr<TraceBuffer> create_trace(uint32_t tid);

*trace_buffer->desc = trace_desc;
Trace::initialize_buffer(this->command_queue(cq_id), trace_buffer);
this->mark_allocations_unsafe();
}

void Device::replay_trace(const uint8_t cq_id, const uint32_t tid, const bool blocking) {
ZoneScoped;
TracyTTMetalReplayTrace(this->id(), tid);
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ EnqueueTraceCommand::EnqueueTraceCommand(
uint32_t command_queue_id,
IDevice* device,
SystemMemoryManager& manager,
std::shared_ptr<detail::TraceDescriptor>& descriptor,
std::shared_ptr<TraceDescriptor>& descriptor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yikes, I think all of these can and should be const &. (outside of the PR, just a drive by comment:) )

Buffer& buffer,
std::array<uint32_t, dispatch_constants::DISPATCH_MESSAGE_ENTRIES> & expected_num_workers_completed,
NOC noc_index,
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/dispatch/hardware_command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ volatile bool CommandQueue::is_dprint_server_hung() { return dprint_server_hang;

volatile bool CommandQueue::is_noc_hung() { return illegal_noc_txn_hang; }

void CommandQueue::record_begin(const uint32_t tid, std::shared_ptr<detail::TraceDescriptor> ctx) {
void CommandQueue::record_begin(const uint32_t tid, std::shared_ptr<TraceDescriptor> ctx) {
auto num_sub_devices = this->device_->num_sub_devices();
// Record the original value of expected_num_workers_completed, and reset it to 0.
std::copy(
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/trace/trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ std::atomic<uint32_t> Trace::global_trace_id = 0;
uint32_t Trace::next_id() { return global_trace_id++; }

std::shared_ptr<TraceBuffer> Trace::create_empty_trace_buffer() {
return std::make_shared<TraceBuffer>(std::make_shared<detail::TraceDescriptor>(), nullptr);
return std::make_shared<TraceBuffer>(std::make_shared<TraceDescriptor>(), nullptr);
}

void Trace::initialize_buffer(CommandQueue& cq, const std::shared_ptr<TraceBuffer>& trace_buffer) {
Expand Down
3 changes: 2 additions & 1 deletion tt_metal/impl/trace/trace_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

#include <utility>
#include <device.hpp>
#include "buffer.hpp"

namespace tt::tt_metal {

TraceBuffer::TraceBuffer(std::shared_ptr<detail::TraceDescriptor> desc, std::shared_ptr<Buffer> buffer) :
TraceBuffer::TraceBuffer(std::shared_ptr<TraceDescriptor> desc, std::shared_ptr<Buffer> buffer) :
desc(std::move(desc)), buffer(std::move(buffer)) {}

TraceBuffer::~TraceBuffer() {
Expand Down
4 changes: 4 additions & 0 deletions tt_metal/tt_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,10 @@ LightMetalBinary LightMetalEndCapture() {
return {};
}

void LoadTrace(IDevice* device, const uint8_t cq_id, const uint32_t trace_id, const TraceDescriptor& trace_desc) {
device->load_trace(cq_id, trace_id, trace_desc);
}

void Synchronize(IDevice* device, const std::optional<uint8_t> cq_id, tt::stl::Span<const SubDeviceId> sub_device_ids) {
if (std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr) {
if (cq_id.has_value()) {
Expand Down
Loading