Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
Merge pull request #20 from hello-11/refactor_random_generator_op
Browse files Browse the repository at this point in the history
refactor random genertor cpu
  • Loading branch information
BradReesWork authored Jun 5, 2023
2 parents e8350ba + f8af2dd commit d39454a
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 35 deletions.
25 changes: 25 additions & 0 deletions cpp/include/wholememory/wholegraph_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,31 @@ wholememory_error_code_t wholegraph_csr_weighted_sample_without_replacement(
wholememory_env_func_t* p_env_fns,
void* stream);

/**
* raft_pcg_generator_random_int cpu op
* @param random_seed : random seed
* @param subsequence : subsequence for generating random value
* @param output : Wholememory Tensor of output
* @return : wholememory_error_code_t
*/
wholememory_error_code_t generate_random_positive_int_cpu(
int64_t random_seed,
int64_t subsequence,
wholememory_tensor_t output
);

/**
* raft_pcg_generator_random_float cpu op
* @param random_seed : random seed
* @param subsequence : subsequence for generating random value
* @param output : Wholememory Tensor of output
* @return : wholememory_error_code_t
*/
wholememory_error_code_t generate_exponential_distribution_negative_float_cpu(
int64_t random_seed,
int64_t subsequence,
wholememory_tensor_t output);

#ifdef __cplusplus
}
#endif
96 changes: 96 additions & 0 deletions cpp/src/wholegraph_ops/raft_random_gen.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <wholememory/wholegraph_op.h>
#include <cmath>
#include <wholememory_ops/raft_random.cuh>


#include "error.hpp"
#include "logger.hpp"

wholememory_error_code_t generate_random_positive_int_cpu(
int64_t random_seed,
int64_t subsequence,
wholememory_tensor_t output
) {
auto output_tensor_desc = *wholememory_tensor_get_tensor_description(output);
if (output_tensor_desc.dim != 1) {
WHOLEMEMORY_ERROR("output should be 1D tensor.");
return WHOLEMEMORY_INVALID_INPUT;
}
if (output_tensor_desc.dtype != WHOLEMEMORY_DT_INT64 && output_tensor_desc.dtype != WHOLEMEMORY_DT_INT) {
WHOLEMEMORY_ERROR("output should be int64 or int32 tensor.");
return WHOLEMEMORY_INVALID_INPUT;
}

auto* output_ptr = wholememory_tensor_get_data_pointer(output);
PCGenerator rng((unsigned long long)random_seed, subsequence, 0);
for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) {
if (output_tensor_desc.dtype == WHOLEMEMORY_DT_INT) {
int32_t random_num;
rng.next(random_num);
static_cast<int*>(output_ptr)[i] = random_num;
}
else {
int64_t random_num;
rng.next(random_num);
static_cast<int64_t*>(output_ptr)[i] = random_num;
}
}
return WHOLEMEMORY_SUCCESS;
}

