forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathResize.cpp
324 lines (290 loc) · 11.4 KB
/
Resize.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/native/Resize.h>
#include <ATen/native/ResizeCommon.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/TensorSubclassLikeUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/resize_as_native.h>
#include <ATen/ops/resize_native.h>
#include <ATen/ops/resize.h>
#include <ATen/ops/_resize_output.h>
#include <ATen/ops/_resize_output_native.h>
#endif
#include <c10/util/overflows.h>
namespace at::native {
// Returns true if resize is necessary
template <typename T>
bool _resize_output_check(const Tensor& output, ArrayRef<T> shape) {
// Tests for resizing of tensors with one or more elements
if (at::symint::sizes<T>(output).equals(shape)) {
return false;
}
if (at::symint::numel<T>(output) != 0) {
TORCH_WARN(
"An output with one or more elements was resized since it had ",
"shape ", at::symint::sizes<T>(output), ", which does not match the required ",
"output shape ", shape, ". ",
"This behavior is deprecated, and in a future PyTorch release outputs ",
"will not be resized unless they have zero elements. You can explicitly ",
"reuse an out tensor t by resizing it, inplace, to zero elements with ",
"t.resize_(0).");
}
return true;
}
bool resize_output_check(const Tensor& output, IntArrayRef shape) {
return _resize_output_check(output, shape);
}
bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape) {
return _resize_output_check(output, shape);
}
static void native_resize_(const Tensor& output, IntArrayRef shape) {
native::resize_(output, shape);
}
static void native_resize_(const Tensor& output, SymIntArrayRef shape) {
native::resize__symint(output, shape);
}
template <typename T>
bool _resize_output(const Tensor& output, ArrayRef<T> shape) {
if (_resize_output_check<T>(output, shape)) {
// avoid a redispatch for cpu and cuda.
// TODO: when resize_cuda_ is re-written to be unified with resize_,
// we can provide the same benefit for cuda.
//
// TODO(#61485): functorch wrapped tensors should not go through the
// fast path. This is a hack, longer term solutions are in the issue
if (output.is_cpu() && !isTensorSubclassLike(output)) {
native_resize_(output, shape);
} else {
at::symint::resize_<T>(output, shape);
}
return true;
} else {
return false;
}
}
bool resize_output(const Tensor& output, IntArrayRef shape) {
return _resize_output(output, shape);
}
bool resize_output_symint(const Tensor& output, SymIntArrayRef shape) {
return _resize_output(output, shape);
}
const Tensor& _resize_output_(const Tensor& self, IntArrayRef shape, c10::Device device) {
TORCH_CHECK(self.device() == device, "out Tensor doesn't have the correct device set");
at::native::resize_output(self, shape);
return self;
}
void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes) {
TORCH_CHECK(storage->resizable(), "Trying to resize storage that is not resizable");
at::DataPtr new_data;
if (size_bytes != 0) {
new_data = storage->allocator()->allocate(size_bytes);
}
const at::DataPtr& old_data = storage->data_ptr();
const auto old_capacity = storage->nbytes();
const auto copy_capacity = std::min(size_bytes, old_capacity);
if (old_data != nullptr && copy_capacity > 0) {
memcpy(new_data.get(), old_data.get(), copy_capacity);
}
storage->set_data_ptr_noswap(std::move(new_data));
storage->set_nbytes(size_bytes);
}
// Call the sparse implementation in SparseTensor.cpp directly.
// A dynamic dispatch here is NOT necessary, so I didn't put
// this function in native_functions.yaml
const Tensor& resize_as_sparse_(const Tensor& self, const Tensor& src);
// TODO(VitalyFedyunin): Move it to HTML docs.
//
// Strides of the output tensor of `resize_as_` operator is defined by input
// tensor strides and the value of memory_format argument.
//
// If memory_format is omitted and input tensor have the same shape as output
// tensor, strides of the output will remain unchanged. Strides going to be
// set to contiguous if shapes are different.
//
// If memory_format is equals to MemoryFormat::Contiguous (torch.contiguous_format)
// output tensor will have contiguous strides.
//
// If memory_format is equal to MemoryFormat::ChannelsLast (torch.channels_last)
// and input tensor is 4D, output tensor will have channels last memory layout.
//
// If memory_format is equal to MemoryFormat::Preserve (torch.preserve_format)
// output tensor will be defined by strides of the input tensor, following
// memory format preservation rule:
//
// - If input tensor strides are in channels last format, output tensor will
// have channels last memory layout.
//
// - Otherwise, output tensor will have contiguous memory layout.
//
const Tensor& resize_as_(
const Tensor& self,
const Tensor& the_template,
std::optional<MemoryFormat> optional_memory_format) {
if (self.is_sparse() && the_template.is_sparse()) {
TORCH_CHECK(
!optional_memory_format.has_value(),
"Unsupported memory format for sparse tensor resize_as_ :",
optional_memory_format.value());
return at::native::resize_as_sparse_(self, the_template);
}
const Tensor& result = self.resize_(the_template.sizes());
if (optional_memory_format.has_value()) {
auto memory_format = optional_memory_format.value();
if (memory_format == MemoryFormat::Preserve) {
memory_format = the_template.suggest_memory_format();
}
self.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
}
namedinference::propagate_names(result, the_template);
return result;
}
void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes) {
TORCH_CHECK(storage->resizable(), "Trying to resize storage that is not resizable");
storage->set_nbytes(std::move(size_bytes));
}
static void maybe_resize_storage_meta(TensorImpl* self, c10::SymInt new_size_bytes) {
// It does not make sense to try to resize a storage
// to hold 0 elements, and this can break
// if storage_offset is positive but
// new_size is 0, so just bail in that case
// (same comment is in Resize.h)
if (self->sym_numel() == 0) {
return;
}
const Storage& storage = self->unsafe_storage();
if (!storage) {
TORCH_INTERNAL_ASSERT(0, "NYI, this should only be Caffe2");
} else if (new_size_bytes > storage.sym_nbytes()) {
resize_bytes_meta(storage.unsafeGetStorageImpl(), std::move(new_size_bytes));
}
}
static void _maybe_resize_storage(TensorImpl* self, int64_t new_size_bytes) {
maybe_resize_storage_cpu(self, new_size_bytes);
}
static void _maybe_resize_storage(TensorImpl* self, c10::SymInt new_size_bytes) {
if (self->is_cpu()) {
maybe_resize_storage_cpu(self, new_size_bytes.expect_int());
return;
}
TORCH_INTERNAL_ASSERT(self->is_meta());
maybe_resize_storage_meta(self, std::move(new_size_bytes));
}
template <typename T>
TensorImpl* _resize_impl_(
TensorImpl* self,
ArrayRef<T> size,
at::OptionalArrayRef<T> stride,
bool resize_storage) {
if (self->generic_sizes<T>() == size && (!stride || self->generic_strides<T>() == stride.value())) {
return self;
}
const auto itemsize = self->dtype().itemsize();
const auto storage_offset = self->generic_storage_offset<T>();
T storage_size = T(1);
if (stride) {
self->set_sizes_and_strides(size, *stride);
storage_size = at::detail::computeStorageNbytes(
size, *stride, itemsize, storage_offset);
} else {
self->generic_set_sizes_contiguous(size);
storage_size = at::detail::computeStorageNbytesContiguous(
size, itemsize, storage_offset);
}
if (resize_storage) {
_maybe_resize_storage(self, std::move(storage_size));
}
return self;
}
TensorImpl* resize_impl_cpu_(
TensorImpl* self,
IntArrayRef size,
at::OptionalIntArrayRef stride,
bool resize_storage) {
return _resize_impl_(self, size, stride, resize_storage);
}
template <typename T>
const Tensor& _resize_(
const Tensor& self,
ArrayRef<T> size,
std::optional<MemoryFormat> optional_memory_format) {
auto* self_ = self.unsafeGetTensorImpl();
int64_t old_storage_nbytes = self_->unsafe_storage() ? self_->unsafe_storage().sym_nbytes().maybe_as_int().value_or(-1) : 0;
_resize_impl_<T>(self_, size, /*stride=*/std::nullopt, true);
if (optional_memory_format.has_value()) {
auto memory_format =
optional_memory_format.value();
TORCH_CHECK(
memory_format != MemoryFormat::Preserve,
"Unsupported memory format",
memory_format);
self_->empty_tensor_restride(memory_format);
}
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory() && old_storage_nbytes != -1)) {
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
}
return self;
}
const Tensor& resize_(
const Tensor& self,
IntArrayRef size,
std::optional<MemoryFormat> optional_memory_format) {
if (self.has_names()) {
return resize_named_tensor_(self, size, optional_memory_format);
}
return _resize_(self, size, optional_memory_format);
}
const Tensor& resize__symint(
const Tensor& self,
c10::SymIntArrayRef size,
std::optional<MemoryFormat> optional_memory_format) {
TORCH_INTERNAL_ASSERT(!self.has_names())
return _resize_(self, size, optional_memory_format);
}
void resize_bytes_nocuda(const Storage& storage, const c10::SymInt& newsize) {
// handles all devices except cuda (which needs to be in a different .so)
c10::DeviceType device_type = storage.device_type();
if (device_type == at::kCPU) {
at::native::resize_bytes_cpu(storage.unsafeGetStorageImpl(), newsize.expect_int());
} else if (device_type == at::kMeta) {
at::native::resize_bytes_meta(storage.unsafeGetStorageImpl(), newsize);
} else if (device_type == at::kPrivateUse1) {
at::detail::getPrivateUse1Hooks().resizePrivateUse1Bytes(
storage, newsize.expect_int());
} else if (device_type == at::kXPU || device_type == at::kHPU || device_type == at::kMTIA) {
ptrdiff_t size_bytes_i = newsize.expect_int();
TORCH_CHECK(
!c10::overflows<int64_t>(size_bytes_i),
"Requested storage size (",
size_bytes_i,
") cannot be represented as a int64_t");
const auto size_bytes = static_cast<int64_t>(size_bytes_i);
void* original_data_ptr = storage.data_ptr().get();
auto src_option =
c10::TensorOptions().device(storage.device()).dtype(at::kByte);
auto src_tensor = at::empty({0}, src_option).set_(storage);
src_tensor.resize_({size_bytes});
// When using resize_ to replace resize_bytes_xxx, in some cases
// the original data_ptr is still returned, which is an inconsistent
// behavior when compared to resize_bytes_xxx. For these cases,
// an additional memory copy and update for storage are required.
if (original_data_ptr == src_tensor.storage().data_ptr().get()) {
auto new_tensor = at::empty(src_tensor.sizes(), src_tensor.options());
new_tensor.copy_(src_tensor);
storage.set_data_ptr_noswap(
std::move(new_tensor.storage().mutable_data_ptr()));
storage.unsafeGetStorageImpl()->set_allocator(
new_tensor.storage().unsafeGetStorageImpl()->allocator());
storage.set_nbytes(new_tensor.storage().nbytes());
}
} else {
TORCH_CHECK(
false,
"UntypedStorage.resize_: got unexpected device type ",
device_type);
}
}
} // namespace at::native