From 20a31bdcfd9652f48b0dcf6390b82c10ca24ed43 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Wed, 25 Jan 2023 19:07:48 +0100 Subject: [PATCH] Fix `unary_op` docs and add `map_offset` as an improved version of `write_only_unary_op` (#1149) Follows [a discussion](https://github.com/rapidsai/raft/pull/1113#discussion_r1063227687) that we had about `write_only_unary_op` / `writeOnlyUnaryOp`. For consistency and to simplify the use of this primitive, we shouldn't require dereferencing a pointer in the functor. This new implementation uses `thrust::tabulate` but we can add our own optimized kernel later if we need. Authors: - Louis Sugy (https://github.com/Nyrio) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) - Artem M. Chirkin (https://github.com/achirkin) URL: https://github.com/rapidsai/raft/pull/1149 --- cpp/include/raft/core/operators.hpp | 14 +++ cpp/include/raft/linalg/map.cuh | 35 +++++++ cpp/include/raft/linalg/unary_op.cuh | 83 ++++++++-------- cpp/include/raft/random/detail/make_blobs.cuh | 16 ++-- .../knn/detail/ann_kmeans_balanced.cuh | 10 +- .../spatial/knn/detail/ivf_flat_build.cuh | 25 ++--- .../spatial/knn/detail/ivf_flat_search.cuh | 1 + .../raft/spatial/knn/detail/ivf_pq_build.cuh | 75 +++++++-------- .../raft/spatial/knn/detail/ivf_pq_search.cuh | 15 +-- cpp/test/linalg/map.cu | 89 ++++++++++------- cpp/test/linalg/unary_op.cu | 96 +++++++++---------- cpp/test/linalg/unary_op.cuh | 4 +- 12 files changed, 264 insertions(+), 199 deletions(-) diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index edb437c880..9fcf6657db 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -163,6 +163,14 @@ struct pow_op { } }; +struct mod_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a % b; + } +}; + struct min_op { template RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const @@ -313,6 +321,12 @@ using modulo_const_op = plug_const_op; template using pow_const_op = plug_const_op; +template +using mod_const_op = plug_const_op; + +template +using equal_const_op = plug_const_op; + /** * @brief Constructs an operator by composing a chain of operators. * diff --git a/cpp/include/raft/linalg/map.cuh b/cpp/include/raft/linalg/map.cuh index ad35cc5880..c2e2e6303a 100644 --- a/cpp/include/raft/linalg/map.cuh +++ b/cpp/include/raft/linalg/map.cuh @@ -22,6 +22,7 @@ #include #include +#include namespace raft { namespace linalg { @@ -96,6 +97,40 @@ void map(const raft::handle_t& handle, InType in, OutType out, MapOp map, Args.. } } +/** + * @brief Perform an element-wise unary operation on the input offset into the output array + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * #include + * ... + * raft::handle_t handle; + * auto squares = raft::make_device_vector(handle, n); + * raft::linalg::map_offset(handle, squares.view(), raft::sq_op()); + * @endcode + * + * @tparam OutType Output mdspan type + * @tparam MapOp The unary operation type with signature `OutT func(const IdxT& idx);` + * @param[in] handle The raft handle + * @param[out] out Output array + * @param[in] op The unary operation + */ +template > +void map_offset(const raft::handle_t& handle, OutType out, MapOp op) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + + using out_value_t = typename OutType::value_type; + + thrust::tabulate( + handle.get_thrust_policy(), out.data_handle(), out.data_handle() + out.size(), op); +} + /** @} */ // end of map } // namespace linalg diff --git a/cpp/include/raft/linalg/unary_op.cuh b/cpp/include/raft/linalg/unary_op.cuh index a90bda06d5..e39821cf80 100644 --- a/cpp/include/raft/linalg/unary_op.cuh +++ b/cpp/include/raft/linalg/unary_op.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,17 +30,16 @@ namespace linalg { /** * @brief perform element-wise unary operation in the input array * @tparam InType input data-type - * @tparam Lambda the device-lambda performing the actual operation + * @tparam Lambda Device lambda performing the actual operation, with the signature + * `OutType func(const InType& val);` * @tparam OutType output data-type * @tparam IdxType Integer type used to for addressing * @tparam TPB threads-per-block in the final kernel launched - * @param out the output array - * @param in the input array - * @param len number of elements in the input array - * @param op the device-lambda - * @param stream cuda stream where to launch work - * @note Lambda must be a functor with the following signature: - * `OutType func(const InType& val);` + * @param[out] out Output array [on device], dim = [len] + * @param[in] in Input array [on device], dim = [len] + * @param[in] len Number of elements in the input array + * @param[in] op Device lambda + * @param[in] stream cuda stream where to launch work */ template @@ -81,16 +80,15 @@ void writeOnlyUnaryOp(OutType* out, IdxType len, Lambda op, cudaStream_t stream) */ /** - * @brief perform element-wise binary operation on the input arrays + * @brief Perform an element-wise unary operation into the output array * @tparam InType Input Type raft::device_mdspan - * @tparam Lambda the device-lambda performing the actual operation + * @tparam Lambda Device lambda performing the actual operation, with the signature + * `out_value_t func(const in_value_t& val);` * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::handle_t - * @param[in] in Input - * @param[out] out Output - * @param[in] op the device-lambda - * @note Lambda must be a functor with the following signature: - * `InType func(const InType& val);` + * @param[in] handle The raft handle + * @param[in] in Input + * @param[out] out Output + * @param[in] op Device lambda */ template > -void write_only_unary_op(const raft::handle_t& handle, InType in, Lambda op) +template > +void write_only_unary_op(const raft::handle_t& handle, OutType out, Lambda op) { - RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); - using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; - if (in.size() <= std::numeric_limits::max()) { - writeOnlyUnaryOp( - in.data_handle(), in.size(), op, handle.get_stream()); + if (out.size() <= std::numeric_limits::max()) { + writeOnlyUnaryOp( + out.data_handle(), out.size(), op, handle.get_stream()); } else { - writeOnlyUnaryOp( - in.data_handle(), in.size(), op, handle.get_stream()); + writeOnlyUnaryOp( + out.data_handle(), out.size(), op, handle.get_stream()); } } diff --git a/cpp/include/raft/random/detail/make_blobs.cuh b/cpp/include/raft/random/detail/make_blobs.cuh index 68c2d56599..fb4db5184e 100644 --- a/cpp/include/raft/random/detail/make_blobs.cuh +++ b/cpp/include/raft/random/detail/make_blobs.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ #pragma once #include "permute.cuh" -#include +#include #include #include #include @@ -39,16 +39,16 @@ void generate_labels(IdxT* labels, raft::random::RngState& r, cudaStream_t stream) { + raft::handle_t handle(stream); IdxT a, b; raft::random::affine_transform_params(r, n_clusters, a, b); - auto op = [=] __device__(IdxT * ptr, IdxT idx) { - if (shuffle) { idx = IdxT((a * int64_t(idx)) + b); } + auto op = [=] __device__(IdxT idx) { + if (shuffle) { idx = static_cast((a * int64_t(idx)) + b); } idx %= n_clusters; - // in the unlikely case of n_clusters > n_rows, make sure that the writes - // do not go out-of-bounds - if (idx < n_rows) { *ptr = idx; } + return idx; }; - raft::linalg::writeOnlyUnaryOp(labels, n_rows, op, stream); + auto labels_view = raft::make_device_vector_view(labels, n_rows); + linalg::map_offset(handle, labels_view, op); } template diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index c6a3aea0cf..ba88924da5 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -731,10 +732,11 @@ void build_clusters(const handle_t& handle, "the chosen index type cannot represent all indices for the given dataset"); // "randomly initialize labels" - auto f = [n_clusters] __device__(LabelT * out, IdxT i) { - *out = LabelT(i % static_cast(n_clusters)); - }; - linalg::writeOnlyUnaryOp(cluster_labels, n_rows, f, stream); + auto labels_view = raft::make_device_vector_view(cluster_labels, n_rows); + linalg::map_offset( + handle, + labels_view, + raft::compose_op(raft::cast_op(), raft::mod_const_op(n_clusters))); // update centers to match the initialized labels. calc_centers_and_sizes(handle, diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index 0abd3825e6..6e038db68f 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -27,8 +27,8 @@ #include #include #include +#include #include -#include #include #include @@ -338,23 +338,24 @@ inline void fill_refinement_index(const handle_t& handle, "ivf_flat::fill_refinement_index(%zu, %u)", size_t(n_queries)); rmm::device_uvector new_labels(n_queries * n_candidates, stream); - linalg::writeOnlyUnaryOp( - new_labels.data(), - n_queries * n_candidates, - [n_candidates] __device__(LabelT * out, uint32_t i) { *out = i / n_candidates; }, - stream); + auto new_labels_view = + raft::make_device_vector_view(new_labels.data(), n_queries * n_candidates); + linalg::map_offset( + handle, + new_labels_view, + raft::compose_op(raft::cast_op(), raft::div_const_op(n_candidates))); auto list_sizes_ptr = refinement_index->list_sizes().data_handle(); auto list_offsets_ptr = refinement_index->list_offsets().data_handle(); // We do not fill centers and center norms, since we will not run coarse search. // Calculate new offsets - uint32_t n_roundup = Pow2::roundUp(n_candidates); - linalg::writeOnlyUnaryOp( - refinement_index->list_offsets().data_handle(), - refinement_index->list_offsets().size(), - [n_roundup] __device__(IdxT * out, uint32_t i) { *out = i * n_roundup; }, - stream); + uint32_t n_roundup = Pow2::roundUp(n_candidates); + auto list_offsets_view = raft::make_device_vector_view( + list_offsets_ptr, refinement_index->list_offsets().size()); + linalg::map_offset(handle, + list_offsets_view, + raft::compose_op(raft::cast_op(), raft::mul_const_op(n_roundup))); IdxT index_size = n_roundup * n_lists; refinement_index->allocate(handle, index_size); diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index ab445c75d4..82b7bcf81e 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh index ee020606c7..adbedf854f 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh @@ -31,7 +31,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -209,8 +211,11 @@ inline void make_rotation_matrix(const handle_t& handle, } } else { uint32_t stride = n + 1; - auto f = [stride] __device__(float* out, uint32_t i) -> void { *out = float(i % stride == 0); }; - linalg::writeOnlyUnaryOp(rotation_matrix, n * n, f, stream); + auto rotation_matrix_view = + raft::make_device_vector_view(rotation_matrix, n * n); + linalg::map_offset(handle, rotation_matrix_view, [stride] __device__(uint32_t i) { + return static_cast(i % stride == 0u); + }); } } @@ -283,16 +288,13 @@ void flat_compute_residuals( auto dim = rotation_matrix.extent(1); auto rot_dim = rotation_matrix.extent(0); rmm::device_uvector tmp(n_rows * dim, stream, device_memory); - linalg::writeOnlyUnaryOp( - tmp.data(), - tmp.size(), - [centers, dataset, labels, dim] __device__(float* out, size_t i) { - auto row_ix = i / dim; - auto el_ix = i % dim; - auto label = labels[row_ix]; - *out = utils::mapping{}(dataset[i]) - centers(label, el_ix); - }, - stream); + auto tmp_view = raft::make_device_vector_view(tmp.data(), tmp.size()); + linalg::map_offset(handle, tmp_view, [centers, dataset, labels, dim] __device__(size_t i) { + auto row_ix = i / dim; + auto el_ix = i % dim; + auto label = labels[row_ix]; + return utils::mapping{}(dataset[i]) - centers(label, el_ix); + }); float alpha = 1.0f; float beta = 0.0f; @@ -368,29 +370,28 @@ auto calculate_offsets_and_indices(IdxT n_rows, } template -void transpose_pq_centers(index& index, - const float* pq_centers_source, - rmm::cuda_stream_view stream) +void transpose_pq_centers(const handle_t& handle, + index& index, + const float* pq_centers_source) { + auto stream = handle.get_stream(); auto extents = index.pq_centers().extents(); static_assert(extents.rank() == 3); auto extents_source = make_extents(extents.extent(0), extents.extent(2), extents.extent(1)); auto span_source = make_mdspan(pq_centers_source, extents_source); - linalg::writeOnlyUnaryOp( - index.pq_centers().data_handle(), - index.pq_centers().size(), - [span_source, extents] __device__(float* out, size_t i) { - uint32_t ii[3]; - for (int r = 2; r > 0; r--) { - ii[r] = i % extents.extent(r); - i /= extents.extent(r); - } - ii[0] = i; - *out = span_source(ii[0], ii[2], ii[1]); - }, - stream); + auto pq_centers_view = raft::make_device_vector_view( + index.pq_centers().data_handle(), index.pq_centers().size()); + linalg::map_offset(handle, pq_centers_view, [span_source, extents] __device__(size_t i) { + uint32_t ii[3]; + for (int r = 2; r > 0; r--) { + ii[r] = i % extents.extent(r); + i /= extents.extent(r); + } + ii[0] = i; + return span_source(ii[0], ii[2], ii[1]); + }); } template @@ -460,7 +461,7 @@ void train_per_subset(const handle_t& handle, stream, device_memory); } - transpose_pq_centers(index, pq_centers_tmp.data(), stream); + transpose_pq_centers(handle, index, pq_centers_tmp.data()); } template @@ -537,7 +538,7 @@ void train_per_cluster(const handle_t& handle, stream, device_memory); } - transpose_pq_centers(index, pq_centers_tmp.data(), stream); + transpose_pq_centers(handle, index, pq_centers_tmp.data()); } /** @@ -1212,14 +1213,12 @@ auto build( if (dataset_attr.devicePointer != nullptr) { // data is available on device: just run the kernel to copy and map the data auto p = reinterpret_cast(dataset_attr.devicePointer); - linalg::writeOnlyUnaryOp( - trainset.data(), - dim * n_rows_train, - [p, trainset_ratio, dim] __device__(float* out, size_t i) { - auto col = i % dim; - *out = utils::mapping{}(p[(i - col) * size_t(trainset_ratio) + col]); - }, - stream); + auto trainset_view = + raft::make_device_vector_view(trainset.data(), dim * n_rows_train); + linalg::map_offset(handle, trainset_view, [p, trainset_ratio, dim] __device__(size_t i) { + auto col = i % dim; + return utils::mapping{}(p[(i - col) * size_t(trainset_ratio) + col]); + }); } else { // data is not available: first copy, then map inplace auto trainset_tmp = reinterpret_cast(reinterpret_cast(trainset.data()) + diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh index aa69841f4b..b47ba32c58 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh @@ -28,6 +28,8 @@ #include #include #include +#include +#include #include #include #include @@ -175,15 +177,14 @@ void select_clusters(const handle_t& handle, case raft::distance::DistanceType::InnerProduct: norm_factor = 0.0; break; default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); } - linalg::writeOnlyUnaryOp( - float_queries, - dim_ext * n_queries, - [queries, dim, dim_ext, norm_factor] __device__(float* out, uint32_t ix) { + auto float_queries_view = + raft::make_device_vector_view(float_queries, dim_ext * n_queries); + linalg::map_offset( + handle, float_queries_view, [queries, dim, dim_ext, norm_factor] __device__(uint32_t ix) { uint32_t col = ix % dim_ext; uint32_t row = ix / dim_ext; - *out = col < dim ? utils::mapping{}(queries[col + dim * row]) : norm_factor; - }, - stream); + return col < dim ? utils::mapping{}(queries[col + dim * row]) : norm_factor; + }); float alpha; float beta; diff --git a/cpp/test/linalg/map.cu b/cpp/test/linalg/map.cu index 7e3a1562d9..1add9f7828 100644 --- a/cpp/test/linalg/map.cu +++ b/cpp/test/linalg/map.cu @@ -15,6 +15,7 @@ */ #include "../test_utils.cuh" +#include "unary_op.cuh" #include #include #include @@ -107,52 +108,70 @@ class MapTest : public ::testing::TestWithParam out_ref, out; }; +template +class MapOffsetTest : public ::testing::TestWithParam> { + public: + MapOffsetTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + out_ref(params.len, stream), + out(params.len, stream) + { + } + + protected: + void SetUp() override + { + IdxType len = params.len; + OutType scalar = params.scalar; + naiveScale(out_ref.data(), (OutType*)nullptr, scalar, len, stream); + + auto out_view = raft::make_device_vector_view(out.data(), len); + map_offset(handle, + out_view, + raft::compose_op(raft::cast_op(), raft::mul_const_op(scalar))); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + MapInputs params; + rmm::device_uvector out_ref, out; +}; + +#define MAP_TEST(test_type, test_name, inputs) \ + typedef RAFT_DEPAREN(test_type) test_name; \ + TEST_P(test_name, Result) \ + { \ + ASSERT_TRUE(devArrMatch(this->out_ref.data(), \ + this->out.data(), \ + this->params.len, \ + CompareApprox(this->params.tolerance))); \ + } \ + INSTANTIATE_TEST_SUITE_P(MapTests, test_name, ::testing::ValuesIn(inputs)) + const std::vector> inputsf_i32 = {{0.000001f, 1024 * 1024, 1234ULL, 3.2}}; -typedef MapTest MapTestF_i32; -TEST_P(MapTestF_i32, Result) -{ - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); -} -INSTANTIATE_TEST_SUITE_P(MapTests, MapTestF_i32, ::testing::ValuesIn(inputsf_i32)); +MAP_TEST((MapTest), MapTestF_i32, inputsf_i32); +MAP_TEST((MapOffsetTest), MapOffsetTestF_i32, inputsf_i32); const std::vector> inputsf_i64 = {{0.000001f, 1024 * 1024, 1234ULL, 9.4}}; -typedef MapTest MapTestF_i64; -TEST_P(MapTestF_i64, Result) -{ - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); -} -INSTANTIATE_TEST_SUITE_P(MapTests, MapTestF_i64, ::testing::ValuesIn(inputsf_i64)); +MAP_TEST((MapTest), MapTestF_i64, inputsf_i64); +MAP_TEST((MapOffsetTest), MapOffsetTestF_i64, inputsf_i64); const std::vector> inputsf_i32_d = { {0.000001f, 1024 * 1024, 1234ULL, 5.9}}; -typedef MapTest MapTestF_i32_D; -TEST_P(MapTestF_i32_D, Result) -{ - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); -} -INSTANTIATE_TEST_SUITE_P(MapTests, MapTestF_i32_D, ::testing::ValuesIn(inputsf_i32_d)); +MAP_TEST((MapTest), MapTestF_i32_D, inputsf_i32_d); const std::vector> inputsd_i32 = {{0.00000001, 1024 * 1024, 1234ULL, 7.5}}; -typedef MapTest MapTestD_i32; -TEST_P(MapTestD_i32, Result) -{ - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); -} -INSTANTIATE_TEST_SUITE_P(MapTests, MapTestD_i32, ::testing::ValuesIn(inputsd_i32)); +MAP_TEST((MapTest), MapTestD_i32, inputsd_i32); +MAP_TEST((MapOffsetTest), MapOffsetTestD_i32, inputsd_i32); const std::vector> inputsd_i64 = { {0.00000001, 1024 * 1024, 1234ULL, 5.2}}; -typedef MapTest MapTestD_i64; -TEST_P(MapTestD_i64, Result) -{ - ASSERT_TRUE( - devArrMatch(out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); -} -INSTANTIATE_TEST_SUITE_P(MapTests, MapTestD_i64, ::testing::ValuesIn(inputsd_i64)); +MAP_TEST((MapTest), MapTestD_i64, inputsd_i64); +MAP_TEST((MapOffsetTest), MapOffsetTestD_i64, inputsd_i64); } // namespace linalg } // namespace raft diff --git a/cpp/test/linalg/unary_op.cu b/cpp/test/linalg/unary_op.cu index 3ebf70e69f..341ae1a855 100644 --- a/cpp/test/linalg/unary_op.cu +++ b/cpp/test/linalg/unary_op.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,27 +24,6 @@ namespace raft { namespace linalg { -// Or else, we get the following compilation error -// for an extended __device__ lambda cannot have private or protected access -// within its class -template -void unaryOpLaunch(OutType* out, const InType* in, InType scalar, IdxType len, cudaStream_t stream) -{ - raft::handle_t handle{stream}; - auto out_view = raft::make_device_vector_view(out, len); - auto in_view = raft::make_device_vector_view(in, len); - if (in == nullptr) { - auto op = [scalar] __device__(OutType * ptr, IdxType idx) { - *ptr = static_cast(scalar * idx); - }; - - write_only_unary_op(handle, out_view, op); - } else { - auto op = [scalar] __device__(InType in) { return static_cast(in * scalar); }; - unary_op(handle, in_view, out_view, op); - } -} - template class UnaryOpTest : public ::testing::TestWithParam> { public: @@ -71,10 +50,14 @@ class UnaryOpTest : public ::testing::TestWithParam(in.data(), len); + auto out_view = raft::make_device_vector_view(out.data(), len); + unary_op(handle, + in_view, + out_view, + raft::compose_op(raft::cast_op(), raft::mul_const_op(scalar))); handle.sync_stream(stream); - ASSERT_TRUE(devArrMatch( - out_ref.data(), out.data(), params.len, CompareApprox(params.tolerance))); } protected: @@ -86,6 +69,19 @@ class UnaryOpTest : public ::testing::TestWithParam out_ref, out; }; +// Or else, we get the following compilation error: +// The enclosing parent function ("DoTest") for an extended __device__ lambda cannot have private or +// protected access within its class +template +void launchWriteOnlyUnaryOp(const raft::handle_t& handle, OutType* out, InType scalar, IdxType len) +{ + auto out_view = raft::make_device_vector_view(out, len); + auto op = [scalar] __device__(OutType * ptr, IdxType idx) { + *ptr = static_cast(scalar * idx); + }; + write_only_unary_op(handle, out_view, op); +} + template class WriteOnlyUnaryOpTest : public UnaryOpTest { protected: @@ -94,50 +90,46 @@ class WriteOnlyUnaryOpTest : public UnaryOpTest { auto len = this->params.len; auto scalar = this->params.scalar; naiveScale(this->out_ref.data(), (OutType*)nullptr, scalar, len, this->stream); - unaryOpLaunch(this->out.data(), (OutType*)nullptr, scalar, len, this->stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(this->stream)); - ASSERT_TRUE(devArrMatch(this->out_ref.data(), - this->out.data(), - this->params.len, - CompareApprox(this->params.tolerance))); + + launchWriteOnlyUnaryOp(this->handle, this->out.data(), scalar, len); + this->handle.sync_stream(this->stream); } }; -#define UNARY_OP_TEST(Name, inputs) \ - TEST_P(Name, Result) { DoTest(); } \ - INSTANTIATE_TEST_SUITE_P(UnaryOpTests, Name, ::testing::ValuesIn(inputs)) +#define UNARY_OP_TEST(test_type, test_name, inputs) \ + typedef RAFT_DEPAREN(test_type) test_name; \ + TEST_P(test_name, Result) \ + { \ + DoTest(); \ + ASSERT_TRUE(devArrMatch(this->out_ref.data(), \ + this->out.data(), \ + this->params.len, \ + CompareApprox(this->params.tolerance))); \ + } \ + INSTANTIATE_TEST_SUITE_P(UnaryOpTests, test_name, ::testing::ValuesIn(inputs)) const std::vector> inputsf_i32 = {{0.000001f, 1024 * 1024, 2.f, 1234ULL}}; -typedef UnaryOpTest UnaryOpTestF_i32; -UNARY_OP_TEST(UnaryOpTestF_i32, inputsf_i32); -typedef WriteOnlyUnaryOpTest WriteOnlyUnaryOpTestF_i32; -UNARY_OP_TEST(WriteOnlyUnaryOpTestF_i32, inputsf_i32); +UNARY_OP_TEST((UnaryOpTest), UnaryOpTestF_i32, inputsf_i32); +UNARY_OP_TEST((WriteOnlyUnaryOpTest), WriteOnlyUnaryOpTestF_i32, inputsf_i32); const std::vector> inputsf_i64 = { {0.000001f, 1024 * 1024, 2.f, 1234ULL}}; -typedef UnaryOpTest UnaryOpTestF_i64; -UNARY_OP_TEST(UnaryOpTestF_i64, inputsf_i64); -typedef WriteOnlyUnaryOpTest WriteOnlyUnaryOpTestF_i64; -UNARY_OP_TEST(WriteOnlyUnaryOpTestF_i64, inputsf_i64); +UNARY_OP_TEST((UnaryOpTest), UnaryOpTestF_i64, inputsf_i64); +UNARY_OP_TEST((WriteOnlyUnaryOpTest), WriteOnlyUnaryOpTestF_i64, inputsf_i64); const std::vector> inputsf_i32_d = { {0.000001f, 1024 * 1024, 2.f, 1234ULL}}; -typedef UnaryOpTest UnaryOpTestF_i32_D; -UNARY_OP_TEST(UnaryOpTestF_i32_D, inputsf_i32_d); +UNARY_OP_TEST((UnaryOpTest), UnaryOpTestF_i32_D, inputsf_i32_d); const std::vector> inputsd_i32 = { {0.00000001, 1024 * 1024, 2.0, 1234ULL}}; -typedef UnaryOpTest UnaryOpTestD_i32; -UNARY_OP_TEST(UnaryOpTestD_i32, inputsd_i32); -typedef WriteOnlyUnaryOpTest WriteOnlyUnaryOpTestD_i32; -UNARY_OP_TEST(WriteOnlyUnaryOpTestD_i32, inputsd_i32); +UNARY_OP_TEST((UnaryOpTest), UnaryOpTestD_i32, inputsd_i32); +UNARY_OP_TEST((WriteOnlyUnaryOpTest), WriteOnlyUnaryOpTestD_i32, inputsd_i32); const std::vector> inputsd_i64 = { {0.00000001, 1024 * 1024, 2.0, 1234ULL}}; -typedef UnaryOpTest UnaryOpTestD_i64; -UNARY_OP_TEST(UnaryOpTestD_i64, inputsd_i64); -typedef WriteOnlyUnaryOpTest WriteOnlyUnaryOpTestD_i64; -UNARY_OP_TEST(WriteOnlyUnaryOpTestD_i64, inputsd_i64); +UNARY_OP_TEST((UnaryOpTest), UnaryOpTestD_i64, inputsd_i64); +UNARY_OP_TEST((WriteOnlyUnaryOpTest), WriteOnlyUnaryOpTestD_i64, inputsd_i64); } // end namespace linalg } // end namespace raft diff --git a/cpp/test/linalg/unary_op.cuh b/cpp/test/linalg/unary_op.cuh index 28bcc004a4..9d2bd6f7c9 100644 --- a/cpp/test/linalg/unary_op.cuh +++ b/cpp/test/linalg/unary_op.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ __global__ void naiveScaleKernel(OutType* out, const InType* in, InType scalar, IdxType idx = threadIdx.x + ((IdxType)blockIdx.x * (IdxType)blockDim.x); if (idx < len) { if (in == nullptr) { - // used for testing writeOnlyUnaryOp + // used for testing write_only_unary_op out[idx] = static_cast(scalar * idx); } else { out[idx] = static_cast(scalar * in[idx]);