Skip to content

Commit

Permalink
mdspan-ify rmat_rectangular_gen (#833)
Browse files Browse the repository at this point in the history
I added two overloads of `rmat_rectangular_gen`, one for each of the existing overloads, that take device mdspan instead of raw arrays. I've added a unit test that imitates the existing unit test for the raw-array overloads; it builds and passes.

I also opportunistically fixed some unrelated existing small build errors that were blocking forward progress.

Authors:
  - Mark Hoemmen (https://github.com/mhoemmen)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #833
  • Loading branch information
mhoemmen authored Sep 27, 2022
1 parent 1dd2feb commit faa4d9d
Show file tree
Hide file tree
Showing 4 changed files with 680 additions and 24 deletions.
108 changes: 108 additions & 0 deletions cpp/include/raft/random/detail/rmat_rectangular_generator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

#pragma once

#include "rmat_rectangular_generator_types.cuh"

#include <raft/core/handle.hpp>
#include <raft/random/rng_device.cuh>
#include <raft/random/rng_state.hpp>
#include <raft/util/cuda_utils.cuh>
Expand Down Expand Up @@ -182,6 +185,111 @@ void rmat_rectangular_gen_caller(IdxT* out,
r.advance(n_edges, max_scale);
}

/**
* @brief Implementation of `raft::random::rmat_rectangular_gen_impl`.
*
* @tparam IdxT type of each node index
* @tparam ProbT data type used for probability distributions (either fp32 or fp64)
* @param[in] handle RAFT handle, containing the CUDA stream on which to schedule work
* @param[in] r underlying state of the random generator. Especially useful when
* one wants to call this API for multiple times in order to generate
* a larger graph. For that case, just create this object with the
* initial seed once and after every call continue to pass the same
* object for the successive calls.
* @param[out] output Encapsulation of one, two, or three output vectors.
* @param[in] theta distribution of each quadrant at each level of resolution.
* Since these are probabilities, each of the 2x2 matrices for
* each level of the RMAT must sum to one. [on device]
* [dim = max(r_scale, c_scale) x 2 x 2]. Of course, it is assumed
* that each of the group of 2 x 2 numbers all sum up to 1.
* @param[in] r_scale 2^r_scale represents the number of source nodes
* @param[in] c_scale 2^c_scale represents the number of destination nodes
*/
template <typename IdxT, typename ProbT>
void rmat_rectangular_gen_impl(const raft::handle_t& handle,
raft::random::RngState& r,
raft::device_vector_view<const ProbT, IdxT> theta,
raft::random::detail::rmat_rectangular_gen_output<IdxT> output,
IdxT r_scale,
IdxT c_scale)
{
static_assert(std::is_integral_v<IdxT>,
"rmat_rectangular_gen: "
"Template parameter IdxT must be an integral type");
if (output.empty()) {
return; // nothing to do; not an error
}

const IdxT expected_theta_len = IdxT(4) * (r_scale >= c_scale ? r_scale : c_scale);
RAFT_EXPECTS(theta.extent(0) == expected_theta_len,
"rmat_rectangular_gen: "
"theta.extent(0) = %zu != 2 * 2 * max(r_scale = %zu, c_scale = %zu) = %zu",
static_cast<std::size_t>(theta.extent(0)),
static_cast<std::size_t>(r_scale),
static_cast<std::size_t>(c_scale),
static_cast<std::size_t>(expected_theta_len));

auto out = output.out_view();
auto out_src = output.out_src_view();
auto out_dst = output.out_dst_view();
const bool out_has_value = out.has_value();
const bool out_src_has_value = out_src.has_value();
const bool out_dst_has_value = out_dst.has_value();
IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr;
IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr;
IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr;
const IdxT n_edges = output.number_of_edges();

rmat_rectangular_gen_caller(out_ptr,
out_src_ptr,
out_dst_ptr,
theta.data_handle(),
r_scale,
c_scale,
n_edges,
handle.get_stream(),
r);
}

/**
* @brief Overload of `rmat_rectangular_gen` that assumes the same
* a, b, c, d probability distributions across all the scales.
*
* `a`, `b, and `c` effectively replace the above overload's
* `theta` parameter.
*/
template <typename IdxT, typename ProbT>
void rmat_rectangular_gen_impl(const raft::handle_t& handle,
raft::random::RngState& r,
raft::random::detail::rmat_rectangular_gen_output<IdxT> output,
ProbT a,
ProbT b,
ProbT c,
IdxT r_scale,
IdxT c_scale)
{
static_assert(std::is_integral_v<IdxT>,
"rmat_rectangular_gen: "
"Template parameter IdxT must be an integral type");
if (output.empty()) {
return; // nothing to do; not an error
}

auto out = output.out_view();
auto out_src = output.out_src_view();
auto out_dst = output.out_dst_view();
const bool out_has_value = out.has_value();
const bool out_src_has_value = out_src.has_value();
const bool out_dst_has_value = out_dst.has_value();
IdxT* out_ptr = out_has_value ? (*out).data_handle() : nullptr;
IdxT* out_src_ptr = out_src_has_value ? (*out_src).data_handle() : nullptr;
IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr;
const IdxT n_edges = output.number_of_edges();

detail::rmat_rectangular_gen_caller(
out_ptr, out_src_ptr, out_dst_ptr, a, b, c, r_scale, c_scale, n_edges, handle.get_stream(), r);
}

} // end namespace detail
} // end namespace random
} // end namespace raft
259 changes: 259 additions & 0 deletions cpp/include/raft/random/detail/rmat_rectangular_generator_types.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/random/rng_device.cuh>
#include <raft/random/rng_state.hpp>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <optional>
#include <variant>

