Skip to content

Commit

Permalink
#16066: rollback tensor_ops
Browse files Browse the repository at this point in the history
  • Loading branch information
BuiChiTrung committed Dec 24, 2024
1 parent 024c66c commit 569005e
Showing 1 changed file with 21 additions and 18 deletions.
39 changes: 21 additions & 18 deletions ttnn/cpp/ttnn/tensor/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,27 @@ Tensor tensor_to(
// Record main thread ref count for tensors before pushing to queue.
uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count();
uint32_t original_tensor_ref_count = async_safe_tensor.tensor_attributes->record_main_thread_ref_count();
target_device->push_work(
[async_safe_tensor, device_tensor, mem_config, target_device, cq_id, sub_device_ids]() mutable {
if (async_safe_tensor.storage_type() == StorageType::DEVICE) {
TT_ASSERT(
async_safe_tensor.device() == target_device && "Currently do not support moving between devices");
device_tensor.populate_buffers_and_metadata(async_safe_tensor);
} else {
tensor_impl::validate_on_device_dtype_and_layout(
target_device,
async_safe_tensor.get_padded_shape(),
async_safe_tensor.get_dtype(),
async_safe_tensor.get_layout());
auto local_tensor =
tensor_impl::to_device_wrapper(async_safe_tensor, target_device, mem_config, cq_id, sub_device_ids);
// Populate device tensor
device_tensor.populate_buffers_and_metadata(local_tensor);
}
});
target_device->push_work([async_safe_tensor,
device_tensor,
mem_config,
target_device,
cq_id,
sub_device_ids]() mutable {
if (async_safe_tensor.storage_type() == StorageType::DEVICE) {
TT_ASSERT(async_safe_tensor.device() == target_device && "Currently do not support moving between devices");
device_tensor.populate_buffers_and_metadata(async_safe_tensor);
} else {
tensor_impl::validate_on_device_dtype_and_layout(
target_device,
async_safe_tensor.get_padded_shape(),
async_safe_tensor.get_dtype(),
async_safe_tensor.get_layout());
auto local_tensor =
tensor_impl::to_device_wrapper(async_safe_tensor, target_device, mem_config, cq_id, sub_device_ids);
// Populate device tensor
device_tensor.populate_buffers_and_metadata(local_tensor);
}
});
// Update main thread ref count for tensors after pushing to queue (update original tensor and returned tensor,
// since both can be on device).
device_tensor.tensor_attributes->update_main_thread_ref_count(device_tensor.workers.at(0), device_tensor_ref_count);
Expand Down

0 comments on commit 569005e

Please sign in to comment.