forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRandperm.cuh
58 lines (50 loc) · 2.07 KB
/
Randperm.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <ATen/Utils.h>
#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
namespace {
// See note [Algorithm of randperm]
template<typename T, typename scalar_t>
__global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T mask, int n, at::PhiloxCudaState philox_args) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
// find the beginning of islands
if (tid >= n - 1) return; // out of range
if ((keys[tid] & mask) != (keys[tid + 1] & mask)) return; // not in an island
if (tid != 0 && (keys[tid] & mask) == (keys[tid - 1] & mask)) return; // not the beginning of an island
// find the size of islands
int island_size = 0;
do { island_size++; }
while ((tid + island_size < n) && (keys[tid + island_size] & mask) == (keys[tid] & mask));
// do random permutation inside each island.
data += tid;
auto seeds = at::cuda::philox::unpack(philox_args);
curandStatePhilox4_32_10_t state;
curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state);
for (int i = island_size - 1; i > 0; i--) {
unsigned int r = curand(&state) % (i + 1);
if (i != r) {
scalar_t tmp = data[i];
data[i] = data[r];
data[r] = tmp;
}
}
}
// See note [Algorithm of randperm]
template<typename T, typename scalar_t>
void randperm_handle_duplicate_keys(T *keys, scalar_t *data, int bits, int64_t n, c10::optional<at::Generator> &gen_) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
int64_t counter_offset = n;
at::PhiloxCudaState rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_cuda_state(counter_offset);
}
T mask = static_cast<T>((1UL << bits) - 1);
randperm_handle_duplicate_keys_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>(
keys, data, mask, n, rng_engine_inputs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}