namespace raft {
namespace random {
namespace detail {

/**
* @brief Implementation detail for checking output vector parameter(s)
* of `raft::random::rmat_rectangular_gen`.
*
* `raft::random::rmat_rectangular_gen` lets users specify
* output vector(s) in three different ways.
*
* 1. One vector: `out`, an "array-of-structs" representation
* of the edge list.
*
* 2. Two vectors: `out_src` and `out_dst`, together forming
* a "struct of arrays" representation of the edge list.
*
* 3. Three vectors: `out`, `out_src`, and `out_dst`.
* `out` is as in (1),
* and `out_src` and `out_dst` are as in (2).
*
* This class prevents users from doing anything other than that,
* and makes it easier for the three cases to share a common implementation.
* It also prevents duplication of run-time vector length checking
* (`out` must have twice the number of elements as `out_src` and `out_dst`,
* and `out_src` and `out_dst` must have the same length).
*
* @tparam IdxT Type of each node index; must be integral.
*
* The following examples show how to create an output parameter.
*
* @code
* rmat_rectangular_gen_output<IdxT> output1(out);
* rmat_rectangular_gen_output<IdxT> output2(out_src, out_dst);
* rmat_rectangular_gen_output<IdxT> output3(out, out_src, out_dst);
* @endcode
*/
template <typename IdxT>
class rmat_rectangular_gen_output {
public:
using out_view_type =
raft::device_mdspan<IdxT, raft::extents<IdxT, raft::dynamic_extent, 2>, raft::row_major>;
using out_src_view_type = raft::device_vector_view<IdxT, IdxT>;
using out_dst_view_type = raft::device_vector_view<IdxT, IdxT>;

private:
class output_pair {
public:
output_pair(const out_src_view_type& src, const out_dst_view_type& dst) : src_(src), dst_(dst)
{
RAFT_EXPECTS(src.extent(0) == dst.extent(0),
"rmat_rectangular_gen: "
"out_src.extent(0) = %zu != out_dst.extent(0) = %zu",
static_cast<std::size_t>(src.extent(0)),
static_cast<std::size_t>(dst.extent(0)));
}

out_src_view_type out_src_view() const { return src_; }

out_dst_view_type out_dst_view() const { return dst_; }

IdxT number_of_edges() const { return src_.extent(0); }

bool empty() const { return src_.extent(0) == 0 && dst_.extent(0) == 0; }

private:
out_src_view_type src_;
out_dst_view_type dst_;
};

class output_triple {
public:
output_triple(const out_view_type& out,
const out_src_view_type& src,
const out_dst_view_type& dst)
: out_(out), pair_(src, dst)
{
RAFT_EXPECTS(out.extent(0) == IdxT(2) * dst.extent(0),
"rmat_rectangular_gen: "
"out.extent(0) = %zu != 2 * out_dst.extent(0) = %zu",
static_cast<std::size_t>(out.extent(0)),
static_cast<std::size_t>(IdxT(2) * dst.extent(0)));
}

out_view_type out_view() const { return out_; }

out_src_view_type out_src_view() const { return pair_.out_src_view(); }

out_dst_view_type out_dst_view() const { return pair_.out_dst_view(); }

IdxT number_of_edges() const { return pair_.number_of_edges(); }

bool empty() const { return out_.extent(0) == 0 && pair_.empty(); }

private:
out_view_type out_;
output_pair pair_;
};

public:
/**
* @brief You're not allowed to construct this with no vectors.
*/
rmat_rectangular_gen_output() = delete;

/**
* @brief Constructor taking a single vector, that packs the source
* node ids and destination node ids in array-of-structs fashion.
*
* @param[out] out Generated edgelist [on device]. In each row, the
* first element is the source node id, and the second element is
* the destination node id.
*/
rmat_rectangular_gen_output(const out_view_type& out) : data_(out) {}

/**
* @brief Constructor taking two vectors, that store the source node
* ids and the destination node ids separately, in
* struct-of-arrays fashion.
*
* @param[out] out_src Source node id's [on device] [len = n_edges].
*
* @param[out] out_dst Destination node id's [on device] [len = n_edges].
*/
rmat_rectangular_gen_output(const out_src_view_type& src, const out_dst_view_type& dst)
: data_(output_pair(src, dst))
{
}

/**
* @brief Constructor taking all three vectors.
*
* @param[out] out Generated edgelist [on device]. In each row, the
* first element is the source node id, and the second element is
* the destination node id.
*
* @param[out] out_src Source node id's [on device] [len = n_edges].
*
* @param[out] out_dst Destination node id's [on device] [len = n_edges].
*/
rmat_rectangular_gen_output(const out_view_type& out,
const out_src_view_type& src,
const out_dst_view_type& dst)
: data_(output_triple(out, src, dst))
{
}

/**
* @brief Whether the vector(s) are all length zero.
*/
bool empty() const
{
if (std::holds_alternative<out_view_type>(data_)) {
return std::get<out_view_type>(data_).extent(0) == 0;
} else if (std::holds_alternative<output_pair>(data_)) {
return std::get<output_pair>(data_).empty();
} else { // std::holds_alternative<output_triple>(data_)
return std::get<output_triple>(data_).empty();
}
}

/**
* @brief Vector for the output single edgelist; the argument given
* to the one-argument constructor, or the first argument of the
* three-argument constructor; `std::nullopt` if not provided.
*/
std::optional<out_view_type> out_view() const
{
if (std::holds_alternative<out_view_type>(data_)) {
return std::get<out_view_type>(data_);
} else if (std::holds_alternative<output_triple>(data_)) {
return std::get<output_triple>(data_).out_view();
} else { // if (std::holds_alternative<>(output_pair))
return std::nullopt;
}
}

/**
* @brief Vector for the output source edgelist; the first argument
* given to the two-argument constructor, or the second argument
* of the three-argument constructor; `std::nullopt` if not provided.
*/
std::optional<out_src_view_type> out_src_view() const
{
if (std::holds_alternative<output_pair>(data_)) {
return std::get<output_pair>(data_).out_src_view();
} else if (std::holds_alternative<output_triple>(data_)) {
return std::get<output_triple>(data_).out_src_view();
} else { // if (std::holds_alternative<out_view_type>(data_))
return std::nullopt;
}
}

/**
* @brief Vector for the output destination edgelist; the second
* argument given to the two-argument constructor, or the third
* argument of the three-argument constructor;
* `std::nullopt` if not provided.
*/
std::optional<out_dst_view_type> out_dst_view() const
{
if (std::holds_alternative<output_pair>(data_)) {
return std::get<output_pair>(data_).out_dst_view();
} else if (std::holds_alternative<output_triple>(data_)) {
return std::get<output_triple>(data_).out_dst_view();
} else { // if (std::holds_alternative<out_view_type>(data_))
return std::nullopt;
}
}

/**
* @brief Number of edges in the graph; zero if no output vector
* was provided to the constructor.
*/
IdxT number_of_edges() const
{
if (std::holds_alternative<out_view_type>(data_)) {
return std::get<out_view_type>(data_).extent(0);
} else if (std::holds_alternative<output_pair>(data_)) {
return std::get<output_pair>(data_).number_of_edges();
} else { // if (std::holds_alternative<output_triple>(data_))
return std::get<output_triple>(data_).number_of_edges();
}
}

private:
std::variant<out_view_type, output_pair, output_triple> data_;
};

} // end namespace detail
} // end namespace random
} // end namespace raft
Loading

0 comments on commit faa4d9d

Please sign in to comment.