diff --git a/tests/ttnn/unit_tests/gtests/CMakeLists.txt b/tests/ttnn/unit_tests/gtests/CMakeLists.txt index f8f9806b62a..a76b5284298 100644 --- a/tests/ttnn/unit_tests/gtests/CMakeLists.txt +++ b/tests/ttnn/unit_tests/gtests/CMakeLists.txt @@ -28,6 +28,7 @@ set(TTNN_TENSOR_UNIT_TESTS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_create_tensor_multi_device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_create_tensor_with_layout.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_distributed_tensor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_mesh_tensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_partition.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_shape_base.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_tensor_sharding.cpp diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp new file mode 100644 index 00000000000..dca6dec5491 --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "ttnn/tensor/tensor.hpp" +#include "ttnn_test_fixtures.hpp" +#include +#include + +namespace ttnn::distributed::test { +namespace { + +using MeshTensorTest = T3kMultiDeviceFixture; + +TEST_F(MeshTensorTest, Lifecycle) { + const TensorSpec tensor_spec = + TensorSpec(ttnn::SimpleShape{1, 1, 32, 32}, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{})); + + Tensor input_tensor = allocate_tensor_on_mesh(tensor_spec, mesh_device_.get()); + + EXPECT_EQ(input_tensor.workers.size(), mesh_device_->num_devices()); + EXPECT_TRUE(input_tensor.is_allocated()); + + const auto& storage = input_tensor.get_storage(); + auto* multi_device_storage = std::get_if(&storage); + + ASSERT_NE(multi_device_storage, nullptr); + EXPECT_NE(multi_device_storage->mesh_buffer, nullptr); + + // Buffer address is the same across all device buffers. + const auto buffer_address = multi_device_storage->mesh_buffer->address(); + for (auto* device : mesh_device_->get_devices()) { + auto buffer = multi_device_storage->get_buffer_for_device(device); + ASSERT_NE(buffer, nullptr); + EXPECT_TRUE(buffer->is_allocated()); + EXPECT_EQ(buffer->address(), buffer_address); + } + + input_tensor.deallocate(); + EXPECT_FALSE(input_tensor.is_allocated()); +} + +} // namespace +} // namespace ttnn::distributed::test diff --git a/tt_metal/api/tt-metalium/tt_metal.hpp b/tt_metal/api/tt-metalium/tt_metal.hpp index b86e4664a49..ea5220ca078 100644 --- a/tt_metal/api/tt-metalium/tt_metal.hpp +++ b/tt_metal/api/tt-metalium/tt_metal.hpp @@ -26,6 +26,7 @@ namespace detail { bool DispatchStateCheck(bool isFastDispatch); bool InWorkerThread(); +inline bool InMainThread() { return not InWorkerThread(); } std::map CreateDevices( // TODO: delete this in favour of DevicePool diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index 6690b490158..8c9e9a6f971 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -96,7 +96,7 @@ Tensor aggregate_as_tensor( } else { std::vector ordered_device_ids; std::unordered_map specs; - std::unordered_map device_buffers; + std::unordered_map> device_buffers; for (const auto& shard : tensor_shards) { IDevice* device = std::get(shard.get_storage()).buffer->device(); auto device_id = device->id(); @@ -116,7 +116,8 @@ Tensor aggregate_as_tensor( shard_tile.get_width()); } } - auto storage = MultiDeviceStorage{config, ordered_device_ids, std::move(device_buffers), specs}; + auto storage = + MultiDeviceStorage{config, ordered_device_ids, std::move(device_buffers), specs, /*mesh_buffer_=*/nullptr}; return Tensor(std::move(storage), reference_shard.get_tensor_spec()); } } @@ -247,7 +248,7 @@ Tensor create_multi_device_tensor( if (storage_type == StorageType::MULTI_DEVICE) { std::vector ordered_device_ids; std::unordered_map specs; - std::unordered_map device_buffers; + std::unordered_map> device_buffers; for (const auto& tensor : tensors) { TT_ASSERT( std::holds_alternative(tensor.get_storage()), @@ -260,7 +261,7 @@ Tensor create_multi_device_tensor( specs.insert({device_id, tensor.get_tensor_spec()}); } return Tensor{ - MultiDeviceStorage{strategy, ordered_device_ids, device_buffers, specs}, + MultiDeviceStorage{strategy, ordered_device_ids, device_buffers, specs, /*mesh_buffer_=*/nullptr}, TensorSpec( tensors.at(0).get_logical_shape(), TensorLayout::fromPaddedShape( diff --git a/ttnn/cpp/ttnn/operations/experimental/reshape/view.cpp b/ttnn/cpp/ttnn/operations/experimental/reshape/view.cpp index 8ec4bc62713..20182d0cb8b 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reshape/view.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/reshape/view.cpp @@ -81,7 +81,7 @@ Tensor tensor_reshape( if (input_tensor.get_layout() == Layout::ROW_MAJOR) { if (tensor.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) { DeviceStorage device_storage = std::get(tensor.get_storage()); - DeviceBuffer device_buffer = device_storage.get_buffer(); + auto device_buffer = device_storage.get_buffer(); const auto& tensor_spec = tensor.tensor_spec(); auto page_size_bytes = tensor_spec.compute_page_size_bytes(); device_buffer->set_page_size(page_size_bytes); @@ -89,7 +89,7 @@ Tensor tensor_reshape( return Tensor(device_storage, new_spec); } else { DeviceStorage device_storage = std::get(tensor.get_storage()); - DeviceBuffer device_buffer = device_storage.get_buffer(); + auto device_buffer = device_storage.get_buffer(); ShardSpecBuffer shard_spec_buffer = device_buffer->shard_spec(); auto shard_spec = shard_spec_buffer.tensor_shard_spec; diff --git a/ttnn/cpp/ttnn/tensor/storage.cpp b/ttnn/cpp/ttnn/tensor/storage.cpp index a33f1f93ac3..ad385113ed8 100644 --- a/ttnn/cpp/ttnn/tensor/storage.cpp +++ b/ttnn/cpp/ttnn/tensor/storage.cpp @@ -6,9 +6,9 @@ namespace tt::tt_metal { -std::vector MultiDeviceStorage::get_buffers() const { +std::vector> MultiDeviceStorage::get_buffers() const { std::lock_guard lock(buffer_mtx); - std::vector buf_vec; + std::vector> buf_vec; buf_vec.reserve(buffers.size()); for (const auto& pair : buffers) { buf_vec.push_back(pair.second); diff --git a/ttnn/cpp/ttnn/tensor/storage.hpp b/ttnn/cpp/ttnn/tensor/storage.hpp index e28be85f525..16f3143edae 100644 --- a/ttnn/cpp/ttnn/tensor/storage.hpp +++ b/ttnn/cpp/ttnn/tensor/storage.hpp @@ -38,13 +38,13 @@ struct OwnedStorage { } }; -using DeviceBuffer = std::shared_ptr; +// TODO: #17215 - Replace `DeviceStorage` with "mesh storage". struct DeviceStorage { - DeviceBuffer buffer; + std::shared_ptr buffer; DeviceStorage() = default; - DeviceStorage(DeviceBuffer buffer_) : buffer(std::move(buffer_)) {} + DeviceStorage(std::shared_ptr buffer_) : buffer(std::move(buffer_)) {} - const MemoryConfig memory_config() const { + MemoryConfig memory_config() const { if (this->buffer.get() == nullptr) { TT_THROW("MemoryConfig can only be obtained if the buffer is not null"); } @@ -59,9 +59,9 @@ struct DeviceStorage { .shard_spec = shard_spec}; } - inline void insert_buffer(DeviceBuffer buffer_) { this->buffer = buffer_; } + inline void insert_buffer(const std::shared_ptr& buffer_) { this->buffer = buffer_; } - inline DeviceBuffer get_buffer() const { return this->buffer; } + inline std::shared_ptr get_buffer() const { return this->buffer; } static constexpr auto attribute_names = std::forward_as_tuple("memory_config"); const auto attribute_values() const { return std::make_tuple(this->memory_config()); } @@ -149,7 +149,7 @@ struct MultiDeviceHostStorage { MultiDeviceHostStorage() = default; MultiDeviceHostStorage( DistributedTensorConfig strategy_, std::vector buffers_, std::vector specs_) : - strategy(strategy_), buffers(buffers_), specs(specs_) {} + strategy(strategy_), buffers(std::move(buffers_)), specs(std::move(specs_)) {} MultiDeviceHostStorage(MultiDeviceHostStorage&& other) { swap(*this, other); } // unfotunately we need to have this code written manually. MultiDeviceHostStorage(const MultiDeviceHostStorage& other) { @@ -222,8 +222,13 @@ struct MultiDeviceHostStorage { struct MultiDeviceStorage { DistributedTensorConfig strategy; std::vector ordered_device_ids; - std::unordered_map buffers; + std::unordered_map> buffers; std::unordered_map specs; + + // TODO: #17215 - This isn't populated by default. Switch to creating MeshBuffer backed storage, when TTNN is ready + // to consume it. + // Eventually, `MultiDeviceStorage` will be renamed to `MeshDeviceStorage`, and unified with `DeviceStorage`. + std::shared_ptr mesh_buffer; mutable std::mutex buffer_mtx; mutable std::mutex shape_mtx; MultiDeviceStorage() = default; @@ -235,17 +240,20 @@ struct MultiDeviceStorage { swap(first.ordered_device_ids, second.ordered_device_ids); swap(first.buffers, second.buffers); swap(first.specs, second.specs); + swap(first.mesh_buffer, second.mesh_buffer); } MultiDeviceStorage( DistributedTensorConfig strategy_, std::vector ordered_device_ids_, - std::unordered_map buffers_, - std::unordered_map specs_) : + std::unordered_map> buffers_, + std::unordered_map specs_, + std::shared_ptr mesh_buffer_) : strategy(std::move(strategy_)), ordered_device_ids(std::move(ordered_device_ids_)), buffers(std::move(buffers_)), - specs(std::move(specs_)) {} + specs(std::move(specs_)), + mesh_buffer(std::move(mesh_buffer_)) {} MultiDeviceStorage(MultiDeviceStorage&& other) { swap(*this, other); } @@ -255,6 +263,7 @@ struct MultiDeviceStorage { strategy = other.strategy; buffers = other.buffers; specs = other.specs; + mesh_buffer = other.mesh_buffer; } MultiDeviceStorage& operator=(const MultiDeviceStorage& other) { @@ -270,10 +279,10 @@ struct MultiDeviceStorage { bool operator==(const MultiDeviceStorage& other) { return this->ordered_device_ids == other.ordered_device_ids and this->strategy == other.strategy and - this->buffers == other.buffers and this->specs == other.specs; + this->buffers == other.buffers and this->specs == other.specs and this->mesh_buffer == other.mesh_buffer; } - inline const MemoryConfig memory_config() const { + MemoryConfig memory_config() const { std::lock_guard lock(buffer_mtx); TT_FATAL( !this->ordered_device_ids.empty(), "No device ids in list. Please ensure fields are initialized properly."); @@ -296,10 +305,12 @@ struct MultiDeviceStorage { // Helper Functions - Getters and setters to get/modify storage attributes. These are needed to // preinitialize empty tensor handles and use/populate them in the worker threads. - std::vector get_buffers() const; + std::vector> get_buffers() const; - inline void insert_buffer_and_spec_for_device(IDevice* device, const DeviceBuffer buffer, TensorSpec spec) { + inline void insert_buffer_and_spec_for_device( + IDevice* device, const std::shared_ptr& buffer, TensorSpec spec) { std::scoped_lock lock(buffer_mtx, shape_mtx); + TT_FATAL(mesh_buffer == nullptr, "MeshBuffer backed storage does not support inserting individual buffers"); TT_ASSERT( device == buffer->device(), "Mismatch between device derived from buffer and device derived from MultiDeviceStorage."); @@ -307,7 +318,7 @@ struct MultiDeviceStorage { specs.insert({device->id(), std::move(spec)}); } - inline DeviceBuffer get_buffer_for_device(IDevice* device) const { + inline std::shared_ptr get_buffer_for_device(IDevice* device) const { std::lock_guard lock(buffer_mtx); TT_ASSERT(buffers.find(device->id()) != buffers.end(), "Buffer not found for device {}", device->id()); TT_ASSERT( @@ -316,7 +327,7 @@ struct MultiDeviceStorage { return buffers.at(device->id()); } - inline DeviceBuffer& get_buffer_for_device(IDevice* device) { + inline std::shared_ptr& get_buffer_for_device(IDevice* device) { std::lock_guard lock(buffer_mtx); TT_ASSERT(buffers.find(device->id()) != buffers.end(), "Buffer not found for device {}", device->id()); TT_ASSERT( @@ -325,7 +336,7 @@ struct MultiDeviceStorage { return buffers.at(device->id()); } - inline DeviceBuffer get_buffer_for_device_id(uint32_t device_id) const { + inline std::shared_ptr get_buffer_for_device_id(uint32_t device_id) const { std::lock_guard lock(buffer_mtx); return buffers.at(device_id); } @@ -352,13 +363,16 @@ struct MultiDeviceStorage { } inline bool is_allocated() const { - std::lock_guard lock(buffer_mtx); - - return std::all_of( - ordered_device_ids.begin(), ordered_device_ids.end(), [&buffers = this->buffers](auto&& device_id) { - const auto& buffer = buffers.at(device_id); - return buffer && buffer->is_allocated(); - }); + if (mesh_buffer != nullptr) { + return mesh_buffer->is_allocated(); + } else { + std::lock_guard lock(buffer_mtx); + return std::all_of( + ordered_device_ids.begin(), ordered_device_ids.end(), [&buffers = this->buffers](auto&& device_id) { + const auto& buffer = buffers.at(device_id); + return buffer && buffer->is_allocated(); + }); + } } }; diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index b77155538d8..24ebdefce4a 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -13,6 +13,8 @@ #include #include #include +#include "tt-metalium/mesh_device_view.hpp" +#include "ttnn/distributed/distributed_tensor_config.hpp" #include "ttnn/tensor/tensor_ops.hpp" #include "ttnn/tensor/tensor_impl.hpp" #include "ttnn/tensor/tensor_impl_wrapper.hpp" @@ -201,11 +203,9 @@ void Tensor::init(Storage storage, TensorSpec tensor_spec) { tensor_attributes->tensor_spec.layout()); // Increment main thread ref count for all tensors on device tensor_attributes->increment_main_thread_ref_count(this->workers.at(0)); - // This tensor is being created from scratch in a worker. Track this and allow it to be explicitly - // deallocated inside the worker (composite ops do this). - if (tt::tt_metal::detail::InWorkerThread()) { - tensor_attributes->main_thread_tensor = false; - } + // Track if this tensor is being created from scratch in a worker, to allow it to be deallocated inside + // the worker (composite ops do this). + tensor_attributes->main_thread_tensor = tt::tt_metal::detail::InMainThread(); tensor_attributes->num_shards_to_be_populated = 1; } else if constexpr (std::is_same_v) { tensor_attributes->num_shards_to_be_populated = 1; @@ -225,11 +225,9 @@ void Tensor::init(Storage storage, TensorSpec tensor_spec) { } // Increment main thread ref count for all tensors on cluster tensor_attributes->increment_main_thread_ref_count(this->workers.at(0)); - // This tensor is being created from scratch in a worker. Track this and allow it to be explicitly - // deallocated inside the worker (composite ops do this). - if (tt::tt_metal::detail::InWorkerThread()) { - tensor_attributes->main_thread_tensor = false; - } + // Track if this tensor is being created from scratch in a worker, to allow it to be deallocated inside + // the worker (composite ops do this). + tensor_attributes->main_thread_tensor = tt::tt_metal::detail::InMainThread(); tensor_attributes->num_shards_to_be_populated = storage.num_buffers(); } else if constexpr (std::is_same_v) { tensor_attributes->num_shards_to_be_populated = storage.num_buffers(); @@ -253,13 +251,14 @@ Tensor::Tensor(const std::vector& workers) : } MultiDeviceStorage storage; std::transform( - workers.cbegin(), workers.cend(), std::back_inserter(storage.ordered_device_ids), [](const IDevice* worker) { - return worker->id(); - }); + workers.cbegin(), + workers.cend(), + std::back_inserter(storage.ordered_device_ids), + [](const IDevice* worker) { return worker->id(); }); return Storage(std::move(storage)); }(); tensor_attributes->num_shards_to_be_populated = workers.size(); - if (!tt::tt_metal::detail::InWorkerThread()) { + if (tt::tt_metal::detail::InMainThread()) { tensor_attributes->increment_main_thread_ref_count(this->workers.at(0)); } else { // This tensor is being created from scratch in a worker. Track this and allow it to be explicitly @@ -299,7 +298,6 @@ Tensor& Tensor::operator=(const Tensor& other) { perform_cleanup_for_async_mode(); this->workers = other.workers; this->tensor_attributes = other.tensor_attributes; - this->deallocate_through_destructor = other.deallocate_through_destructor; if (this->workers.size()) { if (not tt::tt_metal::detail::InWorkerThread()) { this->tensor_attributes->increment_main_thread_ref_count(this->workers.at(0)); @@ -310,10 +308,7 @@ Tensor& Tensor::operator=(const Tensor& other) { } Tensor::Tensor(const Tensor& other) : - tensor_id(other.tensor_id), - workers(other.workers), - tensor_attributes(other.tensor_attributes), - deallocate_through_destructor(other.deallocate_through_destructor) { + tensor_id(other.tensor_id), workers(other.workers), tensor_attributes(other.tensor_attributes) { if (this->workers.size()) { if (not tt::tt_metal::detail::InWorkerThread()) { this->tensor_attributes->increment_main_thread_ref_count(this->workers.at(0)); @@ -323,8 +318,7 @@ Tensor::Tensor(const Tensor& other) : Tensor::~Tensor() { ZoneScoped; - this->deallocate_through_destructor = true; - this->deallocate(); + this->deallocate_impl(/*force=*/false, /*deallocation_through_destructor=*/true); // Decrement main thread ref count for all tensors on device if (this->workers.size() and this->tensor_attributes) { this->tensor_attributes->decrement_main_thread_ref_count(this->workers.at(0)); @@ -336,144 +330,145 @@ Tensor::Tensor( Storage storage, const ttnn::SimpleShape& shape, DataType dtype, Layout layout, const std::optional& tile) : Tensor(std::move(storage), /* logical_shape */ shape, /* padded_shape */ shape, dtype, layout, tile) {} -void Tensor::deallocate(bool force) { +void Tensor::deallocate(bool force) { deallocate_impl(force, /*deallocation_through_destructor=*/false); } + +void Tensor::deallocate_impl(bool force, bool deallocation_through_destructor) { ZoneScopedN("TensorDeallocate"); // GraphTracker::instance().track_function_start("Tensor::deallocate", *this, force); - if (this->tensor_attributes.use_count()) { - // Check if the attributes didn't get moved to another tensor. - // If not, we can deallocate this tensor. - std::visit( - [force, this](auto& storage) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - if (this->tensor_attributes.use_count() == 1) { - std::visit([](auto&& buffer) { buffer.reset(); }, storage.buffer); - } - } else if constexpr (std::is_same_v) { - if (not this->workers.at(0)->is_initialized()) { - return; + // Check if the attributes didn't get moved to another tensor. + // If not, we can deallocate this tensor. + if (tensor_attributes.use_count() == 0) { + return; + } + + auto get_tensor_ref_count = [](const Tensor& tensor) { + // If owned by the main thread, deallocate this tensor only from the main thread. If owned by worker thread, + // allow deallocation in worker and use shared_ptr ref count, since this is a thread_local tensor + return (tensor.workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or + not tensor.tensor_attributes->main_thread_tensor) + ? tensor.tensor_attributes.use_count() + : tensor.tensor_attributes->main_thread_ref_count; + }; + + std::visit( + tt::stl::overloaded{ + [this](OwnedStorage& storage) { + if (this->tensor_attributes.use_count() == 1) { + std::visit([](auto&& buffer) { buffer.reset(); }, storage.buffer); + } + }, + [force, this](BorrowedStorage& storage) { + TT_FATAL(not force, "Cannot deallocate tensor with borrowed storage!"); + }, + [this](MultiDeviceHostStorage& storage) { + if (this->tensor_attributes.use_count() == 1) { + for (int i = 0; i < storage.num_buffers(); i++) { + std::visit([](auto&& buffer) { buffer.reset(); }, storage.get_buffer(i)); } - if ((not tt::tt_metal::detail::InWorkerThread()) or - not this->tensor_attributes->main_thread_tensor) { - if (not this->tensor_attributes->main_thread_tensor) { - TT_ASSERT( - not this->tensor_attributes->main_thread_ref_count, - "main_thread_ref_count for tensors created inside a worker thread must be 0"); + } + }, + [force, this, &get_tensor_ref_count, deallocation_through_destructor](DeviceStorage& storage) { + if (not this->workers.at(0)->is_initialized()) { + return; + } + if (tt::tt_metal::detail::InWorkerThread() and this->tensor_attributes->main_thread_tensor) { + TT_FATAL( + deallocation_through_destructor, + "Device tensors created in the main thread cannot be explictly deallocated in worker " + "threads."); + return; + } + + TT_ASSERT( + not this->tensor_attributes->main_thread_tensor and + not this->tensor_attributes->main_thread_ref_count, + "main_thread_ref_count for tensors created inside a worker thread must be 0"); + const uint32_t ref_count_to_use = get_tensor_ref_count(*this); + if ((force or ref_count_to_use == 1) and not this->tensor_attributes->deallocated) { + this->tensor_attributes->deallocated = true; + this->workers.at(0)->push_work([force, attr = this->tensor_attributes]() mutable { + // Cross worker synchronization: If the tensor being deallocated is shared across + // workers (ex: all_gather op), wait until all workers are done with this tensor + // before deallocating. + bool num_threads_sharing_tensor = attr->num_sibling_workers_sharing_tensor; + if (num_threads_sharing_tensor) { + while (num_threads_sharing_tensor) { + num_threads_sharing_tensor = attr->num_sibling_workers_sharing_tensor; + } } - // If owned by the main thread, deallocate this tensor only from the main thread. If owned by - // worker thread, allow deallocation in worker and use shared_ptr ref count, since this is a - // thread_local tensor - uint32_t ref_count_to_use = - (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or - not this->tensor_attributes->main_thread_tensor) - ? this->tensor_attributes.use_count() - : this->tensor_attributes->main_thread_ref_count; - if ((force or ref_count_to_use == 1) and not this->tensor_attributes->deallocated) { - this->tensor_attributes->deallocated = true; - this->workers.at(0)->push_work([force, attr = this->tensor_attributes]() mutable { - // Cross worker synchronization: If the tensor being deallocated is shared across - // workers (ex: all_gather op), wait until all workers are done with this tensor - // before deallocating. - bool num_threads_sharing_tensor = attr->num_sibling_workers_sharing_tensor; - if (num_threads_sharing_tensor) { - while (num_threads_sharing_tensor) { - num_threads_sharing_tensor = attr->num_sibling_workers_sharing_tensor; + std::visit( + [force, attr](auto&& s) { + using type = std::decay_t; + if constexpr (std::is_same_v) { + if (force or s.buffer.use_count() == 1) { + DeallocateBuffer(*(s.buffer)); } + // Safe to reset this buf object since this is the last reference (in + // the main thread) to the tensor attr object holding this buffer. If + // any other tensor handles hold this buffer, it will not be deleted, + // until the last handle goes out of scope or is deallocated. + s.buffer.reset(); + } else if constexpr (std::is_same_v) { + // Manage Dynamic Storage (due to autoformat in async mode): Main thread + // sees this tensor as a device tensor, since worker has not updated + // storage time. When the worker executes the dealloc request, the + // storage type has been appropriately updated to Owned. + TT_ASSERT( + attr->dynamic_storage, + "Tensor storage type changed during runtime (device -> host), but " + "dynamic storage was not marked."); + std::visit([](auto&& buffer) { buffer.reset(); }, s.buffer); } - std::visit( - [force, attr](auto&& s) { - using type = std::decay_t; - if constexpr (std::is_same_v) { - if (force or s.buffer.use_count() == 1) { - DeallocateBuffer(*(s.buffer)); - } - // Safe to reset this buf object since this is the last reference (in - // the main thread) to the tensor attr object holding this buffer. If - // any other tensor handles hold this buffer, it will not be deleted, - // until the last handle goes out of scope or is deallocated. - s.buffer.reset(); - } else if constexpr (std::is_same_v) { - // Manage Dynamic Storage (due to autoformat in async mode): Main thread - // sees this tensor as a device tensor, since worker has not updated - // storage time. When the worker executes the dealloc request, the - // storage type has been appropriately updated to Owned. - TT_ASSERT( - attr->dynamic_storage, - "Tensor storage type changed during runtime (device -> host), but " - "dynamic storage was not marked."); - std::visit([](auto&& buffer) { buffer.reset(); }, s.buffer); - } - }, - attr->storage); - }); - } + }, + attr->storage); + }); + } + }, + [force, this, &get_tensor_ref_count, deallocation_through_destructor](MultiDeviceStorage& storage) { + if (not this->workers.at(0)->is_initialized()) { + return; + } + if (tt::tt_metal::detail::InWorkerThread() and this->tensor_attributes->main_thread_tensor) { + TT_FATAL( + deallocation_through_destructor, + "Device tensors created in the main thread cannot be explictly deallocated in worker " + "threads."); + return; + } + const uint32_t ref_count_to_use = get_tensor_ref_count(*this); + if ((force or ref_count_to_use == 1) and not this->tensor_attributes->deallocated) { + this->tensor_attributes->deallocated = true; + + if (storage.mesh_buffer != nullptr) { + // TODO: #17215 - Consider if it is possible to retain references to individual device buffers + // after mesh buffer was deallocated. + storage.mesh_buffer->deallocate(); } else { - TT_FATAL( - this->deallocate_through_destructor, - "Device tensors created in the main thread cannot be explictly deallocated in worker " - "threads."); - } - } else if constexpr (std::is_same_v) { - if (force) { - TT_THROW("Cannot deallocate tensor with borrowed storage!"); - } - } else if constexpr (std::is_same_v) { - if (not this->workers.at(0)->is_initialized()) { - return; - } - if ((not tt::tt_metal::detail::InWorkerThread()) or - not this->tensor_attributes->main_thread_tensor) { - // If owned by the main thread, deallocate this tensor only from the main thread. If owned by - // worker thread, allow deallocation in worker and use shared_ptr ref count, since this is a - // thread_local tensor - uint32_t ref_count_to_use = - (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or - not this->tensor_attributes->main_thread_tensor) - ? this->tensor_attributes.use_count() - : this->tensor_attributes->main_thread_ref_count; - if ((force or ref_count_to_use == 1) and not this->tensor_attributes->deallocated) { - this->tensor_attributes->deallocated = true; - auto dealloc_lambda = std::make_shared>( - [force, attr = this->tensor_attributes](IDevice* worker) mutable { - ZoneScopedN("ShardDeallocate"); - TT_ASSERT( - std::holds_alternative(attr->storage), - "Unexpected type {}", - tt::stl::get_active_type_name_in_variant(attr->storage)); - auto& s = std::get(attr->storage); - if (s.has_buffer_for_device(worker)) { - auto& device_buffer = s.get_buffer_for_device(worker); - if (force or device_buffer.use_count() == 1) { - DeallocateBuffer(*device_buffer); - } - device_buffer.reset(); + auto dealloc_lambda = std::make_shared>( + [force, attr = this->tensor_attributes](IDevice* worker) mutable { + ZoneScopedN("ShardDeallocate"); + TT_ASSERT( + std::holds_alternative(attr->storage), + "Unexpected type {}", + tt::stl::get_active_type_name_in_variant(attr->storage)); + auto& s = std::get(attr->storage); + if (s.has_buffer_for_device(worker)) { + auto& device_buffer = s.get_buffer_for_device(worker); + if (force or device_buffer.use_count() == 1) { + DeallocateBuffer(*device_buffer); } - }); + device_buffer.reset(); + } + }); - for (auto worker : this->workers) { - worker->push_work([worker, dealloc_lambda]() mutable { (*dealloc_lambda)(worker); }); - } + for (auto* worker : this->workers) { + worker->push_work([worker, dealloc_lambda]() mutable { (*dealloc_lambda)(worker); }); } - } else { - TT_FATAL( - this->deallocate_through_destructor, - "Device tensors created in the main thread cannot be explictly deallocated in worker " - "threads."); } - } else if constexpr (std::is_same_v) { - if (this->tensor_attributes.use_count() == 1) { - // Same logic as above for host tensors - for (int i = 0; i < storage.num_buffers(); i++) { - auto& current_buffer = storage.get_buffer(i); - std::visit([](auto&& buffer) { buffer.reset(); }, current_buffer); - } - } - } else { - raise_unsupported_storage(); } }, - this->tensor_attributes->storage); - } + }, + this->tensor_attributes->storage); // GraphTracker::instance().track_function_end(); } @@ -482,7 +477,7 @@ void Tensor::perform_cleanup_for_async_mode() { // or move assignment operator if (this->tensor_attributes) { // Object has tensor_attributes that will be reassigned - if (this->workers.size() and (not tt::tt_metal::detail::InWorkerThread()) and + if (this->workers.size() and tt::tt_metal::detail::InMainThread() and this->workers.at(0)->get_worker_mode() == WorkExecutorMode::ASYNCHRONOUS) { // Operator called in main thread with async mode. Main thread Ref Count must be decremented. // This is the last tensor in the main thread holding these attributes. Deallocate the buffer @@ -1049,6 +1044,37 @@ Tensor allocate_tensor_on_devices(const TensorSpec& tensor_spec, const std::vect return device_tensor; } +Tensor allocate_tensor_on_mesh(const TensorSpec& tensor_spec, distributed::MeshDevice* mesh_device) { + // Allocate a mesh buffer synchronously. + TT_FATAL( + tt::tt_metal::detail::InMainThread(), "Allocation of a tensor on mesh must be called from the main thread"); + auto mesh_buffer = tensor_impl::allocate_mesh_buffer_on_device(mesh_device, tensor_spec); + + const auto [num_rows, num_cols] = mesh_device->shape(); + std::vector ordered_device_ids; + std::unordered_map> buffers; + std::unordered_map specs; + + ordered_device_ids.reserve(num_rows * num_cols); + buffers.reserve(num_rows * num_cols); + specs.reserve(num_rows * num_cols); + + for (int row = 0; row < num_rows; ++row) { + for (int col = 0; col < num_cols; ++col) { + auto buffer = mesh_buffer->get_device_buffer(distributed::Coordinate{row, col}); + const int device_id = buffer->device()->id(); + ordered_device_ids.push_back(device_id); + buffers.emplace(device_id, std::move(buffer)); + specs.emplace(device_id, tensor_spec); + } + } + + MultiDeviceStorage multi_device_storage( + ReplicateTensor{}, std::move(ordered_device_ids), std::move(buffers), std::move(specs), std::move(mesh_buffer)); + + return Tensor(std::move(multi_device_storage), tensor_spec); +} + void write_tensor(const Tensor& host_tensor, Tensor device_tensor, uint8_t cq_id) { // Top level wrapper to copy a host tensor to a preallocated device tensor TT_ASSERT(device_tensor.workers.size(), "Workers must be specified for device_tensor in write_tensor"); diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index bd7197290c0..273b79490ec 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -36,7 +36,8 @@ namespace distributed { class MeshDevice; } -struct Tensor { +class Tensor { +public: struct TensorAttributes : public std::enable_shared_from_this { Storage storage; TensorSpec tensor_spec; @@ -81,7 +82,6 @@ struct Tensor { // Tensor gets worker queue handle through the device std::vector workers = {}; - bool deallocate_through_destructor = false; // ====================================================================================== // Hi Level APIs @@ -123,7 +123,6 @@ struct Tensor { perform_cleanup_for_async_mode(); this->workers = std::move(other.workers); this->tensor_attributes = std::move(other.tensor_attributes); - this->deallocate_through_destructor = std::move(other.deallocate_through_destructor); } return *this; } @@ -302,7 +301,7 @@ struct Tensor { storage_type); return std::get(this->get_storage()).get_buffer().get(); } - DeviceBuffer device_buffer() const { return std::get(this->get_storage()).get_buffer(); } + std::shared_ptr device_buffer() const { return std::get(this->get_storage()).get_buffer(); } IDevice* device() const { if (this->storage_type() == tt::tt_metal::StorageType::DEVICE) { @@ -351,6 +350,7 @@ struct Tensor { private: void init(Storage storage, TensorSpec tensor_spec); + void deallocate_impl(bool force, bool deallocation_through_destructor); }; Tensor create_device_tensor(const TensorSpec& tensor_spec, IDevice* device); @@ -399,6 +399,10 @@ void memcpy(Tensor& dst, const void* src, const std::optional& reg void memcpy(Tensor& dst, const Tensor& src, const std::optional& region = std::nullopt); Tensor allocate_tensor_on_devices(const TensorSpec& spec, const std::vector& devices); + +// Allocates a tensor on a mesh device through mesh buffer. +Tensor allocate_tensor_on_mesh(const TensorSpec& tensor_spec, distributed::MeshDevice* mesh_device); + void write_tensor(const Tensor& host_tensor, Tensor device_tensor, uint8_t cq_id = ttnn::DefaultQueueId); Tensor set_tensor_id(const Tensor& tensor); diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp index eec7af3f4ef..fb0570ab4fa 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp @@ -5,6 +5,7 @@ #include "ttnn/tensor/tensor_impl.hpp" #include +#include "tt-metalium/mesh_buffer.hpp" #include "ttnn/tensor/tensor_impl_wrapper.hpp" #include "ttnn/tensor/layout/tensor_layout.hpp" #include "ttnn/tensor/types.hpp" @@ -50,7 +51,7 @@ uint32_t element_size_bytes(DataType dtype) { } } -DeviceBuffer allocate_buffer_on_device(IDevice* device, const TensorSpec& tensor_spec) { +std::shared_ptr allocate_buffer_on_device(IDevice* device, const TensorSpec& tensor_spec) { auto buffer_size_bytes = tensor_spec.compute_packed_buffer_size_bytes(); auto page_size_bytes = tensor_spec.compute_page_size_bytes(); auto shard_spec_buffer = tensor_spec.compute_shard_spec_buffer(); @@ -65,6 +66,22 @@ DeviceBuffer allocate_buffer_on_device(IDevice* device, const TensorSpec& tensor shard_spec_buffer); } +std::shared_ptr allocate_mesh_buffer_on_device( + distributed::MeshDevice* mesh_device, const TensorSpec& tensor_spec) { + const auto& memory_config = tensor_spec.tensor_layout().get_memory_config(); + const distributed::DeviceLocalBufferConfig device_local_buffer_config{ + .page_size = tensor_spec.compute_page_size_bytes(), + .buffer_type = memory_config.buffer_type, + .buffer_layout = memory_config.memory_layout, + .shard_parameters = tensor_spec.compute_shard_spec_buffer(), + }; + const distributed::ReplicatedBufferConfig replicated_buffer_config{ + .size = tensor_spec.compute_packed_buffer_size_bytes(), + }; + + return distributed::MeshBuffer::create(replicated_buffer_config, device_local_buffer_config, mesh_device); +} + void validate_on_device_dtype_and_layout( IDevice* device, const ttnn::SimpleShape& shape, DataType dtype, Layout layout) { // TODO: Get supported layout and dtypes from device @@ -556,7 +573,8 @@ Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id) { // ====================================================================================== template typename BufferType> -void write_data_to_device_buffer(CommandQueue& cq, const BufferType& host_buffer, DeviceBuffer device_buffer) { +void write_data_to_device_buffer( + CommandQueue& cq, const BufferType& host_buffer, std::shared_ptr device_buffer) { ZoneScoped; // TODO(arakhmati): can we use generators in this function to go from `data_to_write` to `uint32_data`? // And effectively get rid of any additional allocation @@ -573,7 +591,7 @@ void write_data_to_device_buffer(const BufferType& host_buffer, Buffer& devic } template typename BufferType> -DeviceBuffer initialize_data_on_device( +std::shared_ptr initialize_data_on_device( BufferType& data_to_write, IDevice* device, const TensorSpec& tensor_spec, @@ -593,9 +611,10 @@ DeviceBuffer initialize_data_on_device( } template -DeviceBuffer to_device_buffer(const Storage& storage, IDevice* device, const TensorSpec& tensor_spec, uint8_t cq_id) { +std::shared_ptr to_device_buffer( + const Storage& storage, IDevice* device, const TensorSpec& tensor_spec, uint8_t cq_id) { return std::visit( - [&device, &tensor_spec, cq_id](auto&& storage) -> DeviceBuffer { + [&device, &tensor_spec, cq_id](auto&& storage) -> std::shared_ptr { using StorageType = std::decay_t; if constexpr (std::is_same_v or std::is_same_v) { auto data_to_write = host_buffer::get_as(storage.buffer); diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp index 7bf44efc417..e08de0af376 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp @@ -185,16 +185,19 @@ void validate_on_device_dtype_and_layout( // Data reader, writer, and initializers // ====================================================================================== -DeviceBuffer allocate_buffer_on_device(IDevice* device, const TensorSpec& tensor_spec); +std::shared_ptr allocate_buffer_on_device(IDevice* device, const TensorSpec& tensor_spec); + +std::shared_ptr allocate_mesh_buffer_on_device( + distributed::MeshDevice* mesh_device, const TensorSpec& tensor_spec); template inline void read_data_from_device_buffer( - CommandQueue& cq, DeviceBuffer device_buffer, void* host_buffer_data, bool blocking) { + CommandQueue& cq, std::shared_ptr device_buffer, void* host_buffer_data, bool blocking) { EnqueueReadBuffer(cq, device_buffer, host_buffer_data, blocking); } template -inline void read_data_from_device_buffer(DeviceBuffer device_buffer, std::vector& host_buffer) { +inline void read_data_from_device_buffer(std::shared_ptr device_buffer, std::vector& host_buffer) { ::tt::tt_metal::detail::ReadFromBuffer(device_buffer, host_buffer); } diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp index e895e583b17..f178316fe0f 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp @@ -120,30 +120,29 @@ uint32_t num_buffers_in_tensor(const Tensor& tensor) { Tensor get_shard_for_device(const Tensor& tensor, IDevice* target_device, std::optional buffer_index) { ZoneScopedN("GetShardForDevice"); - Tensor shard = Tensor(); auto& storage = tensor.tensor_attributes->storage; - std::visit( - [target_device, buffer_index, &tensor, &shard](auto&& s) { + return std::visit( + [target_device, buffer_index, &tensor](auto&& s) { using T = std::decay_t; // Stalling reads for tensor data-type and layout are needed here // since some worker might have raced ahead to these lookups, while // another worker is populating this metadata. if constexpr (std::is_same_v) { - shard = Tensor{ + return Tensor{ DeviceStorage{s.get_buffer_for_device(target_device)}, s.get_tensor_spec_for_device(target_device)}; } else if constexpr (std::is_same_v) { - shard = - Tensor{OwnedStorage{s.get_buffer(buffer_index.value())}, s.get_tensor_spec(buffer_index.value())}; + return Tensor{ + OwnedStorage{s.get_buffer(buffer_index.value())}, s.get_tensor_spec(buffer_index.value())}; } else if constexpr ( std::is_same_v || std::is_same_v || std::is_same_v) { - shard = tensor; + return tensor; } else { TT_THROW("get_shard_for_device only supports multi-device or device tensors"); + return Tensor(); } }, storage); - return shard; } void insert_buffer_and_shape_for_device( diff --git a/ttnn/cpp/ttnn/tensor/types.hpp b/ttnn/cpp/ttnn/tensor/types.hpp index 993f311eec8..d9a772227d2 100644 --- a/ttnn/cpp/ttnn/tensor/types.hpp +++ b/ttnn/cpp/ttnn/tensor/types.hpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include