Skip to content

Commit

Permalink
Fix unary_op docs and add map_offset as an improved version of `w…
Browse files Browse the repository at this point in the history
…rite_only_unary_op` (#1149)

Follows [a discussion](#1113 (comment)) that we had about `write_only_unary_op` / `writeOnlyUnaryOp`.

For consistency and to simplify the use of this primitive, we shouldn't require dereferencing a pointer in the functor.

This new implementation uses `thrust::tabulate` but we can add our own optimized kernel later if we need.

Authors:
  - Louis Sugy (https://github.com/Nyrio)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)
  - Artem M. Chirkin (https://github.com/achirkin)

URL: #1149
  • Loading branch information
Nyrio authored Jan 25, 2023
1 parent 7c12b1e commit 20a31bd
Show file tree
Hide file tree
Showing 12 changed files with 264 additions and 199 deletions.
14 changes: 14 additions & 0 deletions cpp/include/raft/core/operators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ struct pow_op {
}
};

struct mod_op {
template <typename T1, typename T2>
constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const
{
return a % b;
}
};

struct min_op {
template <typename... Args>
RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const
Expand Down Expand Up @@ -313,6 +321,12 @@ using modulo_const_op = plug_const_op<Type, modulo_op>;
template <typename Type>
using pow_const_op = plug_const_op<Type, pow_op>;

template <typename Type>
using mod_const_op = plug_const_op<Type, mod_op>;

template <typename Type>
using equal_const_op = plug_const_op<Type, equal_op>;

/**
* @brief Constructs an operator by composing a chain of operators.
*
Expand Down
35 changes: 35 additions & 0 deletions cpp/include/raft/linalg/map.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include <raft/core/device_mdspan.hpp>
#include <raft/util/input_validation.hpp>
#include <thrust/tabulate.h>

namespace raft {
namespace linalg {
Expand Down Expand Up @@ -96,6 +97,40 @@ void map(const raft::handle_t& handle, InType in, OutType out, MapOp map, Args..
}
}

/**
* @brief Perform an element-wise unary operation on the input offset into the output array
*
* Usage example:
* @code{.cpp}
* #include <raft/core/device_mdarray.hpp>
* #include <raft/core/handle.hpp>
* #include <raft/core/operators.hpp>
* #include <raft/linalg/map.cuh>
* ...
* raft::handle_t handle;
* auto squares = raft::make_device_vector<int>(handle, n);
* raft::linalg::map_offset(handle, squares.view(), raft::sq_op());
* @endcode
*
* @tparam OutType Output mdspan type
* @tparam MapOp The unary operation type with signature `OutT func(const IdxT& idx);`
* @param[in] handle The raft handle
* @param[out] out Output array
* @param[in] op The unary operation
*/
template <typename OutType,
typename MapOp,
typename = raft::enable_if_output_device_mdspan<OutType>>
void map_offset(const raft::handle_t& handle, OutType out, MapOp op)
{
RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous");

using out_value_t = typename OutType::value_type;

thrust::tabulate(
handle.get_thrust_policy(), out.data_handle(), out.data_handle() + out.size(), op);
}

/** @} */ // end of map

} // namespace linalg
Expand Down
83 changes: 42 additions & 41 deletions cpp/include/raft/linalg/unary_op.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,17 +30,16 @@ namespace linalg {
/**
* @brief perform element-wise unary operation in the input array
* @tparam InType input data-type
* @tparam Lambda the device-lambda performing the actual operation
* @tparam Lambda Device lambda performing the actual operation, with the signature
* `OutType func(const InType& val);`
* @tparam OutType output data-type
* @tparam IdxType Integer type used to for addressing
* @tparam TPB threads-per-block in the final kernel launched
* @param out the output array
* @param in the input array
* @param len number of elements in the input array
* @param op the device-lambda
* @param stream cuda stream where to launch work
* @note Lambda must be a functor with the following signature:
* `OutType func(const InType& val);`
* @param[out] out Output array [on device], dim = [len]
* @param[in] in Input array [on device], dim = [len]
* @param[in] len Number of elements in the input array
* @param[in] op Device lambda
* @param[in] stream cuda stream where to launch work
*/
template <typename InType,
typename Lambda,
Expand All @@ -58,15 +57,15 @@ void unaryOp(OutType* out, const InType* in, IdxType len, Lambda op, cudaStream_
* Compared to `unaryOp()`, this method does not do any reads from any inputs
*
* @tparam OutType output data-type
* @tparam Lambda the device-lambda performing the actual operation
* @tparam Lambda Device lambda performing the actual operation, with the signature
* `void func(OutType* outLocationOffset, IdxType idx);`
* where outLocationOffset will be out + idx.
* @tparam IdxType Integer type used to for addressing
* @tparam TPB threads-per-block in the final kernel launched
*
* @param[out] out the output array [on device] [len = len]
* @param[in] len number of elements in the input array
* @param[in] op the device-lambda which must be of the form:
* `void func(OutType* outLocationOffset, IdxType idx);`
* where outLocationOffset will be out + idx.
* @param[out] out Output array [on device], dim = [len]
* @param[in] len Number of elements in the input array
* @param[in] op Device lambda
* @param[in] stream cuda stream where to launch work
*/
template <typename OutType, typename Lambda, typename IdxType = int, int TPB = 256>
Expand All @@ -81,16 +80,15 @@ void writeOnlyUnaryOp(OutType* out, IdxType len, Lambda op, cudaStream_t stream)
*/

/**
* @brief perform element-wise binary operation on the input arrays
* @brief Perform an element-wise unary operation into the output array
* @tparam InType Input Type raft::device_mdspan
* @tparam Lambda the device-lambda performing the actual operation
* @tparam Lambda Device lambda performing the actual operation, with the signature
* `out_value_t func(const in_value_t& val);`
* @tparam OutType Output Type raft::device_mdspan
* @param[in] handle raft::handle_t
* @param[in] in Input
* @param[out] out Output
* @param[in] op the device-lambda
* @note Lambda must be a functor with the following signature:
* `InType func(const InType& val);`
* @param[in] handle The raft handle
* @param[in] in Input
* @param[out] out Output
* @param[in] op Device lambda
*/
template <typename InType,
typename Lambda,
Expand All @@ -116,29 +114,32 @@ void unary_op(const raft::handle_t& handle, InType in, OutType out, Lambda op)
}

/**
* @brief perform element-wise binary operation on the input arrays
* This function does not read from the input
* @tparam InType Input Type raft::device_mdspan
* @tparam Lambda the device-lambda performing the actual operation
* @param[in] handle raft::handle_t
* @param[inout] in Input/Output
* @param[in] op the device-lambda
* @note Lambda must be a functor with the following signature:
* `InType func(const InType& val);`
* @brief Perform an element-wise unary operation on the input index into the output array
*
* @note This operation is deprecated. Please use map_offset in `raft/linalg/map.cuh` instead.
*
* @tparam OutType Output Type raft::device_mdspan
* @tparam Lambda Device lambda performing the actual operation, with the signature
* `void func(out_value_t* out_location, index_t idx);`
* @param[in] handle The raft handle
* @param[out] out Output
* @param[in] op Device lambda
*/
template <typename InType, typename Lambda, typename = raft::enable_if_output_device_mdspan<InType>>
void write_only_unary_op(const raft::handle_t& handle, InType in, Lambda op)
template <typename OutType,
typename Lambda,
typename = raft::enable_if_output_device_mdspan<OutType>>
void write_only_unary_op(const raft::handle_t& handle, OutType out, Lambda op)
{
RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous");
RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous");

using in_value_t = typename InType::value_type;
using out_value_t = typename OutType::value_type;

if (in.size() <= std::numeric_limits<std::uint32_t>::max()) {
writeOnlyUnaryOp<in_value_t, Lambda, std::uint32_t>(
in.data_handle(), in.size(), op, handle.get_stream());
if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
writeOnlyUnaryOp<out_value_t, Lambda, std::uint32_t>(
out.data_handle(), out.size(), op, handle.get_stream());
} else {
writeOnlyUnaryOp<in_value_t, Lambda, std::uint64_t>(
in.data_handle(), in.size(), op, handle.get_stream());
writeOnlyUnaryOp<out_value_t, Lambda, std::uint64_t>(
out.data_handle(), out.size(), op, handle.get_stream());
}
}

Expand Down
16 changes: 8 additions & 8 deletions cpp/include/raft/random/detail/make_blobs.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,7 @@
#pragma once

#include "permute.cuh"
#include <raft/linalg/unary_op.cuh>
#include <raft/linalg/map.cuh>
#include <raft/random/rng.cuh>
#include <raft/random/rng_device.cuh>
#include <raft/util/cuda_utils.cuh>
Expand All @@ -39,16 +39,16 @@ void generate_labels(IdxT* labels,
raft::random::RngState& r,
cudaStream_t stream)
{
raft::handle_t handle(stream);
IdxT a, b;
raft::random::affine_transform_params(r, n_clusters, a, b);
auto op = [=] __device__(IdxT * ptr, IdxT idx) {
if (shuffle) { idx = IdxT((a * int64_t(idx)) + b); }
auto op = [=] __device__(IdxT idx) {
if (shuffle) { idx = static_cast<IdxT>((a * int64_t(idx)) + b); }
idx %= n_clusters;
// in the unlikely case of n_clusters > n_rows, make sure that the writes
// do not go out-of-bounds
if (idx < n_rows) { *ptr = idx; }
return idx;
};
raft::linalg::writeOnlyUnaryOp<IdxT, decltype(op), IdxT>(labels, n_rows, op, stream);
auto labels_view = raft::make_device_vector_view<IdxT, IdxT>(labels, n_rows);
linalg::map_offset(handle, labels_view, op);
}

template <typename DataT, typename IdxT>
Expand Down
10 changes: 6 additions & 4 deletions cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <raft/distance/fused_l2_nn.cuh>
#include <raft/linalg/add.cuh>
#include <raft/linalg/gemm.cuh>
#include <raft/linalg/map.cuh>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/normalize.cuh>
Expand Down Expand Up @@ -731,10 +732,11 @@ void build_clusters(const handle_t& handle,
"the chosen index type cannot represent all indices for the given dataset");

// "randomly initialize labels"
auto f = [n_clusters] __device__(LabelT * out, IdxT i) {
*out = LabelT(i % static_cast<IdxT>(n_clusters));
};
linalg::writeOnlyUnaryOp<LabelT, decltype(f), IdxT>(cluster_labels, n_rows, f, stream);
auto labels_view = raft::make_device_vector_view<LabelT, IdxT>(cluster_labels, n_rows);
linalg::map_offset(
handle,
labels_view,
raft::compose_op(raft::cast_op<LabelT>(), raft::mod_const_op<IdxT>(n_clusters)));

// update centers to match the initialized labels.
calc_centers_and_sizes(handle,
Expand Down
25 changes: 13 additions & 12 deletions cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
#include <raft/core/nvtx.hpp>
#include <raft/core/operators.hpp>
#include <raft/linalg/add.cuh>
#include <raft/linalg/map.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/stats/histogram.cuh>
#include <raft/util/pow2_utils.cuh>

Expand Down Expand Up @@ -338,23 +338,24 @@ inline void fill_refinement_index(const handle_t& handle,
"ivf_flat::fill_refinement_index(%zu, %u)", size_t(n_queries));

rmm::device_uvector<LabelT> new_labels(n_queries * n_candidates, stream);
linalg::writeOnlyUnaryOp(
new_labels.data(),
n_queries * n_candidates,
[n_candidates] __device__(LabelT * out, uint32_t i) { *out = i / n_candidates; },
stream);
auto new_labels_view =
raft::make_device_vector_view<LabelT, IdxT>(new_labels.data(), n_queries * n_candidates);
linalg::map_offset(
handle,
new_labels_view,
raft::compose_op(raft::cast_op<LabelT>(), raft::div_const_op<IdxT>(n_candidates)));

auto list_sizes_ptr = refinement_index->list_sizes().data_handle();
auto list_offsets_ptr = refinement_index->list_offsets().data_handle();
// We do not fill centers and center norms, since we will not run coarse search.

// Calculate new offsets
uint32_t n_roundup = Pow2<kIndexGroupSize>::roundUp(n_candidates);
linalg::writeOnlyUnaryOp(
refinement_index->list_offsets().data_handle(),
refinement_index->list_offsets().size(),
[n_roundup] __device__(IdxT * out, uint32_t i) { *out = i * n_roundup; },
stream);
uint32_t n_roundup = Pow2<kIndexGroupSize>::roundUp(n_candidates);
auto list_offsets_view = raft::make_device_vector_view<IdxT, IdxT>(
list_offsets_ptr, refinement_index->list_offsets().size());
linalg::map_offset(handle,
list_offsets_view,
raft::compose_op(raft::cast_op<IdxT>(), raft::mul_const_op<IdxT>(n_roundup)));

IdxT index_size = n_roundup * n_lists;
refinement_index->allocate(handle, index_size);
Expand Down
1 change: 1 addition & 0 deletions cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <raft/distance/distance.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/detail/select_k.cuh>
#include <raft/matrix/detail/select_warpsort.cuh>
#include <raft/util/cuda_utils.cuh>
Expand Down
Loading

0 comments on commit 20a31bd

Please sign in to comment.