forked from tenstorrent/tt-metal
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
tenstorrent#17215: Initial MeshBuffer integration with TTNN (tenstorr…
…ent#17259) ### Ticket tenstorrent#17215 ### Problem description See tenstorrent#17215 ### What's changed Extend `MultiDeviceStorage` with `MeshBuffer`, which is optionally created to back the individual per-device shards. This allows to incrementally switch over to `MeshBuffer` backed variant, while not breaking any of the existing ops. Long term plan for tensor storage: * `MeshBuffer` backed `MultiDeviceStorage` will become the default in TTNN. It will eventually be renamed to `MeshDeviceStorage`, with the `DeviceStorage` being removed. * Interactions with `MeshBuffer` will be entirely synchronous and will be done on the main thread. This allows to get rid of any of the async code in `Tensor`. Next steps in terms of integrating with `MeshBuffer`: - [X] Implement explicit dealloc routine for `MeshBuffer` (done in tenstorrent#17265 and integrated here). - [ ] Implement read / write shards APIs for `MeshBuffer`. From the TTNN perspective, interacting with these APIs will be entirely synchronous. - [ ] Use read / write shards APIs when writing data to `MeshBuffer` backed `MultiDeviceStorage`. - [ ] When launching multi-device operations, create a `MeshBuffer` backed `MultiDeviceStorage` first, then supply the individual shards into ops. This way allows to perform allocation in lock-step across mesh, while maintaining compatibility with the existing ops infra. Note this will change with the introduction of `MeshWorkload`, and this will require further exploration. ### Checklist - [X] [Post commit CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/13036677827) - [X] [T3K unit tests](https://github.com/tenstorrent/tt-metal/actions/runs/13036686228) - [X] New/Existing tests provide coverage for changes
- Loading branch information
1 parent
ce25236
commit 6b3730d
Showing
13 changed files
with
316 additions
and
200 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <gtest/gtest.h> | ||
#include <gmock/gmock.h> | ||
|
||
#include "ttnn/tensor/tensor.hpp" | ||
#include "ttnn_test_fixtures.hpp" | ||
#include <ttnn/distributed/types.hpp> | ||
#include <ttnn/distributed/distributed_tensor.hpp> | ||
|
||
namespace ttnn::distributed::test { | ||
namespace { | ||
|
||
using MeshTensorTest = T3kMultiDeviceFixture; | ||
|
||
TEST_F(MeshTensorTest, Lifecycle) { | ||
const TensorSpec tensor_spec = | ||
TensorSpec(ttnn::SimpleShape{1, 1, 32, 32}, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{})); | ||
|
||
Tensor input_tensor = allocate_tensor_on_mesh(tensor_spec, mesh_device_.get()); | ||
|
||
EXPECT_EQ(input_tensor.workers.size(), mesh_device_->num_devices()); | ||
EXPECT_TRUE(input_tensor.is_allocated()); | ||
|
||
const auto& storage = input_tensor.get_storage(); | ||
auto* multi_device_storage = std::get_if<tt::tt_metal::MultiDeviceStorage>(&storage); | ||
|
||
ASSERT_NE(multi_device_storage, nullptr); | ||
EXPECT_NE(multi_device_storage->mesh_buffer, nullptr); | ||
|
||
// Buffer address is the same across all device buffers. | ||
const auto buffer_address = multi_device_storage->mesh_buffer->address(); | ||
for (auto* device : mesh_device_->get_devices()) { | ||
auto buffer = multi_device_storage->get_buffer_for_device(device); | ||
ASSERT_NE(buffer, nullptr); | ||
EXPECT_TRUE(buffer->is_allocated()); | ||
EXPECT_EQ(buffer->address(), buffer_address); | ||
} | ||
|
||
input_tensor.deallocate(); | ||
EXPECT_FALSE(input_tensor.is_allocated()); | ||
} | ||
|
||
} // namespace | ||
} // namespace ttnn::distributed::test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.