wholememory_error_code_t generate_exponential_distribution_negative_float_cpu(
int64_t random_seed,
int64_t subsequence,
wholememory_tensor_t output
) {
auto output_tensor_desc = *wholememory_tensor_get_tensor_description(output);
if (output_tensor_desc.dim != 1) {
WHOLEMEMORY_ERROR("output should be 1D tensor.");
return WHOLEMEMORY_INVALID_INPUT;
}
if (output_tensor_desc.dtype != WHOLEMEMORY_DT_FLOAT) {
WHOLEMEMORY_ERROR("output should be float.");
return WHOLEMEMORY_INVALID_INPUT;
}
auto* output_ptr = wholememory_tensor_get_data_pointer(output);
PCGenerator rng((unsigned long long)random_seed, subsequence, 0);
for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) {
float u = -rng.next_float(1.0f, 0.5f);
uint64_t random_num2 = 0;
int seed_count = -1;
do {
rng.next(random_num2);
seed_count++;
} while (!random_num2);
auto count_one = [](unsigned long long num) {
int32_t c = 0;
while (num) {
num >>= 1;
c++;
}
return 64 - c;
};
int32_t one_bit = count_one(random_num2) + seed_count * 64;
u *= pow(2, -one_bit);
// float logk = (log1pf(u) / logf(2.0)) * (1.0f / (float)weight);
float logk = (log1p(u) / log(2.0));
static_cast<float*>(output_ptr)[i] = logk;
}
return WHOLEMEMORY_SUCCESS;
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ __global__ void large_sample_kernel(wholememory_gref_t wm_csr_row_ptr,
}
__syncthreads();
for (int idx = max_sample_count + threadIdx.x; idx < neighbor_count; idx += blockDim.x) {
int rand_num;
int32_t rand_num;
rng.next(rand_num);
rand_num %= idx + 1;
if (rand_num < max_sample_count) { atomicMax((int*)(output + offset + rand_num), idx); }
Expand Down Expand Up @@ -192,10 +192,10 @@ __global__ void unweighted_sample_without_replacement_kernel(
} shared_data;
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; i++) {
uint32_t idx = i * BLOCK_DIM + threadIdx.x;
uint32_t random_num;
int idx = i * BLOCK_DIM + threadIdx.x;
int32_t random_num;
rng.next(random_num);
uint32_t r = idx < M ? (random_num % (N - idx)) : N;
int32_t r = idx < M ? (random_num % (N - idx)) : N;
sa_p[i] = ((uint64_t)r << 32UL) | idx;
}
__syncthreads();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ template <typename WeightType>
__device__ __forceinline__ float gen_key_from_weight(const WeightType weight, PCGenerator& rng)
{
float u = -rng.next_float(1.0f, 0.5f);
int64_t random_num2 = 0;
uint64_t random_num2 = 0;
int seed_count = -1;
do {
rng.next(random_num2);
Expand Down
22 changes: 4 additions & 18 deletions cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ REGISTER_DISPATCH_TWO_TYPES(HOSTSAMPLEALL, host_sample_all, SINT3264, SINT3264)

template <int Offset = 0>
void random_sample_without_replacement_cpu_base(std::vector<int>* a,
const std::vector<uint32_t>& r,
const std::vector<int32_t>& r,
int M,
int N)
{
Expand All @@ -319,20 +319,6 @@ void random_sample_without_replacement_cpu_base(std::vector<int>* a,
}
}

void random_sample_without_replacement_cpu_base_2(std::vector<int>& a,
const std::vector<int>& r,
int M,
int N)
{
std::vector<int> Q(N);
for (int i = 0; i < N; ++i) {
Q[i] = i;
}
for (int i = 0; i < M; ++i) {
a[i] = Q[r[i]];
Q[r[i]] = Q[N - i - 1];
}
}

template <typename IdType, typename ColIdType>
void host_unweighted_sample_without_replacement(
Expand Down Expand Up @@ -395,14 +381,14 @@ void host_unweighted_sample_without_replacement(
output_local_id++;
}
} else {
std::vector<uint32_t> r(neighbor_count);
std::vector<int32_t> r(neighbor_count);
for (int j = 0; j < device_num_threads; j++) {
int local_gidx = gidx + j;
PCGenerator rng(random_seed, (uint64_t)local_gidx, (uint64_t)0);

for (int k = 0; k < items_per_thread; k++) {
int id = k * device_num_threads + j;
uint32_t random_num;
int32_t random_num;
rng.next(random_num);
if (id < neighbor_count) { r[id] = id < M ? (random_num % (N - id)) : N; }
}
Expand Down Expand Up @@ -560,7 +546,7 @@ template <typename WeightType>
float host_gen_key_from_weight(const WeightType weight, PCGenerator& rng)
{
float u = -rng.next_float(1.0f, 0.5f);
int64_t random_num2 = 0;
uint64_t random_num2 = 0;
int seed_count = -1;
do {
rng.next(random_num2);
Expand Down
33 changes: 33 additions & 0 deletions pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1790,6 +1790,16 @@ cdef extern from "wholememory/wholegraph_op.h":
unsigned long long random_seed,
wholememory_env_func_t * p_env_fns,
void * stream)

cdef wholememory_error_code_t generate_random_positive_int_cpu(
int64_t random_seed,
int64_t subsequence,
wholememory_tensor_t output)

cdef wholememory_error_code_t generate_exponential_distribution_negative_float_cpu(
int64_t random_seed,
int64_t subsequence,
wholememory_tensor_t output)

cpdef void csr_unweighted_sample_without_replacement(
PyWholeMemoryTensor wm_csr_row_ptr_tensor,
Expand Down Expand Up @@ -1845,6 +1855,29 @@ cpdef void csr_weighted_sample_without_replacement(
<wholememory_env_func_t *> p_env_fns_int,
<void *> stream_int))

cpdef void host_generate_random_positive_int(
int64_t random_seed,
int64_t subsequence,
WrappedLocalTensor output
):
check_wholememory_error_code(generate_random_positive_int_cpu(
random_seed,
subsequence,
<wholememory_tensor_t> <int64_t> output.get_c_handle()
))


cpdef void host_generate_exponential_distribution_negative_float(
int64_t random_seed,
int64_t subsequence,
WrappedLocalTensor output
):
check_wholememory_error_code(generate_exponential_distribution_negative_float_cpu(
random_seed,
subsequence,
<wholememory_tensor_t> <int64_t> output.get_c_handle()
))


cdef extern from "wholememory/graph_op.h":
cdef wholememory_error_code_t graph_append_unique(wholememory_tensor_t target_nodes_tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ def host_unweighted_sample_without_replacement_func(
random_values = torch.empty((N,), dtype=torch.int32)
for j in range(block_threads):
local_gidx = gidx + j
random_nums = torch.ops.wholegraph_test.raft_pcg_generator_random(
random_seed, local_gidx, items_per_thread
)
random_nums = wg_ops.generate_random_positive_int_cpu(random_seed, local_gidx, items_per_thread)
for k in range(items_per_thread):
id = k * block_threads + j
if id < neighbor_count:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,10 @@ def host_weighted_sample_without_replacement_func(
edge_weight_corresponding_ids,
torch.tensor([id], dtype=col_id_dtype),
)
)
generated_random_weight = (
torch.ops.wholegraph_test.raft_pcg_generator_random_from_weight(
random_seed,
local_gidx,
local_edge_weights,
generated_edge_weight_count,
)
)
)
random_values = wg_ops.generate_exponential_distribution_negative_float_cpu(random_seed, local_gidx, generated_edge_weight_count)
generated_random_weight = torch.tensor([(1.0/local_edge_weights[i]) * random_values[i] for i in range(generated_edge_weight_count)])

total_neighbor_generated_weights = torch.cat(
(total_neighbor_generated_weights, generated_random_weight)
)
Expand Down
15 changes: 15 additions & 0 deletions pylibwholegraph/pylibwholegraph/torch/wholegraph_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,18 @@ def weighted_sample_without_replacement(
)
else:
return output_sample_offset_tensor, output_dest_context.get_tensor()


def generate_random_positive_int_cpu(random_seed,
sub_sequence,
output_random_value_count):
output = torch.empty((output_random_value_count,), dtype=torch.int)
wmb.host_generate_random_positive_int(random_seed, sub_sequence, wrap_torch_tensor(output))
return output

def generate_exponential_distribution_negative_float_cpu(random_seed: int,
sub_sequence: int,
output_random_value_count: int):
output = torch.empty((output_random_value_count,), dtype = torch.float)
wmb.host_generate_exponential_distribution_negative_float(random_seed, sub_sequence, wrap_torch_tensor(output))
return output

0 comments on commit d39454a

Please sign in to comment.