From 569005e01bd0c6b40dfd47bf507d0a62486be577 Mon Sep 17 00:00:00 2001 From: BuiChiTrung Date: Thu, 19 Dec 2024 10:50:57 +0000 Subject: [PATCH] #16066: rollback tensor_ops --- ttnn/cpp/ttnn/tensor/tensor_ops.cpp | 39 ++++++++++++++++------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index 3952b911028..c7234d8a1a7 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -40,24 +40,27 @@ Tensor tensor_to( // Record main thread ref count for tensors before pushing to queue. uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); uint32_t original_tensor_ref_count = async_safe_tensor.tensor_attributes->record_main_thread_ref_count(); - target_device->push_work( - [async_safe_tensor, device_tensor, mem_config, target_device, cq_id, sub_device_ids]() mutable { - if (async_safe_tensor.storage_type() == StorageType::DEVICE) { - TT_ASSERT( - async_safe_tensor.device() == target_device && "Currently do not support moving between devices"); - device_tensor.populate_buffers_and_metadata(async_safe_tensor); - } else { - tensor_impl::validate_on_device_dtype_and_layout( - target_device, - async_safe_tensor.get_padded_shape(), - async_safe_tensor.get_dtype(), - async_safe_tensor.get_layout()); - auto local_tensor = - tensor_impl::to_device_wrapper(async_safe_tensor, target_device, mem_config, cq_id, sub_device_ids); - // Populate device tensor - device_tensor.populate_buffers_and_metadata(local_tensor); - } - }); + target_device->push_work([async_safe_tensor, + device_tensor, + mem_config, + target_device, + cq_id, + sub_device_ids]() mutable { + if (async_safe_tensor.storage_type() == StorageType::DEVICE) { + TT_ASSERT(async_safe_tensor.device() == target_device && "Currently do not support moving between devices"); + device_tensor.populate_buffers_and_metadata(async_safe_tensor); + } else { + tensor_impl::validate_on_device_dtype_and_layout( + target_device, + async_safe_tensor.get_padded_shape(), + async_safe_tensor.get_dtype(), + async_safe_tensor.get_layout()); + auto local_tensor = + tensor_impl::to_device_wrapper(async_safe_tensor, target_device, mem_config, cq_id, sub_device_ids); + // Populate device tensor + device_tensor.populate_buffers_and_metadata(local_tensor); + } + }); // Update main thread ref count for tensors after pushing to queue (update original tensor and returned tensor, // since both can be on device). device_tensor.tensor_attributes->update_main_thread_ref_count(device_tensor.workers.at(0), device_tensor_ref_count);