forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSorting.cpp
195 lines (163 loc) · 5.82 KB
/
Sorting.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
#include <ATen/native/cuda/Sorting.h>
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/native/SortingUtils.h>
#include <ATen/native/ReduceOpsUtils.h>
namespace at {
namespace native {
namespace {
std::tuple<Tensor&, Tensor&> kthvalue_out_impl_cuda(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t k,
int64_t dim_,
bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
int64_t slicesize = self.dim() == 0 ? 1 : self.size(dim);
zero_numel_check_dims(self, dim, "kthvalue()");
TORCH_CHECK(k >= 1 && k <= slicesize,
"kthvalue(): selected number k out of range for dimension ", dim);
at::assert_no_overlap(self, values);
_reduction_with_indices_allocate_or_resize_output(
values, indices, self, dim, keepdim);
if (self.dim() == 0 && self.numel() == 1) {
values.copy_(self);
indices.zero_();
return std::forward_as_tuple(values, indices);
}
TORCH_CHECK(
self.dim() <= MAX_TENSORINFO_DIMS,
"cannot operate on more than ",
MAX_TENSORINFO_DIMS,
" dimensions");
// Based on required index size, run the algorithm with the
// appropriate index type
if (self.numel() != 0) {
launch_kthvalue_kernel(values, indices, self, dim, k);
}
if (!keepdim) {
values.squeeze_(dim);
indices.squeeze_(dim);
}
return std::forward_as_tuple(values, indices);
}
std::tuple<Tensor&, Tensor&> median_with_indices_impl(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t dim,
bool keepdim,
bool ignore_nan) {
// See note [Writing Nondeterministic Operations]
// If there are duplicate elements of a median value, the procedure for choosing which
// of the duplicates to use for the indices output is nondeterministic.
at::globalContext().alertNotDeterministic("median CUDA with indices output");
NoNamesGuard guard;
dim = at::maybe_wrap_dim(dim, self.dim());
Tensor in = self.dim() > 0 ? self.contiguous() : self.unsqueeze(0);
checkDeviceType("median", {values, indices}, self.device().type());
checkScalarType("median", {indices, "indices", 1}, kLong);
checkSameType("median", {values, "values", 0}, {self, "self", 2});
TORCH_CHECK(
self.dim() <= MAX_TENSORINFO_DIMS,
"median() cannot operate on more than ",
MAX_TENSORINFO_DIMS,
" dimensions");
std::vector<int64_t> out_shape = self.sizes().vec();
zero_numel_check_dims(self, dim, "median()");
if (self.dim() > 0) {
assert(dim >= 0);
assert(dim < static_cast<int64_t>(out_shape.size()));
if (keepdim) {
out_shape[dim] = 1;
} else {
out_shape.erase(out_shape.begin() + dim);
}
}
values.resize_(out_shape);
indices.resize_(out_shape);
// Only launch kernel for non-empty tensors
if (self.numel() > 0) {
// Ensure #dim is the same for all tensors required for reduction
Tensor vals = keepdim && self.dim() > 0 ? values : values.unsqueeze(dim);
Tensor inds = keepdim && self.dim() > 0 ? indices : indices.unsqueeze(dim);
launch_median_kernel(vals, inds, in, dim, ignore_nan);
}
guard.reset();
namedinference::propagate_names_for_reduction(values, self, dim, keepdim);
namedinference::propagate_names_for_reduction(indices, self, dim, keepdim);
return std::forward_as_tuple(values, indices);
}
Tensor median_impl(const Tensor& self, bool ignore_nan) {
NoNamesGuard guard;
int64_t size = self.numel();
// Return nan for empty tensors
if (size <= 0) {
return at::full({}, std::numeric_limits<float>::quiet_NaN()).to(self.options());
}
// Sort input tensor to efficiently query for median element
Tensor sorted = std::get<0>(self.flatten().sort());
if (!ignore_nan) {
// For torch.median return either the middle element or nan (sorted as
// largest) if there are any
int64_t k = (size - 1) / 2;
return at::where(sorted[-1].isnan(), sorted[-1], sorted[k]);
} else {
// For torch.nanmedian return the middle element among the non-nan values
int64_t k = ((size - 1) - sorted.isnan().sum().item<int64_t>()) / 2;
return sorted[k].clone(); // Clone so we aren't keeping `sorted` alive
}
}
} // namespace (anonymous)
std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
const Tensor& self,
int64_t k,
int64_t dim,
bool keepdim,
Tensor& values,
Tensor& indices) {
// See note [Writing Nondeterministic Operations]
// If there are duplicate elements of the kth value, the procedure for choosing which
// of the duplicates to use for the indices output is nondeterministic.
at::globalContext().alertNotDeterministic("kthvalue CUDA");
auto result = [&]() {
NoNamesGuard guard;
// `kthvalue_out_impl_cuda` expects contiguous in input `self`.
return kthvalue_out_impl_cuda(values, indices, self.contiguous(), k, dim, keepdim);
}();
namedinference::propagate_names_for_reduction(values, self, dim, keepdim);
namedinference::propagate_names_for_reduction(indices, self, dim, keepdim);
return result;
}
// Mark: median
std::tuple<Tensor&, Tensor&> median_out_cuda(
const Tensor& self,
int64_t dim,
bool keepdim,
Tensor& values,
Tensor& indices) {
return median_with_indices_impl(
values, indices, self, dim, keepdim, /*ignore_nan=*/false);
}
Tensor median_cuda(const Tensor& self) {
return median_impl(self, /*ignore_nan=*/false);
}
std::tuple<Tensor&, Tensor&> nanmedian_out_cuda(
const Tensor& self,
int64_t dim,
bool keepdim,
Tensor& values,
Tensor& indices) {
return median_with_indices_impl(
values, indices, self, dim, keepdim, /*ignore_nan=*/true);
}
Tensor nanmedian_cuda(const Tensor& self) {
return median_impl(self, /*ignore_nan=*/true);
}
} // namespace native
} // namespace at