forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTensorShapeCUDA.cpp
37 lines (31 loc) · 1.26 KB
/
TensorShapeCUDA.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/native/Resize.h>
#include <ATen/native/cuda/Resize.h>
namespace at {
namespace native {
// this needs to be split along CPU/CUDA lines because we don't have a consistent
// way of getting the allocator to use for a device (c10::GetAllocator is not
// the same as at::cuda::getCUDADeviceAllocator().
Tensor& set_cuda_(Tensor& result) {
caffe2::TypeMeta dtype = result.dtype();
Storage storage(
Storage::use_byte_size_t(),
0,
at::cuda::getCUDADeviceAllocator(),
true);
result.set_(storage, 0, {0}, {});
TORCH_INTERNAL_ASSERT(dtype == result.dtype());
return result;
}
// unify with cuda implementation? This is not done to avoid a dispatch in resize_impl_cpu_
Tensor& set_storage_cuda_(Tensor& result, Storage storage, int64_t storage_offset, IntArrayRef size, IntArrayRef stride) {
checkSetStorage(result, storage, storage_offset, size, stride);
result.unsafeGetTensorImpl()->set_storage_offset(storage_offset);
c10::optional<IntArrayRef> stride_opt = stride.data() != nullptr ?
c10::optional<IntArrayRef>(stride) : c10::nullopt;
at::native::resize_impl_cuda_(result.unsafeGetTensorImpl(), size, stride_opt);
return result;
}
}
}