forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathResize.cpp
64 lines (57 loc) · 1.98 KB
/
Resize.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/PeerToPeerAccess.h>
#include <torch/library.h>
#include <ATen/native/cuda/Resize.h>
#include <ATen/native/ResizeCommon.h>
namespace at {
namespace native {
void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes) {
TORCH_CHECK(storage->resizable(), "Trying to resize storage that is not resizable");
auto allocator = storage->allocator();
TORCH_CHECK(allocator != nullptr, "Trying to resize storage without an allocator");
auto device = at::cuda::current_device();
if (size_bytes == 0) {
storage->set_data_ptr_noswap(at::DataPtr(nullptr, at::Device(at::DeviceType::CUDA, device)));
storage->set_nbytes(0);
return;
}
at::DataPtr data = allocator->allocate(size_bytes);
if (storage->data_ptr()) {
// Enable p2p access when the memcpy is across devices
at::globalContext().lazyInitCUDA();
at::cuda::get_p2p_access(device, storage->device().index());
C10_CUDA_CHECK(
cudaMemcpyAsync(
data.get(),
storage->data(),
std::min(storage->nbytes(), size_bytes),
cudaMemcpyDeviceToDevice,
c10::cuda::getCurrentCUDAStream()));
}
// Destructively overwrite data_ptr
storage->set_data_ptr_noswap(std::move(data));
storage->set_nbytes(size_bytes);
}
const Tensor& resize_cuda_(
const Tensor& self,
IntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
if (self.has_names()) {
return resize_named_tensor_(self, size, optional_memory_format);
}
auto* self_ = self.unsafeGetTensorImpl();
resize_impl_cuda_(self_, size, /*strides=*/c10::nullopt);
if (optional_memory_format.has_value()) {
auto memory_format =
optional_memory_format.value();
TORCH_CHECK(
memory_format != MemoryFormat::Preserve,
"Unsupported memory format",
memory_format);
self_->empty_tensor_restride(memory_format);
}
return self;
}
} // namespace native
} // namespace at