Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

refactor random genertor cpu #20

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions cpp/include/wholememory/wholegraph_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,31 @@ wholememory_error_code_t wholegraph_csr_weighted_sample_without_replacement(
wholememory_env_func_t* p_env_fns,
void* stream);

/**
* raft_pcg_generator_random_int cpu op
* @param random_seed : random seed
* @param subsequence : subsequence for generating random value
* @param output : Wholememory Tensor of output
* @return : wholememory_error_code_t
*/
wholememory_error_code_t generate_random_positive_int_cpu(
int64_t random_seed,
int64_t subsequence,
wholememory_tensor_t output
);

/**
* raft_pcg_generator_random_float cpu op
* @param random_seed : random seed
* @param subsequence : subsequence for generating random value
* @param output : Wholememory Tensor of output
* @return : wholememory_error_code_t
*/
wholememory_error_code_t generate_exponential_distribution_negative_float_cpu(
int64_t random_seed,
int64_t subsequence,
wholememory_tensor_t output);

#ifdef __cplusplus
}
#endif
96 changes: 96 additions & 0 deletions cpp/src/wholegraph_ops/raft_random_gen.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <wholememory/wholegraph_op.h>
#include <cmath>
#include <wholememory_ops/raft_random.cuh>


#include "error.hpp"
#include "logger.hpp"

wholememory_error_code_t generate_random_positive_int_cpu(
int64_t random_seed,
int64_t subsequence,
wholememory_tensor_t output
) {
auto output_tensor_desc = *wholememory_tensor_get_tensor_description(output);
if (output_tensor_desc.dim != 1) {
WHOLEMEMORY_ERROR("output should be 1D tensor.");
return WHOLEMEMORY_INVALID_INPUT;
}
if (output_tensor_desc.dtype != WHOLEMEMORY_DT_INT64 && output_tensor_desc.dtype != WHOLEMEMORY_DT_INT) {
WHOLEMEMORY_ERROR("output should be int64 or int32 tensor.");
return WHOLEMEMORY_INVALID_INPUT;
}

auto* output_ptr = wholememory_tensor_get_data_pointer(output);
PCGenerator rng((unsigned long long)random_seed, subsequence, 0);
for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) {
if (output_tensor_desc.dtype == WHOLEMEMORY_DT_INT) {
int32_t random_num;
rng.next(random_num);
static_cast<int*>(output_ptr)[i] = random_num;
}
else {
int64_t random_num;
rng.next(random_num);
static_cast<int64_t*>(output_ptr)[i] = random_num;
}
}
return WHOLEMEMORY_SUCCESS;
}

