Skip to content

Commit

Permalink
tenstorrent#17215: Initial MeshBuffer integration with TTNN (tenstorr…
Browse files Browse the repository at this point in the history
…ent#17259)

### Ticket
tenstorrent#17215 

### Problem description
See tenstorrent#17215

### What's changed
Extend `MultiDeviceStorage` with `MeshBuffer`, which is optionally
created to back the individual per-device shards. This allows to
incrementally switch over to `MeshBuffer` backed variant, while not
breaking any of the existing ops.

Long term plan for tensor storage:
* `MeshBuffer` backed `MultiDeviceStorage` will become the default in
TTNN. It will eventually be renamed to `MeshDeviceStorage`, with the
`DeviceStorage` being removed.
* Interactions with `MeshBuffer` will be entirely synchronous and will
be done on the main thread. This allows to get rid of any of the async
code in `Tensor`.

Next steps in terms of integrating with `MeshBuffer`:
- [X] Implement explicit dealloc routine for `MeshBuffer` (done in
tenstorrent#17265 and integrated here).
- [ ] Implement read / write shards APIs for `MeshBuffer`. From the TTNN
perspective, interacting with these APIs will be entirely synchronous.
- [ ] Use read / write shards APIs when writing data to `MeshBuffer`
backed `MultiDeviceStorage`.
- [ ] When launching multi-device operations, create a `MeshBuffer`
backed `MultiDeviceStorage` first, then supply the individual shards
into ops. This way allows to perform allocation in lock-step across
mesh, while maintaining compatibility with the existing ops infra. Note
this will change with the introduction of `MeshWorkload`, and this will
require further exploration.

### Checklist
- [X] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/13036677827)
- [X] [T3K unit
tests](https://github.com/tenstorrent/tt-metal/actions/runs/13036686228)
- [X] New/Existing tests provide coverage for changes
  • Loading branch information
omilyutin-tt authored and nikileshx committed Feb 3, 2025
1 parent ce25236 commit 6b3730d
Show file tree
Hide file tree
Showing 13 changed files with 316 additions and 200 deletions.
1 change: 1 addition & 0 deletions tests/ttnn/unit_tests/gtests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ set(TTNN_TENSOR_UNIT_TESTS_SRC
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_create_tensor_multi_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_create_tensor_with_layout.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_distributed_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_mesh_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_partition.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_shape_base.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_tensor_sharding.cpp
Expand Down
47 changes: 47 additions & 0 deletions tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>
#include <gmock/gmock.h>

#include "ttnn/tensor/tensor.hpp"
#include "ttnn_test_fixtures.hpp"
#include <ttnn/distributed/types.hpp>
#include <ttnn/distributed/distributed_tensor.hpp>

namespace ttnn::distributed::test {
namespace {

using MeshTensorTest = T3kMultiDeviceFixture;

TEST_F(MeshTensorTest, Lifecycle) {
const TensorSpec tensor_spec =
TensorSpec(ttnn::SimpleShape{1, 1, 32, 32}, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{}));

Tensor input_tensor = allocate_tensor_on_mesh(tensor_spec, mesh_device_.get());

EXPECT_EQ(input_tensor.workers.size(), mesh_device_->num_devices());
EXPECT_TRUE(input_tensor.is_allocated());

const auto& storage = input_tensor.get_storage();
auto* multi_device_storage = std::get_if<tt::tt_metal::MultiDeviceStorage>(&storage);

ASSERT_NE(multi_device_storage, nullptr);
EXPECT_NE(multi_device_storage->mesh_buffer, nullptr);

// Buffer address is the same across all device buffers.
const auto buffer_address = multi_device_storage->mesh_buffer->address();
for (auto* device : mesh_device_->get_devices()) {
auto buffer = multi_device_storage->get_buffer_for_device(device);
ASSERT_NE(buffer, nullptr);
EXPECT_TRUE(buffer->is_allocated());
EXPECT_EQ(buffer->address(), buffer_address);
}

input_tensor.deallocate();
EXPECT_FALSE(input_tensor.is_allocated());
}

} // namespace
} // namespace ttnn::distributed::test
1 change: 1 addition & 0 deletions tt_metal/api/tt-metalium/tt_metal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace detail {
bool DispatchStateCheck(bool isFastDispatch);

bool InWorkerThread();
inline bool InMainThread() { return not InWorkerThread(); }

std::map<chip_id_t, IDevice*> CreateDevices(
// TODO: delete this in favour of DevicePool
Expand Down
9 changes: 5 additions & 4 deletions ttnn/cpp/ttnn/distributed/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ Tensor aggregate_as_tensor(
} else {
std::vector<int> ordered_device_ids;
std::unordered_map<int, ttnn::TensorSpec> specs;
std::unordered_map<int, DeviceBuffer> device_buffers;
std::unordered_map<int, std::shared_ptr<Buffer>> device_buffers;
for (const auto& shard : tensor_shards) {
IDevice* device = std::get<DeviceStorage>(shard.get_storage()).buffer->device();
auto device_id = device->id();
Expand All @@ -116,7 +116,8 @@ Tensor aggregate_as_tensor(
shard_tile.get_width());
}
}
auto storage = MultiDeviceStorage{config, ordered_device_ids, std::move(device_buffers), specs};
auto storage =
MultiDeviceStorage{config, ordered_device_ids, std::move(device_buffers), specs, /*mesh_buffer_=*/nullptr};
return Tensor(std::move(storage), reference_shard.get_tensor_spec());
}
}
Expand Down Expand Up @@ -247,7 +248,7 @@ Tensor create_multi_device_tensor(
if (storage_type == StorageType::MULTI_DEVICE) {
std::vector<int> ordered_device_ids;
std::unordered_map<int, ttnn::TensorSpec> specs;
std::unordered_map<int, DeviceBuffer> device_buffers;
std::unordered_map<int, std::shared_ptr<Buffer>> device_buffers;
for (const auto& tensor : tensors) {
TT_ASSERT(
std::holds_alternative<DeviceStorage>(tensor.get_storage()),
Expand All @@ -260,7 +261,7 @@ Tensor create_multi_device_tensor(
specs.insert({device_id, tensor.get_tensor_spec()});
}
return Tensor{
MultiDeviceStorage{strategy, ordered_device_ids, device_buffers, specs},
MultiDeviceStorage{strategy, ordered_device_ids, device_buffers, specs, /*mesh_buffer_=*/nullptr},
TensorSpec(
tensors.at(0).get_logical_shape(),
TensorLayout::fromPaddedShape(
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/operations/experimental/reshape/view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ Tensor tensor_reshape(
if (input_tensor.get_layout() == Layout::ROW_MAJOR) {
if (tensor.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) {
DeviceStorage device_storage = std::get<T>(tensor.get_storage());
DeviceBuffer device_buffer = device_storage.get_buffer();
auto device_buffer = device_storage.get_buffer();
const auto& tensor_spec = tensor.tensor_spec();
auto page_size_bytes = tensor_spec.compute_page_size_bytes();
device_buffer->set_page_size(page_size_bytes);
device_storage.insert_buffer(device_buffer);
return Tensor(device_storage, new_spec);
} else {
DeviceStorage device_storage = std::get<T>(tensor.get_storage());
DeviceBuffer device_buffer = device_storage.get_buffer();
auto device_buffer = device_storage.get_buffer();
ShardSpecBuffer shard_spec_buffer = device_buffer->shard_spec();

auto shard_spec = shard_spec_buffer.tensor_shard_spec;
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/tensor/storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

namespace tt::tt_metal {

std::vector<DeviceBuffer> MultiDeviceStorage::get_buffers() const {
std::vector<std::shared_ptr<Buffer>> MultiDeviceStorage::get_buffers() const {
std::lock_guard<std::mutex> lock(buffer_mtx);
std::vector<DeviceBuffer> buf_vec;
std::vector<std::shared_ptr<Buffer>> buf_vec;
buf_vec.reserve(buffers.size());
for (const auto& pair : buffers) {
buf_vec.push_back(pair.second);
Expand Down
64 changes: 39 additions & 25 deletions ttnn/cpp/ttnn/tensor/storage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ struct OwnedStorage {
}
};

using DeviceBuffer = std::shared_ptr<Buffer>;
// TODO: #17215 - Replace `DeviceStorage` with "mesh storage".
struct DeviceStorage {
DeviceBuffer buffer;
std::shared_ptr<Buffer> buffer;
DeviceStorage() = default;
DeviceStorage(DeviceBuffer buffer_) : buffer(std::move(buffer_)) {}
DeviceStorage(std::shared_ptr<Buffer> buffer_) : buffer(std::move(buffer_)) {}

const MemoryConfig memory_config() const {
MemoryConfig memory_config() const {
if (this->buffer.get() == nullptr) {
TT_THROW("MemoryConfig can only be obtained if the buffer is not null");
}
Expand All @@ -59,9 +59,9 @@ struct DeviceStorage {
.shard_spec = shard_spec};
}

inline void insert_buffer(DeviceBuffer buffer_) { this->buffer = buffer_; }
inline void insert_buffer(const std::shared_ptr<Buffer>& buffer_) { this->buffer = buffer_; }

inline DeviceBuffer get_buffer() const { return this->buffer; }
inline std::shared_ptr<Buffer> get_buffer() const { return this->buffer; }
static constexpr auto attribute_names = std::forward_as_tuple("memory_config");
const auto attribute_values() const { return std::make_tuple(this->memory_config()); }

Expand Down Expand Up @@ -149,7 +149,7 @@ struct MultiDeviceHostStorage {
MultiDeviceHostStorage() = default;
MultiDeviceHostStorage(
DistributedTensorConfig strategy_, std::vector<OwnedBuffer> buffers_, std::vector<TensorSpec> specs_) :
strategy(strategy_), buffers(buffers_), specs(specs_) {}
strategy(strategy_), buffers(std::move(buffers_)), specs(std::move(specs_)) {}
MultiDeviceHostStorage(MultiDeviceHostStorage&& other) { swap(*this, other); }
// unfotunately we need to have this code written manually.
MultiDeviceHostStorage(const MultiDeviceHostStorage& other) {
Expand Down Expand Up @@ -222,8 +222,13 @@ struct MultiDeviceHostStorage {
struct MultiDeviceStorage {
DistributedTensorConfig strategy;
std::vector<int> ordered_device_ids;
std::unordered_map<int, DeviceBuffer> buffers;
std::unordered_map<int, std::shared_ptr<Buffer>> buffers;
std::unordered_map<int, TensorSpec> specs;

// TODO: #17215 - This isn't populated by default. Switch to creating MeshBuffer backed storage, when TTNN is ready
// to consume it.
// Eventually, `MultiDeviceStorage` will be renamed to `MeshDeviceStorage`, and unified with `DeviceStorage`.
std::shared_ptr<distributed::MeshBuffer> mesh_buffer;
mutable std::mutex buffer_mtx;
mutable std::mutex shape_mtx;
MultiDeviceStorage() = default;
Expand All @@ -235,17 +240,20 @@ struct MultiDeviceStorage {
swap(first.ordered_device_ids, second.ordered_device_ids);
swap(first.buffers, second.buffers);
swap(first.specs, second.specs);
swap(first.mesh_buffer, second.mesh_buffer);
}

MultiDeviceStorage(
DistributedTensorConfig strategy_,
std::vector<int> ordered_device_ids_,
std::unordered_map<int, DeviceBuffer> buffers_,
std::unordered_map<int, TensorSpec> specs_) :
std::unordered_map<int, std::shared_ptr<Buffer>> buffers_,
std::unordered_map<int, TensorSpec> specs_,
std::shared_ptr<distributed::MeshBuffer> mesh_buffer_) :
strategy(std::move(strategy_)),
ordered_device_ids(std::move(ordered_device_ids_)),
buffers(std::move(buffers_)),
specs(std::move(specs_)) {}
specs(std::move(specs_)),
mesh_buffer(std::move(mesh_buffer_)) {}

MultiDeviceStorage(MultiDeviceStorage&& other) { swap(*this, other); }

Expand All @@ -255,6 +263,7 @@ struct MultiDeviceStorage {
strategy = other.strategy;
buffers = other.buffers;
specs = other.specs;
mesh_buffer = other.mesh_buffer;
}

MultiDeviceStorage& operator=(const MultiDeviceStorage& other) {
Expand All @@ -270,10 +279,10 @@ struct MultiDeviceStorage {

bool operator==(const MultiDeviceStorage& other) {
return this->ordered_device_ids == other.ordered_device_ids and this->strategy == other.strategy and
this->buffers == other.buffers and this->specs == other.specs;
this->buffers == other.buffers and this->specs == other.specs and this->mesh_buffer == other.mesh_buffer;
}

inline const MemoryConfig memory_config() const {
MemoryConfig memory_config() const {
std::lock_guard<std::mutex> lock(buffer_mtx);
TT_FATAL(
!this->ordered_device_ids.empty(), "No device ids in list. Please ensure fields are initialized properly.");
Expand All @@ -296,18 +305,20 @@ struct MultiDeviceStorage {

// Helper Functions - Getters and setters to get/modify storage attributes. These are needed to
// preinitialize empty tensor handles and use/populate them in the worker threads.
std::vector<DeviceBuffer> get_buffers() const;
std::vector<std::shared_ptr<Buffer>> get_buffers() const;

inline void insert_buffer_and_spec_for_device(IDevice* device, const DeviceBuffer buffer, TensorSpec spec) {
inline void insert_buffer_and_spec_for_device(
IDevice* device, const std::shared_ptr<Buffer>& buffer, TensorSpec spec) {
std::scoped_lock lock(buffer_mtx, shape_mtx);
TT_FATAL(mesh_buffer == nullptr, "MeshBuffer backed storage does not support inserting individual buffers");
TT_ASSERT(
device == buffer->device(),
"Mismatch between device derived from buffer and device derived from MultiDeviceStorage.");
buffers.insert({device->id(), buffer});
specs.insert({device->id(), std::move(spec)});
}

inline DeviceBuffer get_buffer_for_device(IDevice* device) const {
inline std::shared_ptr<Buffer> get_buffer_for_device(IDevice* device) const {
std::lock_guard<std::mutex> lock(buffer_mtx);
TT_ASSERT(buffers.find(device->id()) != buffers.end(), "Buffer not found for device {}", device->id());
TT_ASSERT(
Expand All @@ -316,7 +327,7 @@ struct MultiDeviceStorage {
return buffers.at(device->id());
}

inline DeviceBuffer& get_buffer_for_device(IDevice* device) {
inline std::shared_ptr<Buffer>& get_buffer_for_device(IDevice* device) {
std::lock_guard<std::mutex> lock(buffer_mtx);
TT_ASSERT(buffers.find(device->id()) != buffers.end(), "Buffer not found for device {}", device->id());
TT_ASSERT(
Expand All @@ -325,7 +336,7 @@ struct MultiDeviceStorage {
return buffers.at(device->id());
}

inline DeviceBuffer get_buffer_for_device_id(uint32_t device_id) const {
inline std::shared_ptr<Buffer> get_buffer_for_device_id(uint32_t device_id) const {
std::lock_guard<std::mutex> lock(buffer_mtx);
return buffers.at(device_id);
}
Expand All @@ -352,13 +363,16 @@ struct MultiDeviceStorage {
}

inline bool is_allocated() const {
std::lock_guard<std::mutex> lock(buffer_mtx);

return std::all_of(
ordered_device_ids.begin(), ordered_device_ids.end(), [&buffers = this->buffers](auto&& device_id) {
const auto& buffer = buffers.at(device_id);
return buffer && buffer->is_allocated();
});
if (mesh_buffer != nullptr) {
return mesh_buffer->is_allocated();
} else {
std::lock_guard<std::mutex> lock(buffer_mtx);
return std::all_of(
ordered_device_ids.begin(), ordered_device_ids.end(), [&buffers = this->buffers](auto&& device_id) {
const auto& buffer = buffers.at(device_id);
return buffer && buffer->is_allocated();
});
}
}
};

Expand Down
Loading

0 comments on commit 6b3730d

Please sign in to comment.