From 35440dea844b09a66d06977d5c65ff17cbbb3a36 Mon Sep 17 00:00:00 2001 From: Artem Yerofieiev Date: Thu, 30 Jan 2025 04:57:57 +0000 Subject: [PATCH 1/4] =?UTF-8?q?Extract=20CQ=20interface.=20HWCQ=20is=20an?= =?UTF-8?q?=20implementation=20detail=20now.=20Removed=20from=20public=20a?= =?UTF-8?q?pi=20=F0=9F=8E=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tt_metal/api/tt-metalium/command_queue.hpp | 240 +++++------------- .../tt-metalium/command_queue_commands.hpp | 203 +++++++++++++++ .../tt-metalium/command_queue_interface.hpp | 47 ++++ tt_metal/api/tt-metalium/device_impl.hpp | 1 - tt_metal/api/tt-metalium/program_impl.hpp | 7 +- tt_metal/distributed/mesh_command_queue.cpp | 4 +- tt_metal/distributed/mesh_workload_utils.cpp | 1 - tt_metal/impl/CMakeLists.txt | 2 +- tt_metal/impl/buffers/circular_buffer.cpp | 2 +- tt_metal/impl/buffers/dispatch.cpp | 14 +- tt_metal/impl/buffers/dispatch.hpp | 8 +- tt_metal/impl/device/device.cpp | 9 +- ...d_queue.cpp => command_queue_commands.cpp} | 156 +++++++----- .../impl/dispatch/hardware_command_queue.cpp | 111 ++++---- .../dispatch}/hardware_command_queue.hpp | 123 +++------ tt_metal/impl/module.mk | 2 +- tt_metal/impl/program/dispatch.cpp | 2 +- tt_metal/impl/program/program.cpp | 1 - .../sub_device/sub_device_manager_tracker.cpp | 2 +- tt_metal/tt_metal.cpp | 10 +- 20 files changed, 549 insertions(+), 396 deletions(-) create mode 100644 tt_metal/api/tt-metalium/command_queue_commands.hpp rename tt_metal/impl/dispatch/{command_queue.cpp => command_queue_commands.cpp} (81%) rename tt_metal/{api/tt-metalium => impl/dispatch}/hardware_command_queue.hpp (60%) diff --git a/tt_metal/api/tt-metalium/command_queue.hpp b/tt_metal/api/tt-metalium/command_queue.hpp index 5099810b71d..76df29533b1 100644 --- a/tt_metal/api/tt-metalium/command_queue.hpp +++ b/tt_metal/api/tt-metalium/command_queue.hpp @@ -1,217 +1,113 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 #pragma once -#include -#include #include -#include +#include #include -#include #include -#include -#include -#include "env_lib.hpp" -#include "command_queue_interface.hpp" -#include "device_command.hpp" -#include "lock_free_queue.hpp" -#include "program_command_sequence.hpp" #include "worker_config_buffer.hpp" -#include "program_impl.hpp" #include "trace_buffer.hpp" -#include "hardware_command_queue.hpp" +#include "memcpy.hpp" +#include "command_queue_interface.hpp" namespace tt::tt_metal { -inline namespace v0 { -class BufferRegion; +inline namespace v0 { class Event; -class Trace; -using RuntimeArgs = std::vector>; - +class Program; +class Kernel; } // namespace v0 -// Only contains the types of commands which are enqueued onto the device -enum class EnqueueCommandType { - ENQUEUE_READ_BUFFER, - ENQUEUE_WRITE_BUFFER, - GET_BUF_ADDR, - ADD_BUFFER_TO_PROGRAM, - SET_RUNTIME_ARGS, - ENQUEUE_PROGRAM, - ENQUEUE_TRACE, - ENQUEUE_RECORD_EVENT, - ENQUEUE_WAIT_FOR_EVENT, - FINISH, - FLUSH, - TERMINATE, - INVALID -}; - -string EnqueueCommandTypeToString(EnqueueCommandType ctype); +class CommandQueue { +public: + virtual ~CommandQueue() = default; -class Command { - public: - Command() {} - virtual void process() {}; - virtual EnqueueCommandType type() = 0; -}; + virtual const CoreCoord& virtual_enqueue_program_dispatch_core() const = 0; + virtual const CoreCoord& completion_queue_writer_core() const = 0; -class EnqueueProgramCommand : public Command { - private: - uint32_t command_queue_id; - IDevice* device; - NOC noc_index; - Program& program; - SystemMemoryManager& manager; - WorkerConfigBufferMgr& config_buffer_mgr; - CoreCoord dispatch_core; - CoreType dispatch_core_type; - uint32_t expected_num_workers_completed; - uint32_t packed_write_max_unicast_sub_cmds; - uint32_t dispatch_message_addr; - uint32_t multicast_cores_launch_message_wptr = 0; - uint32_t unicast_cores_launch_message_wptr = 0; - // TODO: There will be multiple ids once programs support spanning multiple sub_devices - SubDeviceId sub_device_id = SubDeviceId{0}; - - public: - EnqueueProgramCommand( - uint32_t command_queue_id, - IDevice* device, - NOC noc_index, - Program& program, - CoreCoord& dispatch_core, - SystemMemoryManager& manager, - WorkerConfigBufferMgr& config_buffer_mgr, - uint32_t expected_num_workers_completed, - uint32_t multicast_cores_launch_message_wptr, - uint32_t unicast_cores_launch_message_wptr, - SubDeviceId sub_device_id); - - void process(); - - EnqueueCommandType type() { return EnqueueCommandType::ENQUEUE_PROGRAM; } - - constexpr bool has_side_effects() { return true; } -}; + virtual volatile bool is_dprint_server_hung() = 0; + virtual volatile bool is_noc_hung() = 0; -class EnqueueRecordEventCommand : public Command { - private: - uint32_t command_queue_id; - IDevice* device; - NOC noc_index; - SystemMemoryManager& manager; - uint32_t event_id; - tt::stl::Span expected_num_workers_completed; - tt::stl::Span sub_device_ids; - bool clear_count; - bool write_barrier; - - public: - EnqueueRecordEventCommand( - uint32_t command_queue_id, - IDevice* device, - NOC noc_index, - SystemMemoryManager& manager, - uint32_t event_id, - tt::stl::Span expected_num_workers_completed, - tt::stl::Span sub_device_ids, - bool clear_count = false, - bool write_barrier = true); + virtual void record_begin(const uint32_t tid, const std::shared_ptr& ctx) = 0; + virtual void record_end() = 0; + virtual void set_num_worker_sems_on_dispatch(uint32_t num_worker_sems) = 0; + virtual void set_go_signal_noc_data_on_dispatch(const vector_memcpy_aligned& go_signal_noc_data) = 0; - void process(); + virtual void reset_worker_state( + bool reset_launch_msg_state, + uint32_t num_sub_devices, + const vector_memcpy_aligned& go_signal_noc_data) = 0; - EnqueueCommandType type() { return EnqueueCommandType::ENQUEUE_RECORD_EVENT; } + virtual uint32_t id() const = 0; + virtual std::optional tid() const = 0; - constexpr bool has_side_effects() { return false; } -}; + virtual SystemMemoryManager& sysmem_manager() = 0; -class EnqueueWaitForEventCommand : public Command { - private: - uint32_t command_queue_id; - IDevice* device; - SystemMemoryManager& manager; - const Event& sync_event; - CoreType dispatch_core_type; - bool clear_count; + virtual void terminate() = 0; - public: - EnqueueWaitForEventCommand( - uint32_t command_queue_id, - IDevice* device, - SystemMemoryManager& manager, - const Event& sync_event, - bool clear_count = false); + virtual IDevice* device() = 0; - void process(); + // These functions are temporarily needed since MeshCommandQueue relies on the CommandQueue object + virtual uint32_t get_expected_num_workers_completed_for_sub_device(uint32_t sub_device_index) const = 0; + virtual void set_expected_num_workers_completed_for_sub_device(uint32_t sub_device_index, uint32_t num_workers) = 0; + virtual WorkerConfigBufferMgr& get_config_buffer_mgr(uint32_t index) = 0; - EnqueueCommandType type() { return EnqueueCommandType::ENQUEUE_WAIT_FOR_EVENT; } + virtual void enqueue_trace(const uint32_t trace_id, bool blocking) = 0; - constexpr bool has_side_effects() { return false; } -}; + virtual void enqueue_program(Program& program, bool blocking) = 0; -class EnqueueTraceCommand : public Command { - private: - uint32_t command_queue_id; - Buffer& buffer; - IDevice* device; - SystemMemoryManager& manager; - std::shared_ptr& descriptor; - std::array& expected_num_workers_completed; - bool clear_count; - NOC noc_index; - CoreCoord dispatch_core; - public: - EnqueueTraceCommand( - uint32_t command_queue_id, - IDevice* device, - SystemMemoryManager& manager, - std::shared_ptr& descriptor, + virtual void enqueue_read_buffer( + std::shared_ptr& buffer, + void* dst, + const BufferRegion& region, + bool blocking, + tt::stl::Span sub_device_ids = {}) = 0; + virtual void enqueue_read_buffer( Buffer& buffer, - std::array& expected_num_workers_completed, - NOC noc_index, - CoreCoord dispatch_core); - - void process(); - - EnqueueCommandType type() { return EnqueueCommandType::ENQUEUE_TRACE; } - - constexpr bool has_side_effects() { return true; } -}; - -class EnqueueTerminateCommand : public Command { - private: - uint32_t command_queue_id; - IDevice* device; - SystemMemoryManager& manager; + void* dst, + const BufferRegion& region, + bool blocking, + tt::stl::Span sub_device_ids = {}) = 0; - public: - EnqueueTerminateCommand(uint32_t command_queue_id, IDevice* device, SystemMemoryManager& manager); - - void process(); - - EnqueueCommandType type() { return EnqueueCommandType::TERMINATE; } + virtual void enqueue_record_event( + const std::shared_ptr& event, + bool clear_count = false, + tt::stl::Span sub_device_ids = {}) = 0; + virtual void enqueue_wait_for_event(const std::shared_ptr& sync_event, bool clear_count = false) = 0; + + virtual void enqueue_write_buffer( + const std::variant, std::shared_ptr>& buffer, + HostDataType src, + const BufferRegion& region, + bool blocking, + tt::stl::Span sub_device_ids = {}) = 0; + virtual void enqueue_write_buffer( + Buffer& buffer, + const void* src, + const BufferRegion& region, + bool blocking, + tt::stl::Span sub_device_ids = {}) = 0; - constexpr bool has_side_effects() { return false; } + virtual void finish(tt::stl::Span sub_device_ids) = 0; }; +// Temporarily here. Need to eliminate this +using RuntimeArgs = std::vector>; // Primitives used to place host only operations on the SW Command Queue. // These are used in functions exposed through tt_metal.hpp or host_api.hpp -void EnqueueGetBufferAddr(uint32_t* dst_buf_addr, const Buffer* buffer, bool blocking); -void EnqueueSetRuntimeArgs( +void GetBufferAddr(uint32_t* dst_buf_addr, const Buffer* buffer, bool blocking); +void SetRuntimeArgs( const std::shared_ptr& kernel, const CoreCoord& core_coord, std::shared_ptr runtime_args_ptr, bool blocking); -void EnqueueAddBufferToProgram( +void AddBufferToProgram( const std::variant, std::shared_ptr>& buffer, Program& program, bool blocking); } // namespace tt::tt_metal - -std::ostream& operator<<(std::ostream& os, const tt::tt_metal::EnqueueCommandType& type); diff --git a/tt_metal/api/tt-metalium/command_queue_commands.hpp b/tt_metal/api/tt-metalium/command_queue_commands.hpp new file mode 100644 index 00000000000..7b1503c5f3f --- /dev/null +++ b/tt_metal/api/tt-metalium/command_queue_commands.hpp @@ -0,0 +1,203 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "env_lib.hpp" +#include "command_queue_interface.hpp" +#include "device_command.hpp" +#include "lock_free_queue.hpp" +#include "program_command_sequence.hpp" +#include "worker_config_buffer.hpp" +#include "program_impl.hpp" +#include "trace_buffer.hpp" + +namespace tt::tt_metal { +inline namespace v0 { + +class BufferRegion; +class Event; +class Trace; + +} // namespace v0 + +// Only contains the types of commands which are enqueued onto the device +enum class EnqueueCommandType { + ENQUEUE_READ_BUFFER, + ENQUEUE_WRITE_BUFFER, + GET_BUF_ADDR, + ADD_BUFFER_TO_PROGRAM, + SET_RUNTIME_ARGS, + ENQUEUE_PROGRAM, + ENQUEUE_TRACE, + ENQUEUE_RECORD_EVENT, + ENQUEUE_WAIT_FOR_EVENT, + FINISH, + FLUSH, + TERMINATE, + INVALID +}; + +string EnqueueCommandTypeToString(EnqueueCommandType ctype); + +class Command { +public: + Command() {} + virtual void process() {}; + virtual EnqueueCommandType type() = 0; +}; + +class EnqueueProgramCommand : public Command { +private: + uint32_t command_queue_id; + IDevice* device; + NOC noc_index; + Program& program; + SystemMemoryManager& manager; + WorkerConfigBufferMgr& config_buffer_mgr; + CoreCoord dispatch_core; + CoreType dispatch_core_type; + uint32_t expected_num_workers_completed; + uint32_t packed_write_max_unicast_sub_cmds; + uint32_t dispatch_message_addr; + uint32_t multicast_cores_launch_message_wptr = 0; + uint32_t unicast_cores_launch_message_wptr = 0; + // TODO: There will be multiple ids once programs support spanning multiple sub_devices + SubDeviceId sub_device_id = SubDeviceId{0}; + +public: + EnqueueProgramCommand( + uint32_t command_queue_id, + IDevice* device, + NOC noc_index, + Program& program, + CoreCoord& dispatch_core, + SystemMemoryManager& manager, + WorkerConfigBufferMgr& config_buffer_mgr, + uint32_t expected_num_workers_completed, + uint32_t multicast_cores_launch_message_wptr, + uint32_t unicast_cores_launch_message_wptr, + SubDeviceId sub_device_id); + + void process(); + + EnqueueCommandType type() { return EnqueueCommandType::ENQUEUE_PROGRAM; } + + constexpr bool has_side_effects() { return true; } +}; + +class EnqueueRecordEventCommand : public Command { +private: + uint32_t command_queue_id; + IDevice* device; + NOC noc_index; + SystemMemoryManager& manager; + uint32_t event_id; + tt::stl::Span expected_num_workers_completed; + tt::stl::Span sub_device_ids; + bool clear_count; + bool write_barrier; + +public: + EnqueueRecordEventCommand( + uint32_t command_queue_id, + IDevice* device, + NOC noc_index, + SystemMemoryManager& manager, + uint32_t event_id, + tt::stl::Span expected_num_workers_completed, + tt::stl::Span sub_device_ids, + bool clear_count = false, + bool write_barrier = true); + + void process(); + + EnqueueCommandType type() { return EnqueueCommandType::ENQUEUE_RECORD_EVENT; } + + constexpr bool has_side_effects() { return false; } +}; + +class EnqueueWaitForEventCommand : public Command { +private: + uint32_t command_queue_id; + IDevice* device; + SystemMemoryManager& manager; + const Event& sync_event; + CoreType dispatch_core_type; + bool clear_count; + +public: + EnqueueWaitForEventCommand( + uint32_t command_queue_id, + IDevice* device, + SystemMemoryManager& manager, + const Event& sync_event, + bool clear_count = false); + + void process(); + + EnqueueCommandType type() { return EnqueueCommandType::ENQUEUE_WAIT_FOR_EVENT; } + + constexpr bool has_side_effects() { return false; } +}; + +class EnqueueTraceCommand : public Command { +private: + uint32_t command_queue_id; + Buffer& buffer; + IDevice* device; + SystemMemoryManager& manager; + std::shared_ptr& descriptor; + std::array& expected_num_workers_completed; + bool clear_count; + NOC noc_index; + CoreCoord dispatch_core; + +public: + EnqueueTraceCommand( + uint32_t command_queue_id, + IDevice* device, + SystemMemoryManager& manager, + std::shared_ptr& descriptor, + Buffer& buffer, + std::array& expected_num_workers_completed, + NOC noc_index, + CoreCoord dispatch_core); + + void process(); + + EnqueueCommandType type() { return EnqueueCommandType::ENQUEUE_TRACE; } + + constexpr bool has_side_effects() { return true; } +}; + +class EnqueueTerminateCommand : public Command { +private: + uint32_t command_queue_id; + IDevice* device; + SystemMemoryManager& manager; + +public: + EnqueueTerminateCommand(uint32_t command_queue_id, IDevice* device, SystemMemoryManager& manager); + + void process(); + + EnqueueCommandType type() { return EnqueueCommandType::TERMINATE; } + + constexpr bool has_side_effects() { return false; } +}; + +} // namespace tt::tt_metal + +std::ostream& operator<<(std::ostream& os, const tt::tt_metal::EnqueueCommandType& type); diff --git a/tt_metal/api/tt-metalium/command_queue_interface.hpp b/tt_metal/api/tt-metalium/command_queue_interface.hpp index 582859c6e1c..63283acb306 100644 --- a/tt_metal/api/tt-metalium/command_queue_interface.hpp +++ b/tt_metal/api/tt-metalium/command_queue_interface.hpp @@ -14,6 +14,7 @@ #include "hal.hpp" #include "dispatch_settings.hpp" #include "helpers.hpp" +#include "buffer.hpp" // FIXME: Don't do this in header files using namespace tt::tt_metal; @@ -43,6 +44,52 @@ enum class CommandQueueHostAddrType : uint8_t { UNRESERVED = 4 }; +// Used so the host knows how to properly copy data into user space from the completion queue (in hugepages) +struct ReadBufferDescriptor { + TensorMemoryLayout buffer_layout; + uint32_t page_size; + uint32_t padded_page_size; + std::shared_ptr buffer_page_mapping; + void* dst; + uint32_t dst_offset; + uint32_t num_pages_read; + uint32_t cur_dev_page_id; + uint32_t starting_host_page_id; + + ReadBufferDescriptor( + TensorMemoryLayout buffer_layout, + uint32_t page_size, + uint32_t padded_page_size, + void* dst, + uint32_t dst_offset, + uint32_t num_pages_read, + uint32_t cur_dev_page_id, + uint32_t starting_host_page_id = 0, + const std::shared_ptr& buffer_page_mapping = nullptr) : + buffer_layout(buffer_layout), + page_size(page_size), + padded_page_size(padded_page_size), + buffer_page_mapping(buffer_page_mapping), + dst(dst), + dst_offset(dst_offset), + num_pages_read(num_pages_read), + cur_dev_page_id(cur_dev_page_id), + starting_host_page_id(starting_host_page_id) {} +}; + +// Used so host knows data in completion queue is just an event ID +struct ReadEventDescriptor { + uint32_t event_id; + uint32_t global_offset; + + explicit ReadEventDescriptor(uint32_t event) : event_id(event), global_offset(0) {} + + void set_global_offset(uint32_t offset) { global_offset = offset; } + uint32_t get_global_event_id() { return global_offset + event_id; } +}; + +using CompletionReaderVariant = std::variant; + // Contains constants related to FD // // Deprecated note: for constant values, use tt::tt_metal::dispatch::DispatchConstants instead. diff --git a/tt_metal/api/tt-metalium/device_impl.hpp b/tt_metal/api/tt-metalium/device_impl.hpp index 6facfb97b29..1b76dc09d88 100644 --- a/tt_metal/api/tt-metalium/device_impl.hpp +++ b/tt_metal/api/tt-metalium/device_impl.hpp @@ -19,7 +19,6 @@ #include "hal.hpp" #include "command_queue_interface.hpp" #include "command_queue.hpp" -#include "hardware_command_queue.hpp" #include "sub_device_manager_tracker.hpp" #include "sub_device_types.hpp" #include "trace_buffer.hpp" diff --git a/tt_metal/api/tt-metalium/program_impl.hpp b/tt_metal/api/tt-metalium/program_impl.hpp index 12665674381..96d2e0c30f8 100644 --- a/tt_metal/api/tt-metalium/program_impl.hpp +++ b/tt_metal/api/tt-metalium/program_impl.hpp @@ -56,9 +56,12 @@ namespace distributed { class MeshWorkload; } // namespace distributed +class JitBuildOptions; class EnqueueProgramCommand; class CommandQueue; -class JitBuildOptions; +// Must be removed. Only here because its a damn friend of a Program +class HWCommandQueue; + namespace detail{ class Program_; @@ -232,7 +235,7 @@ class Program { template friend void program_dispatch::finalize_program_offsets(T&, IDevice*); template friend uint32_t program_dispatch::program_base_addr_on_core(WorkloadType&, DeviceType, HalProgrammableCoreType); - friend CommandQueue; + friend HWCommandQueue; friend EnqueueProgramCommand; friend distributed::MeshWorkload; friend detail::Internal_; diff --git a/tt_metal/distributed/mesh_command_queue.cpp b/tt_metal/distributed/mesh_command_queue.cpp index 15ac0cc4f74..7a4fda8842f 100644 --- a/tt_metal/distributed/mesh_command_queue.cpp +++ b/tt_metal/distributed/mesh_command_queue.cpp @@ -210,7 +210,7 @@ void MeshCommandQueue::read_shard_from_device( buffer_dispatch::copy_sharded_buffer_from_core_to_completion_queue( core_id, *shard_view, dispatch_params, sub_device_ids, cores[core_id], this->dispatch_core_type()); if (dispatch_params.pages_per_txn > 0) { - auto read_descriptor = std::get( + auto read_descriptor = std::get( *buffer_dispatch::generate_sharded_buffer_read_descriptor(dst, dispatch_params, *shard_view)); buffer_dispatch::copy_completion_queue_data_into_user_space( read_descriptor, mmio_device_id, channel, id_, device->sysmem_manager(), exit_condition); @@ -222,7 +222,7 @@ void MeshCommandQueue::read_shard_from_device( buffer_dispatch::copy_interleaved_buffer_to_completion_queue( dispatch_params, *shard_view, sub_device_ids, this->dispatch_core_type()); if (dispatch_params.pages_per_txn > 0) { - auto read_descriptor = std::get( + auto read_descriptor = std::get( *buffer_dispatch::generate_interleaved_buffer_read_descriptor(dst, dispatch_params, *shard_view)); buffer_dispatch::copy_completion_queue_data_into_user_space( read_descriptor, mmio_device_id, channel, id_, device->sysmem_manager(), exit_condition); diff --git a/tt_metal/distributed/mesh_workload_utils.cpp b/tt_metal/distributed/mesh_workload_utils.cpp index c15ba8cb230..73f42ad93ad 100644 --- a/tt_metal/distributed/mesh_workload_utils.cpp +++ b/tt_metal/distributed/mesh_workload_utils.cpp @@ -4,7 +4,6 @@ #include #include -#include #include "tt_metal/impl/program/dispatch.hpp" diff --git a/tt_metal/impl/CMakeLists.txt b/tt_metal/impl/CMakeLists.txt index e23a1cffa7c..0558711d018 100644 --- a/tt_metal/impl/CMakeLists.txt +++ b/tt_metal/impl/CMakeLists.txt @@ -22,7 +22,7 @@ set(IMPL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/program/program.cpp ${CMAKE_CURRENT_SOURCE_DIR}/program/dispatch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/debug_tools.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/command_queue.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/command_queue_commands.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/hardware_command_queue.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/launch_message_ring_buffer_state.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/worker_config_buffer.cpp diff --git a/tt_metal/impl/buffers/circular_buffer.cpp b/tt_metal/impl/buffers/circular_buffer.cpp index 60e7057fe74..c85389241d1 100644 --- a/tt_metal/impl/buffers/circular_buffer.cpp +++ b/tt_metal/impl/buffers/circular_buffer.cpp @@ -15,7 +15,7 @@ namespace { inline void GetBufferAddress(const tt::tt_metal::Buffer* buffer, uint32_t* address_on_host) { - EnqueueGetBufferAddr(address_on_host, buffer, false); + GetBufferAddr(address_on_host, buffer, false); } } // namespace diff --git a/tt_metal/impl/buffers/dispatch.cpp b/tt_metal/impl/buffers/dispatch.cpp index 8b29da1de2a..f90e54dc8ce 100644 --- a/tt_metal/impl/buffers/dispatch.cpp +++ b/tt_metal/impl/buffers/dispatch.cpp @@ -766,14 +766,14 @@ void copy_interleaved_buffer_to_completion_queue( } // Functions used to copy buffer data from completion queue into user space -std::shared_ptr generate_sharded_buffer_read_descriptor( +std::shared_ptr generate_sharded_buffer_read_descriptor( void* dst, ShardedBufferReadDispatchParams& dispatch_params, Buffer& buffer) { // Increment the src_page_index after the Read Buffer Descriptor has been populated // for the current core/txn auto initial_src_page_index = dispatch_params.src_page_index; dispatch_params.src_page_index += dispatch_params.pages_per_txn; - return std::make_shared( - std::in_place_type, + return std::make_shared( + std::in_place_type, buffer.buffer_layout(), buffer.page_size(), dispatch_params.padded_page_size, @@ -785,10 +785,10 @@ std::shared_ptr generate_sharded_ dispatch_params.buffer_page_mapping); } -std::shared_ptr generate_interleaved_buffer_read_descriptor( +std::shared_ptr generate_interleaved_buffer_read_descriptor( void* dst, BufferReadDispatchParams& dispatch_params, Buffer& buffer) { - return std::make_shared( - std::in_place_type, + return std::make_shared( + std::in_place_type, buffer.buffer_layout(), buffer.page_size(), dispatch_params.padded_page_size, @@ -799,7 +799,7 @@ std::shared_ptr generate_interlea } void copy_completion_queue_data_into_user_space( - const detail::ReadBufferDescriptor& read_buffer_descriptor, + const ReadBufferDescriptor& read_buffer_descriptor, chip_id_t mmio_device_id, uint16_t channel, uint32_t cq_id, diff --git a/tt_metal/impl/buffers/dispatch.hpp b/tt_metal/impl/buffers/dispatch.hpp index fb80f9f9013..0cda3654610 100644 --- a/tt_metal/impl/buffers/dispatch.hpp +++ b/tt_metal/impl/buffers/dispatch.hpp @@ -6,7 +6,7 @@ #include #include -#include // Need this for ReadBufferDesriptor -> this should be moved to a separate header +#include #include "buffer.hpp" namespace tt::tt_metal { @@ -72,7 +72,7 @@ void copy_interleaved_buffer_to_completion_queue( CoreType dispatch_core_type); void copy_completion_queue_data_into_user_space( - const detail::ReadBufferDescriptor& read_buffer_descriptor, + const ReadBufferDescriptor& read_buffer_descriptor, chip_id_t mmio_device_id, uint16_t channel, uint32_t cq_id, @@ -82,9 +82,9 @@ void copy_completion_queue_data_into_user_space( std::vector get_cores_for_sharded_buffer( bool width_split, const std::shared_ptr& buffer_page_mapping, Buffer& buffer); -std::shared_ptr<::tt::tt_metal::detail::CompletionReaderVariant> generate_sharded_buffer_read_descriptor( +std::shared_ptr<::tt::tt_metal::CompletionReaderVariant> generate_sharded_buffer_read_descriptor( void* dst, ShardedBufferReadDispatchParams& dispatch_params, Buffer& buffer); -std::shared_ptr<::tt::tt_metal::detail::CompletionReaderVariant> generate_interleaved_buffer_read_descriptor( +std::shared_ptr<::tt::tt_metal::CompletionReaderVariant> generate_interleaved_buffer_read_descriptor( void* dst, BufferReadDispatchParams& dispatch_params, Buffer& buffer); } // namespace buffer_dispatch diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index 6298cc49da8..32d050f1213 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -31,7 +31,9 @@ #include #include #include + #include "impl/dispatch/topology.hpp" +#include "impl/dispatch/hardware_command_queue.hpp" namespace tt { @@ -970,7 +972,8 @@ void Device::init_command_queue_host() { sysmem_manager_ = std::make_unique(this->id_, this->num_hw_cqs()); command_queues_.reserve(num_hw_cqs()); for (size_t cq_id = 0; cq_id < num_hw_cqs(); cq_id++) { - command_queues_.push_back(std::make_unique(this, cq_id, dispatch_downstream_noc, completion_queue_reader_core_)); + command_queues_.push_back( + std::make_unique(this, cq_id, dispatch_downstream_noc, completion_queue_reader_core_)); } } @@ -1068,7 +1071,7 @@ bool Device::close() { TT_THROW("Cannot close device {} that has not been initialized!", this->id_); } - for (const std::unique_ptr& hw_command_queue : command_queues_) { + for (const auto& hw_command_queue : command_queues_) { if (hw_command_queue->sysmem_manager().get_bypass_mode()) { hw_command_queue->record_end(); } @@ -1599,7 +1602,7 @@ uint8_t Device::noc_data_start_index(SubDeviceId sub_device_id, bool mcast_data, } CoreCoord Device::virtual_program_dispatch_core(uint8_t cq_id) const { - return this->command_queues_[cq_id]->virtual_enqueue_program_dispatch_core; + return this->command_queues_[cq_id]->virtual_enqueue_program_dispatch_core(); } // Main source to get NOC idx for dispatch core diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue_commands.cpp similarity index 81% rename from tt_metal/impl/dispatch/command_queue.cpp rename to tt_metal/impl/dispatch/command_queue_commands.cpp index 256eb20d42e..c680349748c 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue_commands.cpp @@ -1,8 +1,8 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 -#include +#include #include #include @@ -20,7 +20,6 @@ #include #include #include -#include #include "program_command_sequence.hpp" #include "tt_metal/command_queue.hpp" #include @@ -114,8 +113,9 @@ EnqueueProgramCommand::EnqueueProgramCommand( this->device = device; this->dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id()); this->packed_write_max_unicast_sub_cmds = get_packed_write_max_unicast_sub_cmds(this->device); - this->dispatch_message_addr = dispatch_constants::get( - this->dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE) + + this->dispatch_message_addr = + dispatch_constants::get(this->dispatch_core_type) + .get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE) + dispatch_constants::get(this->dispatch_core_type).get_dispatch_message_offset(this->sub_device_id.to_index()); } @@ -158,7 +158,13 @@ void EnqueueProgramCommand::process() { dispatch_metadata, program.get_program_binary_status(device->id())); // Issue dispatch commands for this program - program_dispatch::write_program_command_sequence(cached_program_command_sequence, this->manager, this->command_queue_id, this->dispatch_core_type, dispatch_metadata.stall_first, dispatch_metadata.stall_before_program); + program_dispatch::write_program_command_sequence( + cached_program_command_sequence, + this->manager, + this->command_queue_id, + this->dispatch_core_type, + dispatch_metadata.stall_first, + dispatch_metadata.stall_before_program); // Kernel Binaries are committed to DRAM, the first time the program runs on device. Reflect this on host. program.set_program_binary_status(device->id(), ProgramBinaryStatus::Committed); } @@ -198,7 +204,8 @@ void EnqueueRecordEventCommand::process() { uint32_t num_worker_counters = this->sub_device_ids.size(); uint32_t cmd_sequence_sizeB = - hal.get_alignment(HalMemType::HOST) * num_worker_counters + // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT + hal.get_alignment(HalMemType::HOST) * + num_worker_counters + // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT packed_write_sizeB + // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WRITE_PACKED + unicast subcmds + event // payload align( @@ -210,22 +217,29 @@ void EnqueueRecordEventCommand::process() { HugepageDeviceCommand command_sequence(cmd_region, cmd_sequence_sizeB); CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(this->device->id()); - uint32_t dispatch_message_base_addr = dispatch_constants::get( - dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); + uint32_t dispatch_message_base_addr = + dispatch_constants::get(dispatch_core_type) + .get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); uint32_t last_index = num_worker_counters - 1; // We only need the write barrier for the last wait cmd for (uint32_t i = 0; i < last_index; ++i) { auto offset_index = this->sub_device_ids[i].to_index(); - uint32_t dispatch_message_addr = dispatch_message_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(offset_index); + uint32_t dispatch_message_addr = + dispatch_message_base_addr + + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(offset_index); command_sequence.add_dispatch_wait( false, dispatch_message_addr, this->expected_num_workers_completed[offset_index], this->clear_count); - } auto offset_index = this->sub_device_ids[last_index].to_index(); - uint32_t dispatch_message_addr = dispatch_message_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(offset_index); + uint32_t dispatch_message_addr = + dispatch_message_base_addr + + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(offset_index); command_sequence.add_dispatch_wait( - this->write_barrier, dispatch_message_addr, this->expected_num_workers_completed[offset_index], this->clear_count); + this->write_barrier, + dispatch_message_addr, + this->expected_num_workers_completed[offset_index], + this->clear_count); CoreType core_type = dispatch_core_manager::instance().get_dispatch_core_type(this->device->id()); uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(this->device->id()); @@ -246,8 +260,10 @@ void EnqueueRecordEventCommand::process() { event_payloads[cq_id] = {event_payload.data(), event_payload.size() * sizeof(uint32_t)}; } - uint32_t completion_q0_last_event_addr = dispatch_constants::get(core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q0_LAST_EVENT); - uint32_t completion_q1_last_event_addr = dispatch_constants::get(core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q1_LAST_EVENT); + uint32_t completion_q0_last_event_addr = dispatch_constants::get(core_type).get_device_command_queue_addr( + CommandQueueDeviceAddrType::COMPLETION_Q0_LAST_EVENT); + uint32_t completion_q1_last_event_addr = dispatch_constants::get(core_type).get_device_command_queue_addr( + CommandQueueDeviceAddrType::COMPLETION_Q1_LAST_EVENT); uint32_t address = this->command_queue_id == 0 ? completion_q0_last_event_addr : completion_q1_last_event_addr; const uint32_t packed_write_max_unicast_sub_cmds = get_packed_write_max_unicast_sub_cmds(this->device); command_sequence.add_dispatch_write_packed( @@ -288,13 +304,18 @@ EnqueueWaitForEventCommand::EnqueueWaitForEventCommand( } void EnqueueWaitForEventCommand::process() { - uint32_t cmd_sequence_sizeB = hal.get_alignment(HalMemType::HOST); // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT + uint32_t cmd_sequence_sizeB = + hal.get_alignment(HalMemType::HOST); // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT void* cmd_region = this->manager.issue_queue_reserve(cmd_sequence_sizeB, this->command_queue_id); HugepageDeviceCommand command_sequence(cmd_region, cmd_sequence_sizeB); - uint32_t completion_q0_last_event_addr = dispatch_constants::get(this->dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q0_LAST_EVENT); - uint32_t completion_q1_last_event_addr = dispatch_constants::get(this->dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q1_LAST_EVENT); + uint32_t completion_q0_last_event_addr = + dispatch_constants::get(this->dispatch_core_type) + .get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q0_LAST_EVENT); + uint32_t completion_q1_last_event_addr = + dispatch_constants::get(this->dispatch_core_type) + .get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q1_LAST_EVENT); uint32_t last_completed_event_address = sync_event.cq_id == 0 ? completion_q0_last_event_addr : completion_q1_last_event_addr; @@ -314,7 +335,7 @@ EnqueueTraceCommand::EnqueueTraceCommand( SystemMemoryManager& manager, std::shared_ptr& descriptor, Buffer& buffer, - std::array & expected_num_workers_completed, + std::array& expected_num_workers_completed, NOC noc_index, CoreCoord dispatch_core) : command_queue_id(command_queue_id), @@ -330,16 +351,22 @@ EnqueueTraceCommand::EnqueueTraceCommand( void EnqueueTraceCommand::process() { uint32_t num_sub_devices = descriptor->descriptors.size(); uint32_t pcie_alignment = hal.get_alignment(HalMemType::HOST); - uint32_t go_signals_cmd_size = align(sizeof(CQPrefetchCmd) + sizeof(CQDispatchCmd), pcie_alignment) * descriptor->descriptors.size(); + uint32_t go_signals_cmd_size = + align(sizeof(CQPrefetchCmd) + sizeof(CQDispatchCmd), pcie_alignment) * descriptor->descriptors.size(); uint32_t cmd_sequence_sizeB = - this->device->dispatch_s_enabled() * hal.get_alignment(HalMemType::HOST) + // dispatch_d -> dispatch_s sem update (send only if dispatch_s is running) - go_signals_cmd_size + // go signal cmd - (hal.get_alignment(HalMemType::HOST) + // wait to ensure that reset go signal was processed (dispatch_d) - // when dispatch_s and dispatch_d are running on 2 cores, workers update dispatch_s. dispatch_s is responsible for resetting worker count - // and giving dispatch_d the latest worker state. This is encapsulated in the dispatch_s wait command (only to be sent when dispatch is distributed - // on 2 cores) - (this->device->distributed_dispatcher()) * hal.get_alignment(HalMemType::HOST)) * num_sub_devices + + this->device->dispatch_s_enabled() * + hal.get_alignment( + HalMemType::HOST) + // dispatch_d -> dispatch_s sem update (send only if dispatch_s is running) + go_signals_cmd_size + // go signal cmd + (hal.get_alignment( + HalMemType::HOST) + // wait to ensure that reset go signal was processed (dispatch_d) + // when dispatch_s and dispatch_d are running on 2 cores, workers update dispatch_s. + // dispatch_s is responsible for resetting worker count and giving dispatch_d the + // latest worker state. This is encapsulated in the dispatch_s wait command (only to + // be sent when dispatch is distributed on 2 cores) + (this->device->distributed_dispatcher()) * hal.get_alignment(HalMemType::HOST)) * + num_sub_devices + hal.get_alignment(HalMemType::HOST); // CQ_PREFETCH_CMD_EXEC_BUF void* cmd_region = this->manager.issue_queue_reserve(cmd_sequence_sizeB, this->command_queue_id); @@ -349,25 +376,34 @@ void EnqueueTraceCommand::process() { DispatcherSelect dispatcher_for_go_signal = DispatcherSelect::DISPATCH_MASTER; if (this->device->dispatch_s_enabled()) { uint16_t index_bitmask = 0; - for (const auto &id : descriptor->sub_device_ids) { + for (const auto& id : descriptor->sub_device_ids) { index_bitmask |= 1 << id.to_index(); } command_sequence.add_notify_dispatch_s_go_signal_cmd(false, index_bitmask); dispatcher_for_go_signal = DispatcherSelect::DISPATCH_SLAVE; } CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id()); - uint32_t dispatch_message_base_addr = dispatch_constants::get( - dispatch_core_type).get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); + uint32_t dispatch_message_base_addr = + dispatch_constants::get(dispatch_core_type) + .get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE); go_msg_t reset_launch_message_read_ptr_go_signal; reset_launch_message_read_ptr_go_signal.signal = RUN_MSG_RESET_READ_PTR; reset_launch_message_read_ptr_go_signal.master_x = (uint8_t)this->dispatch_core.x; reset_launch_message_read_ptr_go_signal.master_y = (uint8_t)this->dispatch_core.y; for (const auto& [id, desc] : descriptor->descriptors) { - const auto& noc_data_start_idx = device->noc_data_start_index(id, desc.num_traced_programs_needing_go_signal_multicast, desc.num_traced_programs_needing_go_signal_unicast); - const auto& num_noc_mcast_txns = desc.num_traced_programs_needing_go_signal_multicast ? device->num_noc_mcast_txns(id) : 0; - const auto& num_noc_unicast_txns = desc.num_traced_programs_needing_go_signal_unicast ? device->num_noc_unicast_txns(id) : 0; - reset_launch_message_read_ptr_go_signal.dispatch_message_offset = (uint8_t)dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(id.to_index()); - uint32_t dispatch_message_addr = dispatch_message_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(id.to_index()); + const auto& noc_data_start_idx = device->noc_data_start_index( + id, + desc.num_traced_programs_needing_go_signal_multicast, + desc.num_traced_programs_needing_go_signal_unicast); + const auto& num_noc_mcast_txns = + desc.num_traced_programs_needing_go_signal_multicast ? device->num_noc_mcast_txns(id) : 0; + const auto& num_noc_unicast_txns = + desc.num_traced_programs_needing_go_signal_unicast ? device->num_noc_unicast_txns(id) : 0; + reset_launch_message_read_ptr_go_signal.dispatch_message_offset = + (uint8_t)dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(id.to_index()); + uint32_t dispatch_message_addr = + dispatch_message_base_addr + + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(id.to_index()); auto index = id.to_index(); // Wait to ensure that all kernels have completed. Then send the reset_rd_ptr go_signal. command_sequence.add_dispatch_go_signal_mcast( @@ -379,21 +415,30 @@ void EnqueueTraceCommand::process() { noc_data_start_idx, dispatcher_for_go_signal); if (desc.num_traced_programs_needing_go_signal_multicast) { - this->expected_num_workers_completed[index] += device->num_worker_cores(HalProgrammableCoreType::TENSIX, id); + this->expected_num_workers_completed[index] += + device->num_worker_cores(HalProgrammableCoreType::TENSIX, id); } if (desc.num_traced_programs_needing_go_signal_unicast) { - this->expected_num_workers_completed[index] += device->num_worker_cores(HalProgrammableCoreType::ACTIVE_ETH, id); + this->expected_num_workers_completed[index] += + device->num_worker_cores(HalProgrammableCoreType::ACTIVE_ETH, id); } } - // Wait to ensure that all workers have reset their read_ptr. dispatch_d will stall until all workers have completed this step, before sending kernel config data to workers - // or notifying dispatch_s that its safe to send the go_signal. - // Clear the dispatch <--> worker semaphore, since trace starts at 0. - for (const auto &id : descriptor->sub_device_ids) { + // Wait to ensure that all workers have reset their read_ptr. dispatch_d will stall until all workers have completed + // this step, before sending kernel config data to workers or notifying dispatch_s that its safe to send the + // go_signal. Clear the dispatch <--> worker semaphore, since trace starts at 0. + for (const auto& id : descriptor->sub_device_ids) { auto index = id.to_index(); - uint32_t dispatch_message_addr = dispatch_message_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(index); + uint32_t dispatch_message_addr = + dispatch_message_base_addr + dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(index); if (this->device->distributed_dispatcher()) { command_sequence.add_dispatch_wait( - false, dispatch_message_addr, this->expected_num_workers_completed[index], this->clear_count, false, true, 1); + false, + dispatch_message_addr, + this->expected_num_workers_completed[index], + this->clear_count, + false, + true, + 1); } command_sequence.add_dispatch_wait( false, dispatch_message_addr, this->expected_num_workers_completed[index], this->clear_count); @@ -449,7 +494,7 @@ void EnqueueTerminateCommand::process() { this->manager.fetch_queue_write(cmd_sequence_sizeB, this->command_queue_id); } -void EnqueueAddBufferToProgram( +void AddBufferToProgram( const std::variant, std::shared_ptr>& buffer, Program& program, bool blocking) { @@ -463,7 +508,7 @@ void EnqueueAddBufferToProgram( buffer); } -void EnqueueSetRuntimeArgs( +void SetRuntimeArgs( const std::shared_ptr& kernel, const CoreCoord& core_coord, std::shared_ptr runtime_args_ptr, @@ -486,7 +531,7 @@ void EnqueueSetRuntimeArgs( kernel->set_runtime_args(core_coord, resolved_runtime_args); } -void EnqueueGetBufferAddr(uint32_t* dst_buf_addr, const Buffer* buffer, bool blocking) { +void GetBufferAddr(uint32_t* dst_buf_addr, const Buffer* buffer, bool blocking) { *(static_cast(dst_buf_addr)) = buffer->address(); } @@ -598,7 +643,8 @@ void EventSynchronize(const std::shared_ptr& event) { event->event_id); while (event->device->sysmem_manager().get_last_completed_event(event->cq_id) < event->event_id) { - if (tt::llrt::RunTimeOptions::get_instance().get_test_mode_enabled() && tt::watcher_server_killed_due_to_error()) { + if (tt::llrt::RunTimeOptions::get_instance().get_test_mode_enabled() && + tt::watcher_server_killed_due_to_error()) { TT_FATAL( false, "Command Queue could not complete EventSynchronize. See {} for details.", @@ -648,15 +694,15 @@ v1::CommandQueueHandle v1::GetCommandQueue(IDevice* device, std::uint8_t cq_id) v1::CommandQueueHandle v1::GetDefaultCommandQueue(IDevice* device) { return GetCommandQueue(device, 0); } -void v1::EnqueueReadBuffer(CommandQueueHandle cq, const BufferHandle& buffer, std::byte *dst, bool blocking) { +void v1::EnqueueReadBuffer(CommandQueueHandle cq, const BufferHandle& buffer, std::byte* dst, bool blocking) { v0::EnqueueReadBuffer(GetDevice(cq)->command_queue(GetId(cq)), *buffer, dst, blocking); } -void v1::EnqueueWriteBuffer(CommandQueueHandle cq, const BufferHandle& buffer, const std::byte *src, bool blocking) { +void v1::EnqueueWriteBuffer(CommandQueueHandle cq, const BufferHandle& buffer, const std::byte* src, bool blocking) { v0::EnqueueWriteBuffer(GetDevice(cq)->command_queue(GetId(cq)), *buffer, src, blocking); } -void v1::EnqueueProgram(CommandQueueHandle cq, ProgramHandle &program, bool blocking) { +void v1::EnqueueProgram(CommandQueueHandle cq, ProgramHandle& program, bool blocking) { v0::EnqueueProgram(GetDevice(cq)->command_queue(GetId(cq)), program, blocking); } @@ -664,17 +710,13 @@ void v1::Finish(CommandQueueHandle cq, tt::stl::Span sub_devi v0::Finish(GetDevice(cq)->command_queue(GetId(cq))); } -IDevice* v1::GetDevice(CommandQueueHandle cq) { - return cq.device; -} +IDevice* v1::GetDevice(CommandQueueHandle cq) { return cq.device; } -std::uint8_t v1::GetId(CommandQueueHandle cq) { - return cq.id; -} +std::uint8_t v1::GetId(CommandQueueHandle cq) { return cq.id; } } // namespace tt::tt_metal -std::ostream& operator<<(std::ostream& os, EnqueueCommandType const& type) { +std::ostream& operator<<(std::ostream& os, const EnqueueCommandType& type) { switch (type) { case EnqueueCommandType::ENQUEUE_READ_BUFFER: os << "ENQUEUE_READ_BUFFER"; break; case EnqueueCommandType::ENQUEUE_WRITE_BUFFER: os << "ENQUEUE_WRITE_BUFFER"; break; diff --git a/tt_metal/impl/dispatch/hardware_command_queue.cpp b/tt_metal/impl/dispatch/hardware_command_queue.cpp index 4803d415ff5..08ed894c8e4 100644 --- a/tt_metal/impl/dispatch/hardware_command_queue.cpp +++ b/tt_metal/impl/dispatch/hardware_command_queue.cpp @@ -2,14 +2,18 @@ // // SPDX-License-Identifier: Apache-2.0 -#include +#include "hardware_command_queue.hpp" + #include #include "dprint_server.hpp" #include -#include #include #include +// Because we are a Friend of Program, accessing Program::get_program_transfer_info() and Program::get_kernels_buffer() +// MUST REMOVE +#include + #include "tt_metal/impl/buffers/dispatch.hpp" #include "tt_metal/impl/debug/watcher_server.hpp" #include "tt_metal/impl/program/dispatch.hpp" @@ -43,14 +47,14 @@ Buffer& get_buffer_object(const std::variant, std } // namespace -CommandQueue::CommandQueue(IDevice* device, uint32_t id, NOC noc_index, uint32_t completion_queue_reader_core) : +HWCommandQueue::HWCommandQueue(IDevice* device, uint32_t id, NOC noc_index, uint32_t completion_queue_reader_core) : manager(device->sysmem_manager()), completion_queue_thread{}, completion_queue_reader_core(completion_queue_reader_core) { ZoneScopedN("CommandQueue_constructor"); this->device_ = device; this->id_ = id; - this->noc_index = noc_index; + this->noc_index_ = noc_index; this->num_entries_in_completion_q = 0; this->num_completed_completion_q_reads = 0; @@ -76,16 +80,16 @@ CommandQueue::CommandQueue(IDevice* device, uint32_t id, NOC noc_index, uint32_t dispatch_core_manager::instance().dispatcher_d_core(device_->id(), channel, id); } } - this->virtual_enqueue_program_dispatch_core = + this->virtual_enqueue_program_dispatch_core_ = device_->virtual_core_from_logical_core(enqueue_program_dispatch_core, core_type); tt_cxy_pair completion_q_writer_location = dispatch_core_manager::instance().completion_queue_writer_core(device_->id(), channel, this->id_); - this->completion_queue_writer_core = CoreCoord(completion_q_writer_location.x, completion_q_writer_location.y); + this->completion_queue_writer_core_ = CoreCoord(completion_q_writer_location.x, completion_q_writer_location.y); this->exit_condition = false; - std::thread completion_queue_thread = std::thread(&CommandQueue::read_completion_queue, this); + std::thread completion_queue_thread = std::thread(&HWCommandQueue::read_completion_queue, this); this->completion_queue_thread = std::move(completion_queue_thread); // Set the affinity of the completion queue reader. set_device_thread_affinity(this->completion_queue_thread, this->completion_queue_reader_core); @@ -96,13 +100,13 @@ CommandQueue::CommandQueue(IDevice* device, uint32_t id, NOC noc_index, uint32_t reset_config_buffer_mgr(dispatch_constants::DISPATCH_MESSAGE_ENTRIES); } -uint32_t CommandQueue::id() const { return this->id_; } +uint32_t HWCommandQueue::id() const { return this->id_; } -std::optional CommandQueue::tid() const { return this->tid_; } +std::optional HWCommandQueue::tid() const { return this->tid_; } -SystemMemoryManager& CommandQueue::sysmem_manager() { return this->manager; } +SystemMemoryManager& HWCommandQueue::sysmem_manager() { return this->manager; } -void CommandQueue::set_num_worker_sems_on_dispatch(uint32_t num_worker_sems) { +void HWCommandQueue::set_num_worker_sems_on_dispatch(uint32_t num_worker_sems) { // Not needed for regular dispatch kernel if (!this->device_->dispatch_s_enabled()) { return; @@ -116,7 +120,7 @@ void CommandQueue::set_num_worker_sems_on_dispatch(uint32_t num_worker_sems) { this->manager.fetch_queue_write(cmd_sequence_sizeB, this->id_); } -void CommandQueue::set_go_signal_noc_data_on_dispatch(const vector_memcpy_aligned& go_signal_noc_data) { +void HWCommandQueue::set_go_signal_noc_data_on_dispatch(const vector_memcpy_aligned& go_signal_noc_data) { uint32_t pci_alignment = hal.get_alignment(HalMemType::HOST); uint32_t cmd_sequence_sizeB = align( sizeof(CQPrefetchCmd) + sizeof(CQDispatchCmd) + go_signal_noc_data.size() * sizeof(uint32_t), pci_alignment); @@ -130,21 +134,22 @@ void CommandQueue::set_go_signal_noc_data_on_dispatch(const vector_memcpy_aligne this->manager.fetch_queue_write(cmd_sequence_sizeB, this->id_); } -uint32_t CommandQueue::get_expected_num_workers_completed_for_sub_device(uint32_t sub_device_index) const { +uint32_t HWCommandQueue::get_expected_num_workers_completed_for_sub_device(uint32_t sub_device_index) const { TT_FATAL( sub_device_index < dispatch_constants::DISPATCH_MESSAGE_ENTRIES, "Expected sub_device_index to be less than dispatch_constants::DISPATCH_MESSAGE_ENTRIES"); return this->expected_num_workers_completed[sub_device_index]; } -void CommandQueue::set_expected_num_workers_completed_for_sub_device(uint32_t sub_device_index, uint32_t num_workers) { +void HWCommandQueue::set_expected_num_workers_completed_for_sub_device( + uint32_t sub_device_index, uint32_t num_workers) { TT_FATAL( sub_device_index < dispatch_constants::DISPATCH_MESSAGE_ENTRIES, "Expected sub_device_index to be less than dispatch_constants::DISPATCH_MESSAGE_ENTRIES"); this->expected_num_workers_completed[sub_device_index] = num_workers; } -void CommandQueue::reset_worker_dispatch_state_on_device(bool reset_launch_msg_state) { +void HWCommandQueue::reset_worker_dispatch_state_on_device(bool reset_launch_msg_state) { auto num_sub_devices = device_->num_sub_devices(); uint32_t go_signals_cmd_size = 0; if (reset_launch_msg_state) { @@ -183,8 +188,8 @@ void CommandQueue::reset_worker_dispatch_state_on_device(bool reset_launch_msg_s } go_msg_t reset_launch_message_read_ptr_go_signal; reset_launch_message_read_ptr_go_signal.signal = RUN_MSG_RESET_READ_PTR; - reset_launch_message_read_ptr_go_signal.master_x = (uint8_t)this->virtual_enqueue_program_dispatch_core.x; - reset_launch_message_read_ptr_go_signal.master_y = (uint8_t)this->virtual_enqueue_program_dispatch_core.y; + reset_launch_message_read_ptr_go_signal.master_x = (uint8_t)this->virtual_enqueue_program_dispatch_core_.x; + reset_launch_message_read_ptr_go_signal.master_y = (uint8_t)this->virtual_enqueue_program_dispatch_core_.y; for (uint32_t i = 0; i < num_sub_devices; ++i) { reset_launch_message_read_ptr_go_signal.dispatch_message_offset = (uint8_t)dispatch_constants::get(dispatch_core_type).get_dispatch_message_offset(i); @@ -225,7 +230,7 @@ void CommandQueue::reset_worker_dispatch_state_on_device(bool reset_launch_msg_s } } -void CommandQueue::reset_worker_state( +void HWCommandQueue::reset_worker_state( bool reset_launch_msg_state, uint32_t num_sub_devices, const vector_memcpy_aligned& go_signal_noc_data) { TT_FATAL(!this->manager.get_bypass_mode(), "Cannot reset worker state during trace capture"); // TODO: This could be further optimized by combining all of these into a single prefetch entry @@ -239,7 +244,7 @@ void CommandQueue::reset_worker_state( } } -CommandQueue::~CommandQueue() { +HWCommandQueue::~HWCommandQueue() { ZoneScopedN("HWCommandQueue_destructor"); if (this->exit_condition) { this->completion_queue_thread.join(); // We errored out already prior @@ -257,7 +262,7 @@ CommandQueue::~CommandQueue() { } } -void CommandQueue::increment_num_entries_in_completion_q() { +void HWCommandQueue::increment_num_entries_in_completion_q() { // Increment num_entries_in_completion_q and inform reader thread // that there is work in the completion queue to process this->num_entries_in_completion_q++; @@ -267,7 +272,7 @@ void CommandQueue::increment_num_entries_in_completion_q() { } } -void CommandQueue::set_exit_condition() { +void HWCommandQueue::set_exit_condition() { this->exit_condition = true; { std::lock_guard lock(this->reader_thread_cv_mutex); @@ -275,17 +280,17 @@ void CommandQueue::set_exit_condition() { } } -IDevice* CommandQueue::device() { return this->device_; } +IDevice* HWCommandQueue::device() { return this->device_; } template -void CommandQueue::enqueue_command(T& command, bool blocking, tt::stl::Span sub_device_ids) { +void HWCommandQueue::enqueue_command(T& command, bool blocking, tt::stl::Span sub_device_ids) { command.process(); if (blocking) { this->finish(sub_device_ids); } } -void CommandQueue::enqueue_read_buffer( +void HWCommandQueue::enqueue_read_buffer( std::shared_ptr& buffer, void* dst, const BufferRegion& region, @@ -296,7 +301,7 @@ void CommandQueue::enqueue_read_buffer( // Read buffer command is enqueued in the issue region and device writes requested buffer data into the completion // region -void CommandQueue::enqueue_read_buffer( +void HWCommandQueue::enqueue_read_buffer( Buffer& buffer, void* dst, const BufferRegion& region, @@ -348,7 +353,7 @@ void CommandQueue::enqueue_read_buffer( } } -void CommandQueue::enqueue_write_buffer( +void HWCommandQueue::enqueue_write_buffer( const std::variant, std::shared_ptr>& buffer, HostDataType src, const BufferRegion& region, @@ -365,11 +370,11 @@ void CommandQueue::enqueue_write_buffer( this->enqueue_write_buffer(buffer_obj, data, region, blocking, sub_device_ids); } -CoreType CommandQueue::get_dispatch_core_type() { +CoreType HWCommandQueue::get_dispatch_core_type() { return dispatch_core_manager::instance().get_dispatch_core_type(device_->id()); } -void CommandQueue::enqueue_write_buffer( +void HWCommandQueue::enqueue_write_buffer( Buffer& buffer, const void* src, const BufferRegion& region, @@ -389,7 +394,7 @@ void CommandQueue::enqueue_write_buffer( } } -void CommandQueue::enqueue_program(Program& program, bool blocking) { +void HWCommandQueue::enqueue_program(Program& program, bool blocking) { ZoneScopedN("HWCommandQueue_enqueue_program"); std::vector sub_device_ids = {program.determine_sub_device_ids(device_)}; TT_FATAL(sub_device_ids.size() == 1, "Programs must be executed on a single sub-device"); @@ -462,9 +467,9 @@ void CommandQueue::enqueue_program(Program& program, bool blocking) { auto command = EnqueueProgramCommand( this->id_, this->device_, - this->noc_index, + this->noc_index_, program, - this->virtual_enqueue_program_dispatch_core, + this->virtual_enqueue_program_dispatch_core_, this->manager, this->get_config_buffer_mgr(sub_device_index), expected_workers_completed, @@ -503,7 +508,7 @@ void CommandQueue::enqueue_program(Program& program, bool blocking) { expected_workers_completed); } -void CommandQueue::enqueue_record_event( +void HWCommandQueue::enqueue_record_event( const std::shared_ptr& event, bool clear_count, tt::stl::Span sub_device_ids) { ZoneScopedN("HWCommandQueue_enqueue_record_event"); @@ -522,7 +527,7 @@ void CommandQueue::enqueue_record_event( auto command = EnqueueRecordEventCommand( this->id_, this->device_, - this->noc_index, + this->noc_index_, this->manager, event->event_id, this->expected_num_workers_completed, @@ -536,12 +541,12 @@ void CommandQueue::enqueue_record_event( this->expected_num_workers_completed[id.to_index()] = 0; } } - this->issued_completion_q_reads.push(std::make_shared( - std::in_place_type, event->event_id)); + this->issued_completion_q_reads.push( + std::make_shared(std::in_place_type, event->event_id)); this->increment_num_entries_in_completion_q(); } -void CommandQueue::enqueue_wait_for_event(const std::shared_ptr& sync_event, bool clear_count) { +void HWCommandQueue::enqueue_wait_for_event(const std::shared_ptr& sync_event, bool clear_count) { ZoneScopedN("HWCommandQueue_enqueue_wait_for_event"); auto command = EnqueueWaitForEventCommand(this->id_, this->device_, this->manager, *sync_event, clear_count); @@ -552,7 +557,7 @@ void CommandQueue::enqueue_wait_for_event(const std::shared_ptr& sync_eve } } -void CommandQueue::enqueue_trace(const uint32_t trace_id, bool blocking) { +void HWCommandQueue::enqueue_trace(const uint32_t trace_id, bool blocking) { ZoneScopedN("HWCommandQueue_enqueue_trace"); auto trace_inst = this->device_->get_trace(trace_id); @@ -563,8 +568,8 @@ void CommandQueue::enqueue_trace(const uint32_t trace_id, bool blocking) { trace_inst->desc, *trace_inst->buffer, this->expected_num_workers_completed, - this->noc_index, - this->virtual_enqueue_program_dispatch_core); + this->noc_index_, + this->virtual_enqueue_program_dispatch_core_); this->enqueue_command(command, false, {}); @@ -593,7 +598,7 @@ void CommandQueue::enqueue_trace(const uint32_t trace_id, bool blocking) { } } -void CommandQueue::read_completion_queue() { +void HWCommandQueue::read_completion_queue() { chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(this->device_->id()); uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(this->device_->id()); while (true) { @@ -622,7 +627,7 @@ void CommandQueue::read_completion_queue() { std::visit( [&](auto&& read_descriptor) { using T = std::decay_t; - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { ZoneScopedN("CompletionQueueReadData"); buffer_dispatch::copy_completion_queue_data_into_user_space( read_descriptor, @@ -631,7 +636,7 @@ void CommandQueue::read_completion_queue() { this->id_, this->manager, this->exit_condition); - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { ZoneScopedN("CompletionQueueReadEvent"); uint32_t read_ptr = this->manager.get_completion_queue_read_ptr(this->id_); thread_local static std::vector dispatch_cmd_and_event( @@ -671,7 +676,7 @@ void CommandQueue::read_completion_queue() { } } -void CommandQueue::finish(tt::stl::Span sub_device_ids) { +void HWCommandQueue::finish(tt::stl::Span sub_device_ids) { ZoneScopedN("HWCommandQueue_finish"); tt::log_debug(tt::LogDispatch, "Finish for command queue {}", this->id_); std::shared_ptr event = std::make_shared(); @@ -697,11 +702,17 @@ void CommandQueue::finish(tt::stl::Span sub_device_ids) { } } -volatile bool CommandQueue::is_dprint_server_hung() { return dprint_server_hang; } +volatile bool HWCommandQueue::is_dprint_server_hung() { return dprint_server_hang; } + +volatile bool HWCommandQueue::is_noc_hung() { return illegal_noc_txn_hang; } + +const CoreCoord& HWCommandQueue::virtual_enqueue_program_dispatch_core() const { + return this->virtual_enqueue_program_dispatch_core_; +} -volatile bool CommandQueue::is_noc_hung() { return illegal_noc_txn_hang; } +const CoreCoord& HWCommandQueue::completion_queue_writer_core() const { return this->completion_queue_writer_core_; } -void CommandQueue::record_begin(const uint32_t tid, std::shared_ptr ctx) { +void HWCommandQueue::record_begin(const uint32_t tid, const std::shared_ptr& ctx) { auto num_sub_devices = this->device_->num_sub_devices(); // Record the original value of expected_num_workers_completed, and reset it to 0. std::copy( @@ -738,7 +749,7 @@ void CommandQueue::record_begin(const uint32_t tid, std::shared_ptrtrace_ctx->data; trace_data = std::move(this->manager.get_bypass_data()); // Add command to terminate the trace buffer @@ -776,7 +787,7 @@ void CommandQueue::record_end() { this->manager.set_bypass_mode(false, true); // stop } -void CommandQueue::terminate() { +void HWCommandQueue::terminate() { ZoneScopedN("HWCommandQueue_terminate"); TT_FATAL(!this->manager.get_bypass_mode(), "Terminate cannot be used with tracing"); tt::log_debug(tt::LogDispatch, "Terminating dispatch kernels for command queue {}", this->id_); @@ -784,9 +795,9 @@ void CommandQueue::terminate() { this->enqueue_command(command, false, {}); } -WorkerConfigBufferMgr& CommandQueue::get_config_buffer_mgr(uint32_t index) { return config_buffer_mgr[index]; } +WorkerConfigBufferMgr& HWCommandQueue::get_config_buffer_mgr(uint32_t index) { return config_buffer_mgr[index]; } -void CommandQueue::reset_config_buffer_mgr(const uint32_t num_entries) { +void HWCommandQueue::reset_config_buffer_mgr(const uint32_t num_entries) { for (uint32_t i = 0; i < num_entries; ++i) { this->config_buffer_mgr[i] = WorkerConfigBufferMgr(); program_dispatch::initialize_worker_config_buf_mgr(this->config_buffer_mgr[i]); diff --git a/tt_metal/api/tt-metalium/hardware_command_queue.hpp b/tt_metal/impl/dispatch/hardware_command_queue.hpp similarity index 60% rename from tt_metal/api/tt-metalium/hardware_command_queue.hpp rename to tt_metal/impl/dispatch/hardware_command_queue.hpp index f9c4f40a2e9..42058185c1e 100644 --- a/tt_metal/api/tt-metalium/hardware_command_queue.hpp +++ b/tt_metal/impl/dispatch/hardware_command_queue.hpp @@ -9,6 +9,8 @@ #include #include +#include "command_queue.hpp" +#include "command_queue_commands.hpp" #include "command_queue_interface.hpp" #include "lock_free_queue.hpp" #include "worker_config_buffer.hpp" @@ -16,133 +18,78 @@ #include "trace_buffer.hpp" namespace tt::tt_metal { -inline namespace v0 { - -class Event; - -} // namespace v0 - -namespace detail { - -// Used so the host knows how to properly copy data into user space from the completion queue (in hugepages) -struct ReadBufferDescriptor { - TensorMemoryLayout buffer_layout; - uint32_t page_size; - uint32_t padded_page_size; - std::shared_ptr buffer_page_mapping; - void* dst; - uint32_t dst_offset; - uint32_t num_pages_read; - uint32_t cur_dev_page_id; - uint32_t starting_host_page_id; - - ReadBufferDescriptor( - TensorMemoryLayout buffer_layout, - uint32_t page_size, - uint32_t padded_page_size, - void* dst, - uint32_t dst_offset, - uint32_t num_pages_read, - uint32_t cur_dev_page_id, - uint32_t starting_host_page_id = 0, - const std::shared_ptr& buffer_page_mapping = nullptr) : - buffer_layout(buffer_layout), - page_size(page_size), - padded_page_size(padded_page_size), - buffer_page_mapping(buffer_page_mapping), - dst(dst), - dst_offset(dst_offset), - num_pages_read(num_pages_read), - cur_dev_page_id(cur_dev_page_id), - starting_host_page_id(starting_host_page_id) {} -}; - -// Used so host knows data in completion queue is just an event ID -struct ReadEventDescriptor { - uint32_t event_id; - uint32_t global_offset; - - explicit ReadEventDescriptor(uint32_t event) : event_id(event), global_offset(0) {} - - void set_global_offset(uint32_t offset) { global_offset = offset; } - uint32_t get_global_event_id() { return global_offset + event_id; } -}; - -using CompletionReaderVariant = std::variant; - -} // namespace detail -class CommandQueue { +class HWCommandQueue : public CommandQueue { public: - CommandQueue(IDevice* device, uint32_t id, NOC noc_index, uint32_t completion_queue_reader_core = 0); + HWCommandQueue(IDevice* device, uint32_t id, NOC noc_index, uint32_t completion_queue_reader_core = 0); - ~CommandQueue(); + ~HWCommandQueue() override; - CoreCoord virtual_enqueue_program_dispatch_core; - CoreCoord completion_queue_writer_core; - NOC noc_index; - volatile bool is_dprint_server_hung(); - volatile bool is_noc_hung(); + const CoreCoord& virtual_enqueue_program_dispatch_core() const override; + const CoreCoord& completion_queue_writer_core() const override; - void record_begin(const uint32_t tid, std::shared_ptr ctx); - void record_end(); - void set_num_worker_sems_on_dispatch(uint32_t num_worker_sems); - void set_go_signal_noc_data_on_dispatch(const vector_memcpy_aligned& go_signal_noc_data); + volatile bool is_dprint_server_hung() override; + volatile bool is_noc_hung() override; + + void record_begin(const uint32_t tid, const std::shared_ptr& ctx) override; + void record_end() override; + void set_num_worker_sems_on_dispatch(uint32_t num_worker_sems) override; + void set_go_signal_noc_data_on_dispatch(const vector_memcpy_aligned& go_signal_noc_data) override; void reset_worker_state( bool reset_launch_msg_state, uint32_t num_sub_devices, - const vector_memcpy_aligned& go_signal_noc_data); + const vector_memcpy_aligned& go_signal_noc_data) override; - uint32_t id() const; - std::optional tid() const; + uint32_t id() const override; + std::optional tid() const override; - SystemMemoryManager& sysmem_manager(); + SystemMemoryManager& sysmem_manager() override; - void terminate(); + void terminate() override; // These functions are temporarily needed since MeshCommandQueue relies on the CommandQueue object - uint32_t get_expected_num_workers_completed_for_sub_device(uint32_t sub_device_index) const; - void set_expected_num_workers_completed_for_sub_device(uint32_t sub_device_index, uint32_t num_workers); - WorkerConfigBufferMgr& get_config_buffer_mgr(uint32_t index); + uint32_t get_expected_num_workers_completed_for_sub_device(uint32_t sub_device_index) const override; + void set_expected_num_workers_completed_for_sub_device(uint32_t sub_device_index, uint32_t num_workers) override; + WorkerConfigBufferMgr& get_config_buffer_mgr(uint32_t index) override; - void enqueue_trace(const uint32_t trace_id, bool blocking); - void enqueue_program(Program& program, bool blocking); + void enqueue_trace(const uint32_t trace_id, bool blocking) override; + void enqueue_program(Program& program, bool blocking) override; void enqueue_read_buffer( std::shared_ptr& buffer, void* dst, const BufferRegion& region, bool blocking, - tt::stl::Span sub_device_ids = {}); + tt::stl::Span sub_device_ids = {}) override; void enqueue_read_buffer( Buffer& buffer, void* dst, const BufferRegion& region, bool blocking, - tt::stl::Span sub_device_ids = {}); + tt::stl::Span sub_device_ids = {}) override; void enqueue_record_event( const std::shared_ptr& event, bool clear_count = false, - tt::stl::Span sub_device_ids = {}); - void enqueue_wait_for_event(const std::shared_ptr& sync_event, bool clear_count = false); + tt::stl::Span sub_device_ids = {}) override; + void enqueue_wait_for_event(const std::shared_ptr& sync_event, bool clear_count = false) override; void enqueue_write_buffer( const std::variant, std::shared_ptr>& buffer, HostDataType src, const BufferRegion& region, bool blocking, - tt::stl::Span sub_device_ids = {}); + tt::stl::Span sub_device_ids = {}) override; void enqueue_write_buffer( Buffer& buffer, const void* src, const BufferRegion& region, bool blocking, - tt::stl::Span sub_device_ids = {}); + tt::stl::Span sub_device_ids = {}) override; - void finish(tt::stl::Span sub_device_ids); + void finish(tt::stl::Span sub_device_ids) override; - IDevice* device(); + IDevice* device() override; private: uint32_t id_; @@ -166,7 +113,7 @@ class CommandQueue { volatile uint32_t num_completed_completion_q_reads; // completion queue reader thread increments this after reading // an entry out of the completion queue - LockFreeQueue issued_completion_q_reads; + LockFreeQueue issued_completion_q_reads; // These values are used to reset the host side launch message wptr after a trace is captured // Trace capture is a fully host side operation, but it modifies the state of the wptrs above // To ensure that host and device are not out of sync, we reset the wptrs to their original values @@ -185,6 +132,10 @@ class CommandQueue { std::mutex reads_processed_cv_mutex; CoreType get_dispatch_core_type(); + CoreCoord virtual_enqueue_program_dispatch_core_; + CoreCoord completion_queue_writer_core_; + NOC noc_index_; + void reset_worker_dispatch_state_on_device(bool reset_launch_msg_state); void reset_config_buffer_mgr(const uint32_t num_entries); diff --git a/tt_metal/impl/module.mk b/tt_metal/impl/module.mk index b645467af0d..0fb235b4b19 100644 --- a/tt_metal/impl/module.mk +++ b/tt_metal/impl/module.mk @@ -19,7 +19,7 @@ TT_METAL_IMPL_SRCS = \ tt_metal/impl/allocator/l1_banking_allocator.cpp \ tt_metal/impl/program/program.cpp \ tt_metal/impl/dispatch/debug_tools.cpp \ - tt_metal/impl/dispatch/command_queue.cpp \ + tt_metal/impl/dispatch/command_queue_comands.cpp \ tt_metal/impl/dispatch/launch_message_ring_buffer_state.cpp \ tt_metal/impl/debug/dprint_server.cpp \ tt_metal/impl/debug/watcher_server.cpp \ diff --git a/tt_metal/impl/program/dispatch.cpp b/tt_metal/impl/program/dispatch.cpp index 9bb7ba84f72..be3c8c4e2f2 100644 --- a/tt_metal/impl/program/dispatch.cpp +++ b/tt_metal/impl/program/dispatch.cpp @@ -5,7 +5,7 @@ #include "tt_metal/impl/program/dispatch.hpp" #include -#include +#include #include #include #include diff --git a/tt_metal/impl/program/program.cpp b/tt_metal/impl/program/program.cpp index a5d52c19c33..3f56c7b9dd9 100644 --- a/tt_metal/impl/program/program.cpp +++ b/tt_metal/impl/program/program.cpp @@ -21,7 +21,6 @@ #include "dprint_server.hpp" #include #include -#include #include #include "tt_metal/impl/program/dispatch.hpp" #include "tt_metal/jit_build/genfiles.hpp" diff --git a/tt_metal/impl/sub_device/sub_device_manager_tracker.cpp b/tt_metal/impl/sub_device/sub_device_manager_tracker.cpp index dd05811dab4..1f64ac9e5ce 100644 --- a/tt_metal/impl/sub_device/sub_device_manager_tracker.cpp +++ b/tt_metal/impl/sub_device/sub_device_manager_tracker.cpp @@ -13,7 +13,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 4c3738ff7d1..76ce97f921c 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -251,18 +251,18 @@ inline void SetRuntimeArgsImpl( [&](auto&& core_spec) { using T = std::decay_t; if constexpr (std::is_same_v) { - EnqueueSetRuntimeArgs(kernel, core_spec, runtime_args, blocking); + SetRuntimeArgs(kernel, core_spec, runtime_args, blocking); } else if constexpr (std::is_same_v) { for (auto x = core_spec.start_coord.x; x <= core_spec.end_coord.x; x++) { for (auto y = core_spec.start_coord.y; y <= core_spec.end_coord.y; y++) { - EnqueueSetRuntimeArgs(kernel, CoreCoord(x, y), runtime_args, blocking); + SetRuntimeArgs(kernel, CoreCoord(x, y), runtime_args, blocking); } } } else if constexpr (std::is_same_v) { for (const auto& core_range : core_spec.ranges()) { for (auto x = core_range.start_coord.x; x <= core_range.end_coord.x; x++) { for (auto y = core_range.start_coord.y; y <= core_range.end_coord.y; y++) { - EnqueueSetRuntimeArgs(kernel, CoreCoord(x, y), runtime_args, blocking); + SetRuntimeArgs(kernel, CoreCoord(x, y), runtime_args, blocking); } } } @@ -278,7 +278,7 @@ inline void SetRuntimeArgsImpl( bool blocking) { // SetRuntimeArgs API for Async CQ Mode (support vector of runtime args) for (size_t i = 0; i < core_spec.size(); i++) { - EnqueueSetRuntimeArgs(kernel, core_spec[i], runtime_args[i], blocking); + SetRuntimeArgs(kernel, core_spec[i], runtime_args[i], blocking); } } @@ -1257,7 +1257,7 @@ void DeallocateBuffer(Buffer& buffer) { buffer.deallocate(); } void AssignGlobalBufferToProgram(std::shared_ptr buffer, Program& program) { detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); - EnqueueAddBufferToProgram(buffer, program, false); + AddBufferToProgram(buffer, program, false); } void SetRuntimeArgs( From 463742e4c3a2933397e4ad58dc0cedfc1c0fc14d Mon Sep 17 00:00:00 2001 From: Artem Yerofieiev Date: Thu, 30 Jan 2025 05:10:51 +0000 Subject: [PATCH 2/4] =?UTF-8?q?move=20command=5Fqueue=5Fcommands.hpp=20our?= =?UTF-8?q?=20from=20metal=20api.=20yay=20=F0=9F=8E=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tt_metal/impl/dispatch/command_queue_commands.cpp | 2 +- .../tt-metalium => impl/dispatch}/command_queue_commands.hpp | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename tt_metal/{api/tt-metalium => impl/dispatch}/command_queue_commands.hpp (100%) diff --git a/tt_metal/impl/dispatch/command_queue_commands.cpp b/tt_metal/impl/dispatch/command_queue_commands.cpp index c680349748c..27a4752b234 100644 --- a/tt_metal/impl/dispatch/command_queue_commands.cpp +++ b/tt_metal/impl/dispatch/command_queue_commands.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include +#include "command_queue_commands.hpp" #include #include diff --git a/tt_metal/api/tt-metalium/command_queue_commands.hpp b/tt_metal/impl/dispatch/command_queue_commands.hpp similarity index 100% rename from tt_metal/api/tt-metalium/command_queue_commands.hpp rename to tt_metal/impl/dispatch/command_queue_commands.hpp From 50cde96974dcdca0900b75e254e849cc03100050 Mon Sep 17 00:00:00 2001 From: Artem Yerofieiev Date: Thu, 30 Jan 2025 18:19:22 +0000 Subject: [PATCH 3/4] remove more apis. rename command_queue_commands to host_runtime_commands. move things around --- tt_metal/api/tt-metalium/command_queue.hpp | 15 ------ .../tt-metalium/command_queue_interface.hpp | 46 ------------------- tt_metal/api/tt-metalium/program_impl.hpp | 2 +- tt_metal/impl/CMakeLists.txt | 2 +- tt_metal/impl/buffers/circular_buffer.cpp | 11 +---- tt_metal/impl/buffers/dispatch.hpp | 46 +++++++++++++++++++ .../impl/dispatch/hardware_command_queue.cpp | 1 - .../impl/dispatch/hardware_command_queue.hpp | 4 +- ...commands.cpp => host_runtime_commands.cpp} | 43 +---------------- ...commands.hpp => host_runtime_commands.hpp} | 0 tt_metal/tt_metal.cpp | 33 +++++++++++-- 11 files changed, 81 insertions(+), 122 deletions(-) rename tt_metal/impl/dispatch/{command_queue_commands.cpp => host_runtime_commands.cpp} (95%) rename tt_metal/impl/dispatch/{command_queue_commands.hpp => host_runtime_commands.hpp} (100%) diff --git a/tt_metal/api/tt-metalium/command_queue.hpp b/tt_metal/api/tt-metalium/command_queue.hpp index 76df29533b1..6aa6b22030a 100644 --- a/tt_metal/api/tt-metalium/command_queue.hpp +++ b/tt_metal/api/tt-metalium/command_queue.hpp @@ -95,19 +95,4 @@ class CommandQueue { virtual void finish(tt::stl::Span sub_device_ids) = 0; }; -// Temporarily here. Need to eliminate this -using RuntimeArgs = std::vector>; -// Primitives used to place host only operations on the SW Command Queue. -// These are used in functions exposed through tt_metal.hpp or host_api.hpp -void GetBufferAddr(uint32_t* dst_buf_addr, const Buffer* buffer, bool blocking); -void SetRuntimeArgs( - const std::shared_ptr& kernel, - const CoreCoord& core_coord, - std::shared_ptr runtime_args_ptr, - bool blocking); -void AddBufferToProgram( - const std::variant, std::shared_ptr>& buffer, - Program& program, - bool blocking); - } // namespace tt::tt_metal diff --git a/tt_metal/api/tt-metalium/command_queue_interface.hpp b/tt_metal/api/tt-metalium/command_queue_interface.hpp index 63283acb306..c1d7593e243 100644 --- a/tt_metal/api/tt-metalium/command_queue_interface.hpp +++ b/tt_metal/api/tt-metalium/command_queue_interface.hpp @@ -44,52 +44,6 @@ enum class CommandQueueHostAddrType : uint8_t { UNRESERVED = 4 }; -// Used so the host knows how to properly copy data into user space from the completion queue (in hugepages) -struct ReadBufferDescriptor { - TensorMemoryLayout buffer_layout; - uint32_t page_size; - uint32_t padded_page_size; - std::shared_ptr buffer_page_mapping; - void* dst; - uint32_t dst_offset; - uint32_t num_pages_read; - uint32_t cur_dev_page_id; - uint32_t starting_host_page_id; - - ReadBufferDescriptor( - TensorMemoryLayout buffer_layout, - uint32_t page_size, - uint32_t padded_page_size, - void* dst, - uint32_t dst_offset, - uint32_t num_pages_read, - uint32_t cur_dev_page_id, - uint32_t starting_host_page_id = 0, - const std::shared_ptr& buffer_page_mapping = nullptr) : - buffer_layout(buffer_layout), - page_size(page_size), - padded_page_size(padded_page_size), - buffer_page_mapping(buffer_page_mapping), - dst(dst), - dst_offset(dst_offset), - num_pages_read(num_pages_read), - cur_dev_page_id(cur_dev_page_id), - starting_host_page_id(starting_host_page_id) {} -}; - -// Used so host knows data in completion queue is just an event ID -struct ReadEventDescriptor { - uint32_t event_id; - uint32_t global_offset; - - explicit ReadEventDescriptor(uint32_t event) : event_id(event), global_offset(0) {} - - void set_global_offset(uint32_t offset) { global_offset = offset; } - uint32_t get_global_event_id() { return global_offset + event_id; } -}; - -using CompletionReaderVariant = std::variant; - // Contains constants related to FD // // Deprecated note: for constant values, use tt::tt_metal::dispatch::DispatchConstants instead. diff --git a/tt_metal/api/tt-metalium/program_impl.hpp b/tt_metal/api/tt-metalium/program_impl.hpp index 96d2e0c30f8..9c854434249 100644 --- a/tt_metal/api/tt-metalium/program_impl.hpp +++ b/tt_metal/api/tt-metalium/program_impl.hpp @@ -59,7 +59,7 @@ namespace distributed { class JitBuildOptions; class EnqueueProgramCommand; class CommandQueue; -// Must be removed. Only here because its a damn friend of a Program +// Must be removed. Only here because its a friend of a Program class HWCommandQueue; namespace detail{ diff --git a/tt_metal/impl/CMakeLists.txt b/tt_metal/impl/CMakeLists.txt index 0558711d018..0c785db5b4b 100644 --- a/tt_metal/impl/CMakeLists.txt +++ b/tt_metal/impl/CMakeLists.txt @@ -22,7 +22,7 @@ set(IMPL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/program/program.cpp ${CMAKE_CURRENT_SOURCE_DIR}/program/dispatch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/debug_tools.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/command_queue_commands.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/host_runtime_commands.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/hardware_command_queue.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/launch_message_ring_buffer_state.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/worker_config_buffer.cpp diff --git a/tt_metal/impl/buffers/circular_buffer.cpp b/tt_metal/impl/buffers/circular_buffer.cpp index c85389241d1..8ea6cb922ef 100644 --- a/tt_metal/impl/buffers/circular_buffer.cpp +++ b/tt_metal/impl/buffers/circular_buffer.cpp @@ -12,13 +12,6 @@ #include #include -namespace { - -inline void GetBufferAddress(const tt::tt_metal::Buffer* buffer, uint32_t* address_on_host) { - GetBufferAddr(address_on_host, buffer, false); -} - -} // namespace namespace tt { namespace tt_metal { @@ -129,9 +122,7 @@ uint32_t CircularBuffer::address() const { return this->globally_allocated() ? globally_allocated_address_ : locally_allocated_address_.value(); } -void CircularBuffer::assign_global_address() { - GetBufferAddress(config_.shadow_global_buffer, &globally_allocated_address_); -} +void CircularBuffer::assign_global_address() { globally_allocated_address_ = config_.shadow_global_buffer->address(); } void CircularBuffer::set_global_circular_buffer(const v1::experimental::GlobalCircularBuffer& global_circular_buffer) { TT_FATAL( diff --git a/tt_metal/impl/buffers/dispatch.hpp b/tt_metal/impl/buffers/dispatch.hpp index 0cda3654610..3e9d2791106 100644 --- a/tt_metal/impl/buffers/dispatch.hpp +++ b/tt_metal/impl/buffers/dispatch.hpp @@ -11,6 +11,52 @@ namespace tt::tt_metal { +// Used so the host knows how to properly copy data into user space from the completion queue (in hugepages) +struct ReadBufferDescriptor { + TensorMemoryLayout buffer_layout; + uint32_t page_size; + uint32_t padded_page_size; + std::shared_ptr buffer_page_mapping; + void* dst; + uint32_t dst_offset; + uint32_t num_pages_read; + uint32_t cur_dev_page_id; + uint32_t starting_host_page_id; + + ReadBufferDescriptor( + TensorMemoryLayout buffer_layout, + uint32_t page_size, + uint32_t padded_page_size, + void* dst, + uint32_t dst_offset, + uint32_t num_pages_read, + uint32_t cur_dev_page_id, + uint32_t starting_host_page_id = 0, + const std::shared_ptr& buffer_page_mapping = nullptr) : + buffer_layout(buffer_layout), + page_size(page_size), + padded_page_size(padded_page_size), + buffer_page_mapping(buffer_page_mapping), + dst(dst), + dst_offset(dst_offset), + num_pages_read(num_pages_read), + cur_dev_page_id(cur_dev_page_id), + starting_host_page_id(starting_host_page_id) {} +}; + +// Used so host knows data in completion queue is just an event ID +struct ReadEventDescriptor { + uint32_t event_id; + uint32_t global_offset; + + explicit ReadEventDescriptor(uint32_t event) : event_id(event), global_offset(0) {} + + void set_global_offset(uint32_t offset) { global_offset = offset; } + uint32_t get_global_event_id() { return global_offset + event_id; } +}; + +using CompletionReaderVariant = std::variant; + // Contains helper functions to interface with buffers on device namespace buffer_dispatch { diff --git a/tt_metal/impl/dispatch/hardware_command_queue.cpp b/tt_metal/impl/dispatch/hardware_command_queue.cpp index 08ed894c8e4..43171242545 100644 --- a/tt_metal/impl/dispatch/hardware_command_queue.cpp +++ b/tt_metal/impl/dispatch/hardware_command_queue.cpp @@ -14,7 +14,6 @@ // MUST REMOVE #include -#include "tt_metal/impl/buffers/dispatch.hpp" #include "tt_metal/impl/debug/watcher_server.hpp" #include "tt_metal/impl/program/dispatch.hpp" diff --git a/tt_metal/impl/dispatch/hardware_command_queue.hpp b/tt_metal/impl/dispatch/hardware_command_queue.hpp index 42058185c1e..fcef722f25f 100644 --- a/tt_metal/impl/dispatch/hardware_command_queue.hpp +++ b/tt_metal/impl/dispatch/hardware_command_queue.hpp @@ -10,13 +10,15 @@ #include #include "command_queue.hpp" -#include "command_queue_commands.hpp" +#include "host_runtime_commands.hpp" #include "command_queue_interface.hpp" #include "lock_free_queue.hpp" #include "worker_config_buffer.hpp" #include "program_impl.hpp" #include "trace_buffer.hpp" +#include "tt_metal/impl/buffers/dispatch.hpp" + namespace tt::tt_metal { class HWCommandQueue : public CommandQueue { diff --git a/tt_metal/impl/dispatch/command_queue_commands.cpp b/tt_metal/impl/dispatch/host_runtime_commands.cpp similarity index 95% rename from tt_metal/impl/dispatch/command_queue_commands.cpp rename to tt_metal/impl/dispatch/host_runtime_commands.cpp index 27a4752b234..35da7a7c80f 100644 --- a/tt_metal/impl/dispatch/command_queue_commands.cpp +++ b/tt_metal/impl/dispatch/host_runtime_commands.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "command_queue_commands.hpp" +#include "host_runtime_commands.hpp" #include #include @@ -494,47 +494,6 @@ void EnqueueTerminateCommand::process() { this->manager.fetch_queue_write(cmd_sequence_sizeB, this->command_queue_id); } -void AddBufferToProgram( - const std::variant, std::shared_ptr>& buffer, - Program& program, - bool blocking) { - std::visit( - [&program](auto&& b) { - using buffer_type = std::decay_t; - if constexpr (std::is_same_v>) { - program.add_buffer(b); - } - }, - buffer); -} - -void SetRuntimeArgs( - const std::shared_ptr& kernel, - const CoreCoord& core_coord, - std::shared_ptr runtime_args_ptr, - bool blocking) { - std::vector resolved_runtime_args = {}; - resolved_runtime_args.reserve(runtime_args_ptr->size()); - - for (const auto& arg : *(runtime_args_ptr)) { - std::visit( - [&resolved_runtime_args](auto&& a) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - resolved_runtime_args.push_back(a->address()); - } else { - resolved_runtime_args.push_back(a); - } - }, - arg); - } - kernel->set_runtime_args(core_coord, resolved_runtime_args); -} - -void GetBufferAddr(uint32_t* dst_buf_addr, const Buffer* buffer, bool blocking) { - *(static_cast(dst_buf_addr)) = buffer->address(); -} - inline namespace v0 { void EnqueueWriteBuffer( diff --git a/tt_metal/impl/dispatch/command_queue_commands.hpp b/tt_metal/impl/dispatch/host_runtime_commands.hpp similarity index 100% rename from tt_metal/impl/dispatch/command_queue_commands.hpp rename to tt_metal/impl/dispatch/host_runtime_commands.hpp diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 76ce97f921c..7a39e474a79 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -241,6 +241,29 @@ inline void SetRuntimeArgsImpl( } } +void SetRuntimeArgsImpl( + const std::shared_ptr& kernel, + const CoreCoord& core_coord, + std::shared_ptr runtime_args_ptr, + bool blocking) { + std::vector resolved_runtime_args = {}; + resolved_runtime_args.reserve(runtime_args_ptr->size()); + + for (const auto& arg : *(runtime_args_ptr)) { + std::visit( + [&resolved_runtime_args](auto&& a) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + resolved_runtime_args.push_back(a->address()); + } else { + resolved_runtime_args.push_back(a); + } + }, + arg); + } + kernel->set_runtime_args(core_coord, resolved_runtime_args); +} + inline void SetRuntimeArgsImpl( const std::shared_ptr kernel, const std::variant& core_spec, @@ -251,18 +274,18 @@ inline void SetRuntimeArgsImpl( [&](auto&& core_spec) { using T = std::decay_t; if constexpr (std::is_same_v) { - SetRuntimeArgs(kernel, core_spec, runtime_args, blocking); + SetRuntimeArgsImpl(kernel, core_spec, runtime_args, blocking); } else if constexpr (std::is_same_v) { for (auto x = core_spec.start_coord.x; x <= core_spec.end_coord.x; x++) { for (auto y = core_spec.start_coord.y; y <= core_spec.end_coord.y; y++) { - SetRuntimeArgs(kernel, CoreCoord(x, y), runtime_args, blocking); + SetRuntimeArgsImpl(kernel, CoreCoord(x, y), runtime_args, blocking); } } } else if constexpr (std::is_same_v) { for (const auto& core_range : core_spec.ranges()) { for (auto x = core_range.start_coord.x; x <= core_range.end_coord.x; x++) { for (auto y = core_range.start_coord.y; y <= core_range.end_coord.y; y++) { - SetRuntimeArgs(kernel, CoreCoord(x, y), runtime_args, blocking); + SetRuntimeArgsImpl(kernel, CoreCoord(x, y), runtime_args, blocking); } } } @@ -278,7 +301,7 @@ inline void SetRuntimeArgsImpl( bool blocking) { // SetRuntimeArgs API for Async CQ Mode (support vector of runtime args) for (size_t i = 0; i < core_spec.size(); i++) { - SetRuntimeArgs(kernel, core_spec[i], runtime_args[i], blocking); + SetRuntimeArgsImpl(kernel, core_spec[i], runtime_args[i], blocking); } } @@ -1257,7 +1280,7 @@ void DeallocateBuffer(Buffer& buffer) { buffer.deallocate(); } void AssignGlobalBufferToProgram(std::shared_ptr buffer, Program& program) { detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); - AddBufferToProgram(buffer, program, false); + program.add_buffer(buffer); } void SetRuntimeArgs( From e0b4d39c15f5c436c01deead1f171928963e079d Mon Sep 17 00:00:00 2001 From: Artem Yerofieiev Date: Fri, 31 Jan 2025 00:33:15 +0000 Subject: [PATCH 4/4] fix clang tidy const& --- tt_metal/impl/dispatch/host_runtime_commands.hpp | 2 -- tt_metal/tt_metal.cpp | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tt_metal/impl/dispatch/host_runtime_commands.hpp b/tt_metal/impl/dispatch/host_runtime_commands.hpp index 7b1503c5f3f..56961432fd7 100644 --- a/tt_metal/impl/dispatch/host_runtime_commands.hpp +++ b/tt_metal/impl/dispatch/host_runtime_commands.hpp @@ -49,8 +49,6 @@ enum class EnqueueCommandType { INVALID }; -string EnqueueCommandTypeToString(EnqueueCommandType ctype); - class Command { public: Command() {} diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 7a39e474a79..8561d2f446c 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -244,7 +244,7 @@ inline void SetRuntimeArgsImpl( void SetRuntimeArgsImpl( const std::shared_ptr& kernel, const CoreCoord& core_coord, - std::shared_ptr runtime_args_ptr, + const std::shared_ptr& runtime_args_ptr, bool blocking) { std::vector resolved_runtime_args = {}; resolved_runtime_args.reserve(runtime_args_ptr->size()); @@ -267,7 +267,7 @@ void SetRuntimeArgsImpl( inline void SetRuntimeArgsImpl( const std::shared_ptr kernel, const std::variant& core_spec, - std::shared_ptr runtime_args, + const std::shared_ptr& runtime_args, bool blocking) { // SetRuntimeArgs API for Async CQ Mode std::visit(