wholememory_error_code_t generate_exponential_distribution_negative_float_cpu(
int64_t random_seed,
int64_t subsequence,
wholememory_tensor_t output
) {
auto output_tensor_desc = *wholememory_tensor_get_tensor_description(output);
if (output_tensor_desc.dim != 1) {
WHOLEMEMORY_ERROR("output should be 1D tensor.");
return WHOLEMEMORY_INVALID_INPUT;
}
if (output_tensor_desc.dtype != WHOLEMEMORY_DT_FLOAT) {
WHOLEMEMORY_ERROR("output should be float.");
return WHOLEMEMORY_INVALID_INPUT;
}
auto* output_ptr = wholememory_tensor_get_data_pointer(output);
PCGenerator rng((unsigned long long)random_seed, subsequence, 0);
for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) {
float u = -rng.next_float(1.0f, 0.5f);
uint64_t random_num2 = 0;
int seed_count = -1;
do {
rng.next(random_num2);
seed_count++;
} while (!random_num2);
auto count_one = [](unsigned long long num) {
int32_t c = 0;
while (num) {
num >>= 1;
c++;
}
return 64 - c;
};
int32_t one_bit = count_one(random_num2) + seed_count * 64;
u *= pow(2, -one_bit);
// float logk = (log1pf(u) / logf(2.0)) * (1.0f / (float)weight);
float logk = (log1p(u) / log(2.0));
static_cast<float*>(output_ptr)[i] = logk;
}
return WHOLEMEMORY_SUCCESS;
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ __global__ void large_sample_kernel(wholememory_gref_t wm_csr_row_ptr,
}
__syncthreads();
for (int idx = max_sample_count + threadIdx.x; idx < neighbor_count; idx += blockDim.x) {
int rand_num;
int32_t rand_num;
rng.next(rand_num);
rand_num %= idx + 1;
if (rand_num < max_sample_count) { atomicMax((int*)(output + offset + rand_num), idx); }
Expand Down Expand Up @@ -192,10 +192,10 @@ __global__ void unweighted_sample_without_replacement_kernel(
} shared_data;
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; i++) {
uint32_t idx = i * BLOCK_DIM + threadIdx.x;
uint32_t random_num;
int idx = i * BLOCK_DIM + threadIdx.x;
int32_t random_num;
rng.next(random_num);
uint32_t r = idx < M ? (random_num % (N - idx)) : N;
int32_t r = idx < M ? (random_num % (N - idx)) : N;
sa_p[i] = ((uint64_t)r << 32UL) | idx;
}
__syncthreads();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ template <typename WeightType>
__device__ __forceinline__ float gen_key_from_weight(const WeightType weight, PCGenerator& rng)
{
float u = -rng.next_float(1.0f, 0.5f);
int64_t random_num2 = 0;
uint64_t random_num2 = 0;
int seed_count = -1;
do {
rng.next(random_num2);
Expand Down
22 changes: 4 additions & 18 deletions cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ REGISTER_DISPATCH_TWO_TYPES(HOSTSAMPLEALL, host_sample_all, SINT3264, SINT3264)

template <int Offset = 0>
void random_sample_without_replacement_cpu_base(std::vector<int>* a,
const std::vector<uint32_t>& r,
const std::vector<int32_t>& r,
int M,
int N)
{
Expand All @@ -319,20 +319,6 @@ void random_sample_without_replacement_cpu_base(std::vector<int>* a,
}
}

void random_sample_without_replacement_cpu_base_2(std::vector<int>& a,
const std::vector<int>& r,
int M,
int N)
{
std::vector<int> Q(N);
for (int i = 0; i < N; ++i) {
Q[i] = i;
}
for (int i = 0; i < M; ++i) {
a[i] = Q[r[i]];
Q[r[i]] = Q[N - i - 1];
}
}

template <typename IdType, typename ColIdType>
void host_unweighted_sample_without_replacement(
Expand Down Expand Up @@ -395,14 +381,14 @@ void host_unweighted_sample_without_replacement(
output_local_id++;
}
} else {
std::vector<uint32_t> r(neighbor_count);
std::vector<int32_t> r(neighbor_count);
for (int j = 0; j < device_num_threads; j++) {
int local_gidx = gidx + j;
PCGenerator rng(random_seed, (uint64_t)local_gidx, (uint64_t)0);

for (int k = 0; k < items_per_thread; k++) {
int id = k * device_num_threads + j;
uint32_t random_num;
int32_t random_num;
rng.next(random_num);
if (id < neighbor_count) { r[id] = id < M ? (random_num % (N - id)) : N; }
}
Expand Down Expand Up @@ -560,7 +546,7 @@ template <typename WeightType>
float host_gen_key_from_weight(const WeightType weight, PCGenerator& rng)
{
float u = -rng.next_float(1.0f, 0.5f);
int64_t random_num2 = 0;
uint64_t random_num2 = 0;
int seed_count = -1;
do {
rng.next(random_num2);
Expand Down
33 changes: 33 additions & 0 deletions pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1790,6 +1790,16 @@ cdef extern from "wholememory/wholegraph_op.h":
unsigned long long random_seed,
wholememory_env_func_t * p_env_fns,
void * stream)

cdef wholememory_error_code_t generate_random_positive_int_cpu(
int64_t random_seed,
int64_t subsequence,
wholememory_tensor_t output)

cdef wholememory_error_code_t generate_exponential_distribution_negative_float_cpu(
int64_t random_seed,
int64_t subsequence,
wholememory_tensor_t output)

cpdef void csr_unweighted_sample_without_replacement(
PyWholeMemoryTensor wm_csr_row_ptr_tensor,
Expand Down Expand Up @@ -1845,6 +1855,29 @@ cpdef void csr_weighted_sample_without_replacement(
<wholememory_env_func_t *> p_env_fns_int,
<void *> stream_int))

cpdef void host_generate_random_positive_int(
int64_t random_seed,
int64_t subsequence,
WrappedLocalTensor output
):
check_wholememory_error_code(generate_random_positive_int_cpu(
random_seed,
subsequence,
<wholememory_tensor_t> <int64_t> output.get_c_handle()
))


cpdef void host_generate_exponential_distribution_negative_float(
int64_t random_seed,
int64_t subsequence,
WrappedLocalTensor output
):
check_wholememory_error_code(generate_exponential_distribution_negative_float_cpu(
random_seed,
subsequence,
<wholememory_tensor_t> <int64_t> output.get_c_handle()
))


cdef extern from "wholememory/graph_op.h":
cdef wholememory_error_code_t graph_append_unique(wholememory_tensor_t target_nodes_tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ def host_unweighted_sample_without_replacement_func(
random_values = torch.empty((N,), dtype=torch.int32)
for j in range(block_threads):
local_gidx = gidx + j
random_nums = torch.ops.wholegraph_test.raft_pcg_generator_random(
random_seed, local_gidx, items_per_thread
)
random_nums = wg_ops.generate_random_positive_int_cpu(random_seed, local_gidx, items_per_thread)
for k in range(items_per_thread):
id = k * block_threads + j
if id < neighbor_count:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,10 @@ def host_weighted_sample_without_replacement_func(
edge_weight_corresponding_ids,
torch.tensor([id], dtype=col_id_dtype),
)
)
generated_random_weight = (
torch.ops.wholegraph_test.raft_pcg_generator_random_from_weight(
random_seed,
local_gidx,
local_edge_weights,
generated_edge_weight_count,
)
)
)
random_values = wg_ops.generate_exponential_distribution_negative_float_cpu(random_seed, local_gidx, generated_edge_weight_count)
generated_random_weight = torch.tensor([(1.0/local_edge_weights[i]) * random_values[i] for i in range(generated_edge_weight_count)])

total_neighbor_generated_weights = torch.cat(
(total_neighbor_generated_weights, generated_random_weight)
)
Expand Down
15 changes: 15 additions & 0 deletions pylibwholegraph/pylibwholegraph/torch/wholegraph_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,18 @@ def weighted_sample_without_replacement(
)
else:
return output_sample_offset_tensor, output_dest_context.get_tensor()


def generate_random_positive_int_cpu(random_seed,
sub_sequence,
output_random_value_count):
output = torch.empty((output_random_value_count,), dtype=torch.int)
wmb.host_generate_random_positive_int(random_seed, sub_sequence, wrap_torch_tensor(output))
return output

def generate_exponential_distribution_negative_float_cpu(random_seed: int,
sub_sequence: int,
output_random_value_count: int):
output = torch.empty((output_random_value_count,), dtype = torch.float)
wmb.host_generate_exponential_distribution_negative_float(random_seed, sub_sequence, wrap_torch_tensor(output))
return output