Skip to content

Commit

Permalink
Mdspanify permute (#834)
Browse files Browse the repository at this point in the history
I added an overload of `raft::random::permute` that takes device mdspan instead of raw arrays, and `std::optional<mdspan<...>>` for optional output arrays.  I also added two overloads that let users pass in `std::nullopt`.

I've added a unit test that imitates the existing unit test for the raw-array overloads; it builds and passes.  The test ensures that the two `std::nullopt` overloads also compile.

I also opportunistically fixed some unrelated existing small build errors that were blocking forward progress.  That commit is included in PRs #830 and #833 as well, so if it merges, the change should rebase correctly.

Authors:
  - Mark Hoemmen (https://github.com/mhoemmen)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #834
  • Loading branch information
mhoemmen authored Sep 23, 2022
1 parent a3b6593 commit bfd90a7
Show file tree
Hide file tree
Showing 3 changed files with 282 additions and 40 deletions.
177 changes: 155 additions & 22 deletions cpp/include/raft/random/permute.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,163 @@

#include "detail/permute.cuh"

#include <optional>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <type_traits>

namespace raft::random {

/**
* @brief Generate permutations of the input array. Pretty useful primitive for
* shuffling the input datasets in ML algos. See note at the end for some of its
* limitations!
* @tparam Type Data type of the array to be shuffled
* @tparam IntType Integer type used for ther perms array
* @tparam IdxType Integer type used for addressing indices
* @tparam TPB threads per block
* @param perms the output permutation indices. Typically useful only when
* one wants to refer back. If you don't need this, pass a nullptr
* @param out the output shuffled array. Pass nullptr if you don't want this to
* be written. For eg: when you only want the perms array to be filled.
* @param in input array (in-place is not supported due to race conditions!)
* @param D number of columns of the input array
* @param N length of the input array (or number of rows)
* @param rowMajor whether the input/output matrices are row or col major
* @param stream cuda stream where to launch the work
*
* @note This is NOT a uniform permutation generator! In fact, it only generates
* very small percentage of permutations. If your application really requires a
* high quality permutation generator, it is recommended that you pick
* Knuth Shuffle.
* @brief Randomly permute the rows of the input matrix.
*
* We do not support in-place permutation, so that we can compute
* in parallel without race conditions. This function is useful
* for shuffling input data sets in machine learning algorithms.
*
* @tparam InputOutputValueType Type of each element of the input matrix,
* and the type of each element of the output matrix (if provided)
* @tparam IntType Integer type of each element of `permsOut`
* @tparam IdxType Integer type of the extents of the mdspan parameters
* @tparam Layout Either `raft::row_major` or `raft::col_major`
*
* @param[in] handle RAFT handle containing the CUDA stream
* on which to run.
* @param[in] in input matrix
* @param[out] permsOut If provided, the indices of the permutation.
* @param[out] out If provided, the output matrix, containing the
* permuted rows of the input matrix `in`. (Not providing this
* is only useful if you provide `permsOut`.)
*
* @pre If `permsOut.has_value()` is `true`,
* then `(*permsOut).extent(0) == in.extent(0)` is `true`.
*
* @pre If `out.has_value()` is `true`,
* then `(*out).extents() == in.extents()` is `true`.
*
* @note This is NOT a uniform permutation generator!
* It only generates a small fraction of all possible random permutations.
* If your application needs a high-quality permutation generator,
* then we recommend Knuth Shuffle.
*/
template <typename InputOutputValueType, typename IntType, typename IdxType, typename Layout>
void permute(const raft::handle_t& handle,
raft::device_matrix_view<const InputOutputValueType, IdxType, Layout> in,
std::optional<raft::device_vector_view<IntType, IdxType>> permsOut,
std::optional<raft::device_matrix_view<InputOutputValueType, IdxType, Layout>> out)
{
static_assert(std::is_integral_v<IntType>,
"permute: The type of each element "
"of permsOut (if provided) must be an integral type.");
static_assert(std::is_integral_v<IdxType>,
"permute: The index type "
"of each mdspan argument must be an integral type.");
constexpr bool is_row_major = std::is_same_v<Layout, raft::row_major>;
constexpr bool is_col_major = std::is_same_v<Layout, raft::col_major>;
static_assert(is_row_major || is_col_major,
"permute: Layout must be either "
"raft::row_major or raft::col_major (or one of their aliases)");

const bool permsOut_has_value = permsOut.has_value();
const bool out_has_value = out.has_value();

RAFT_EXPECTS(!permsOut_has_value || (*permsOut).extent(0) == in.extent(0),
"permute: If 'permsOut' is provided, then its extent(0) "
"must equal the number of rows of the input matrix 'in'.");
RAFT_EXPECTS(!out_has_value || (*out).extents() == in.extents(),
"permute: If 'out' is provided, then both its extents "
"must match the extents of the input matrix 'in'.");

IntType* permsOut_ptr = permsOut_has_value ? (*permsOut).data_handle() : nullptr;
InputOutputValueType* out_ptr = out_has_value ? (*out).data_handle() : nullptr;

if (permsOut_ptr != nullptr || out_ptr != nullptr) {
const IdxType N = in.extent(0);
const IdxType D = in.extent(1);
detail::permute<InputOutputValueType, IntType, IdxType>(
permsOut_ptr, out_ptr, in.data_handle(), D, N, is_row_major, handle.get_stream());
}
}

namespace permute_impl {

template <typename T, typename InputOutputValueType, typename IdxType, typename Layout>
struct perms_out_view {
};

template <typename InputOutputValueType, typename IdxType, typename Layout>
struct perms_out_view<std::nullopt_t, InputOutputValueType, IdxType, Layout> {
// permsOut won't have a value anyway,
// so we can pick any integral value type we want.
using type = raft::device_vector_view<IdxType, IdxType>;
};

template <typename PermutationIndexType,
typename InputOutputValueType,
typename IdxType,
typename Layout>
struct perms_out_view<std::optional<raft::device_vector_view<PermutationIndexType, IdxType>>,
InputOutputValueType,
IdxType,
Layout> {
using type = raft::device_vector_view<PermutationIndexType, IdxType>;
};

template <typename T, typename InputOutputValueType, typename IdxType, typename Layout>
using perms_out_view_t = typename perms_out_view<T, InputOutputValueType, IdxType, Layout>::type;

} // namespace permute_impl

/**
* @brief Overload of `permute` that compiles if users pass in `std::nullopt`
* for either or both of `permsOut` and `out`.
*/
template <typename InputOutputValueType,
typename IdxType,
typename Layout,
typename PermsOutType,
typename OutType>
void permute(const raft::handle_t& handle,
raft::device_matrix_view<const InputOutputValueType, IdxType, Layout> in,
PermsOutType&& permsOut,
OutType&& out)
{
// If PermsOutType is std::optional<device_vector_view<T, IdxType>>
// for some T, then that type T need not be related to any of the
// other template parameters. Thus, we have to deduce it specially.
using perms_out_view_type = permute_impl::
perms_out_view_t<std::decay_t<PermsOutType>, InputOutputValueType, IdxType, Layout>;
using out_view_type = raft::device_matrix_view<InputOutputValueType, IdxType, Layout>;

static_assert(std::is_same_v<std::decay_t<OutType>, std::nullopt_t> ||
std::is_same_v<std::decay_t<OutType>, std::optional<out_view_type>>,
"permute: The type of 'out' must be either std::optional<"
"raft::device_matrix_view<InputOutputViewType, IdxType, Layout>>, "
"or std::nullopt.");

std::optional<perms_out_view_type> permsOut_arg = std::forward<PermsOutType>(permsOut);
std::optional<out_view_type> out_arg = std::forward<OutType>(out);
permute(handle, in, permsOut_arg, out_arg);
}

/**
* @brief Legacy overload of `permute` that takes raw arrays instead of mdspan.
*
* @tparam Type Type of each element of the input matrix to be permuted
* @tparam IntType Integer type of each element of the permsOut matrix
* @tparam IdxType Integer type of the dimensions of the matrices
* @tparam TPB threads per block (do not use any value other than the default)
*
* @param[out] perms If nonnull, the indices of the permutation
* @param[out] out If nonnull, the output matrix, containing the
* permuted rows of the input matrix @c in. (Not providing this
* is only useful if you provide @c perms.)
* @param[in] in input matrix
* @param[in] D number of columns in the matrices
* @param[in] N number of rows in the matrices
* @param[in] rowMajor true if the matrices are row major,
* false if they are column major
* @param[in] stream CUDA stream on which to run
*/
template <typename Type, typename IntType = int, typename IdxType = int, int TPB = 256>
void permute(IntType* perms,
Expand All @@ -60,4 +193,4 @@ void permute(IntType* perms,

}; // end namespace raft::random

#endif
#endif
4 changes: 2 additions & 2 deletions cpp/test/nvtx.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,7 +15,7 @@
*/
#ifdef NVTX_ENABLED
#include <gtest/gtest.h>
#include <raft/common/detail/nvtx.hpp>
#include <raft/core/detail/nvtx.hpp>
/**
* tests for the functionality of generating next color based on string
* entered in the NVTX Range marker wrappers
Expand Down
141 changes: 125 additions & 16 deletions cpp/test/random/permute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ template <typename T>

template <typename T>
class PermTest : public ::testing::TestWithParam<PermInputs<T>> {
public:
using test_data_type = T;

protected:
PermTest()
: in(0, handle.get_stream()), out(0, handle.get_stream()), outPerms(0, handle.get_stream())
Expand Down Expand Up @@ -81,6 +84,89 @@ class PermTest : public ::testing::TestWithParam<PermInputs<T>> {
int* outPerms_ptr = nullptr;
};

template <typename T>
class PermMdspanTest : public ::testing::TestWithParam<PermInputs<T>> {
public:
using test_data_type = T;

protected:
PermMdspanTest()
: in(0, handle.get_stream()), out(0, handle.get_stream()), outPerms(0, handle.get_stream())
{
}

private:
using index_type = int;

template <class ElementType, class Layout>
using matrix_view_t = raft::device_matrix_view<ElementType, index_type, Layout>;

template <class ElementType>
using vector_view_t = raft::device_vector_view<ElementType, index_type>;

protected:
void SetUp() override
{
auto stream = handle.get_stream();
params = ::testing::TestWithParam<PermInputs<T>>::GetParam();
// forcefully set needPerms, since we need it for unit-testing!
if (params.needShuffle) { params.needPerms = true; }
raft::random::RngState r(params.seed);
int N = params.N;
int D = params.D;
int len = N * D;
if (params.needPerms) {
outPerms.resize(N, stream);
outPerms_ptr = outPerms.data();
}
if (params.needShuffle) {
in.resize(len, stream);
out.resize(len, stream);
in_ptr = in.data();
out_ptr = out.data();
uniform(handle, r, in_ptr, len, T(-1.0), T(1.0));
}

auto set_up_views_and_test = [&](auto layout) {
using layout_type = std::decay_t<decltype(layout)>;

matrix_view_t<const T, layout_type> in_view(in_ptr, N, D);
std::optional<matrix_view_t<T, layout_type>> out_view;
if (out_ptr != nullptr) { out_view.emplace(out_ptr, N, D); }
std::optional<vector_view_t<index_type>> outPerms_view;
if (outPerms_ptr != nullptr) { outPerms_view.emplace(outPerms_ptr, N); }

permute(handle, in_view, outPerms_view, out_view);

// None of these three permute calls should have an effect.
// The point is to test whether the function can deduce the
// element type of outPerms if given nullopt.
std::optional<matrix_view_t<T, layout_type>> out_view_empty;
std::optional<vector_view_t<index_type>> outPerms_view_empty;
permute(handle, in_view, std::nullopt, out_view_empty);
permute(handle, in_view, outPerms_view_empty, std::nullopt);
permute(handle, in_view, std::nullopt, std::nullopt);
};

if (params.rowMajor) {
set_up_views_and_test(raft::row_major{});
} else {
set_up_views_and_test(raft::col_major{});
}

handle.sync_stream();
}

protected:
raft::handle_t handle;
PermInputs<T> params;
rmm::device_uvector<T> in, out;
T* in_ptr = nullptr;
T* out_ptr = nullptr;
rmm::device_uvector<int> outPerms;
int* outPerms_ptr = nullptr;
};

template <typename T, typename L>
::testing::AssertionResult devArrMatchRange(
const T* actual, size_t size, T start, L eq_compare, bool doSort = true, cudaStream_t stream = 0)
Expand Down Expand Up @@ -169,19 +255,38 @@ const std::vector<PermInputs<float>> inputsf = {
{100000, 32, true, true, false, 1234567890ULL},
{100001, 33, true, true, false, 1234567890ULL}};

typedef PermTest<float> PermTestF;
#define _PERMTEST_BODY(DATA_TYPE) \
do { \
if (params.needPerms) { \
ASSERT_TRUE(devArrMatchRange(outPerms_ptr, params.N, 0, raft::Compare<int>())); \
} \
if (params.needShuffle) { \
ASSERT_TRUE(devArrMatchShuffle(outPerms_ptr, \
out_ptr, \
in_ptr, \
params.D, \
params.N, \
params.rowMajor, \
raft::Compare<DATA_TYPE>())); \
} \
} while (false)

using PermTestF = PermTest<float>;
TEST_P(PermTestF, Result)
{
if (params.needPerms) {
ASSERT_TRUE(devArrMatchRange(outPerms_ptr, params.N, 0, raft::Compare<int>()));
}
if (params.needShuffle) {
ASSERT_TRUE(devArrMatchShuffle(
outPerms_ptr, out_ptr, in_ptr, params.D, params.N, params.rowMajor, raft::Compare<float>()));
}
using test_data_type = PermTestF::test_data_type;
_PERMTEST_BODY(test_data_type);
}
INSTANTIATE_TEST_CASE_P(PermTests, PermTestF, ::testing::ValuesIn(inputsf));

using PermMdspanTestF = PermMdspanTest<float>;
TEST_P(PermMdspanTestF, Result)
{
using test_data_type = PermTestF::test_data_type;
_PERMTEST_BODY(test_data_type);
}
INSTANTIATE_TEST_CASE_P(PermMdspanTests, PermMdspanTestF, ::testing::ValuesIn(inputsf));

const std::vector<PermInputs<double>> inputsd = {
// only generate permutations
{32, 8, true, false, true, 1234ULL},
Expand Down Expand Up @@ -219,18 +324,22 @@ const std::vector<PermInputs<double>> inputsd = {
{100000, 32, true, true, false, 1234ULL},
{100000, 32, true, true, false, 1234567890ULL},
{100001, 33, true, true, false, 1234567890ULL}};
typedef PermTest<double> PermTestD;

using PermTestD = PermTest<double>;
TEST_P(PermTestD, Result)
{
if (params.needPerms) {
ASSERT_TRUE(devArrMatchRange(outPerms_ptr, params.N, 0, raft::Compare<int>()));
}
if (params.needShuffle) {
ASSERT_TRUE(devArrMatchShuffle(
outPerms_ptr, out_ptr, in_ptr, params.D, params.N, params.rowMajor, raft::Compare<double>()));
}
using test_data_type = PermTestF::test_data_type;
_PERMTEST_BODY(test_data_type);
}
INSTANTIATE_TEST_CASE_P(PermTests, PermTestD, ::testing::ValuesIn(inputsd));

using PermMdspanTestD = PermMdspanTest<double>;
TEST_P(PermMdspanTestD, Result)
{
using test_data_type = PermTestF::test_data_type;
_PERMTEST_BODY(test_data_type);
}
INSTANTIATE_TEST_CASE_P(PermMdspanTests, PermMdspanTestD, ::testing::ValuesIn(inputsd));

} // end namespace random
} // end namespace raft

0 comments on commit bfd90a7

Please sign in to comment.