forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathResize.h
62 lines (51 loc) · 1.85 KB
/
Resize.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#pragma once
#include <ATen/ATen.h>
#include <ATen/native/ResizeCommon.h>
#include <c10/cuda/CUDAGuard.h>
namespace at { namespace native {
TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes);
static inline void maybe_resize_storage_cuda(TensorImpl* self, uint64_t new_size) {
// It does not make sense to try to resize a storage
// to hold 0 elements, and this can break
// if storage_offset is positive but
// new_size is 0, so just bail in that case
// (same comment is in Resize.h)
if (new_size == 0) {
return;
}
auto new_size_bytes_i = (new_size + self->storage_offset()) * self->dtype().itemsize();
TORCH_CHECK(!overflows<size_t>(new_size_bytes_i), "Requested storage size (",
new_size_bytes_i, ") cannot be represented as a size_t");
const auto new_size_bytes = static_cast<size_t>(new_size_bytes_i);
const Storage &storage = self->unsafe_storage();
TORCH_CHECK(storage, "Tensor: invalid null storage");
if (new_size_bytes > storage.nbytes()) {
resize_bytes_cuda(storage.unsafeGetStorageImpl(), new_size_bytes);
}
}
inline TensorImpl* resize_impl_cuda_(
TensorImpl* self,
IntArrayRef size,
c10::optional<IntArrayRef> stride,
bool device_guard = true) {
if (self->sizes() == size && (!stride || self->strides() == stride)) {
return self;
}
// NB: We don't need to hold the device guard when calling from TH
cuda::OptionalCUDAGuard guard;
if (device_guard) {
guard.set_index(self->storage().device().index());
}
int64_t storage_size = 1;
if (stride) {
self->set_sizes_and_strides(size, *stride);
// NB: storage size can be different from numel.
storage_size = storage_size_for(size, *stride);
} else {
self->set_sizes_contiguous(size);
storage_size = self->numel();
}
maybe_resize_storage_cuda(self, storage_size);
return self;
}
}}