forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSortImpl.cu
35 lines (32 loc) · 1.06 KB
/
SortImpl.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include <ATen/ATen.h>
#include <thrust/sort.h>
namespace at { namespace native {
std::vector<int64_t> infer_dense_strides_dim_last(const Tensor & self, int64_t dim) {
int64_t ndim = self.dim();
// sort the strides in descending order according to its value,
// keeping dim the last.
std::vector<int64_t> strides = self.strides().vec();
strides[dim] = -1;
std::vector<int64_t> original_dim(ndim);
for (int64_t i = 0; i < ndim; i++) {
original_dim[i] = i;
}
thrust::stable_sort_by_key(
thrust::host, strides.data(), strides.data() + ndim, original_dim.data(),
thrust::greater<int64_t>()
);
// generate contiguous strides on permuted dims
std::vector<int64_t> new_strides(ndim);
std::vector<int64_t> new_strides_unsort(ndim);
int64_t cumprod = 1;
for (int64_t i = 0; i < ndim; i++) {
new_strides[ndim - 1 - i] = cumprod;
cumprod *= self.sizes()[original_dim[ndim - 1 - i]];
}
// unsort new strides
for (int64_t i = 0; i < ndim; i++) {
new_strides_unsort[original_dim[i]] = new_strides[i];
}
return new_strides_unsort;
}
}}