From 6a7e125bbb5ffdf195cd0fc50ec44aa28e6900bf Mon Sep 17 00:00:00 2001 From: Micka Date: Wed, 1 Feb 2023 18:40:39 +0100 Subject: [PATCH] Add function to convert mdspan to a const view (#1188) `make_const_mdspan` is a helper function to convert `mdspan` into `mdspan`. I added examples of it's usage @mhoemmen @Nyrio Authors: - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Louis Sugy (https://github.com/Nyrio) - Corey J. Nolet (https://github.com/cjnolet) - Mark Hoemmen (https://github.com/mhoemmen) URL: https://github.com/rapidsai/raft/pull/1188 --- cpp/bench/matrix/argmin.cu | 4 +-- cpp/bench/matrix/gather.cu | 11 +++--- cpp/include/raft/core/mdspan.hpp | 50 ++++++++++++++++++++++++++- cpp/test/cluster/kmeans.cu | 18 ++++------ cpp/test/core/mdspan_utils.cu | 27 +++++++++++++++ docs/source/cpp_api/mdspan_mdspan.rst | 3 ++ 6 files changed, 90 insertions(+), 23 deletions(-) diff --git a/cpp/bench/matrix/argmin.cu b/cpp/bench/matrix/argmin.cu index 52f5aab7f3..3869f0c5e1 100644 --- a/cpp/bench/matrix/argmin.cu +++ b/cpp/bench/matrix/argmin.cu @@ -46,9 +46,7 @@ struct Argmin : public fixture { void run_benchmark(::benchmark::State& state) override { loop_on_state(state, [this]() { - auto matrix_const_view = raft::make_device_matrix_view( - matrix.data_handle(), matrix.extent(0), matrix.extent(1)); - raft::matrix::argmin(handle, matrix_const_view, indices.view()); + raft::matrix::argmin(handle, raft::make_const_mdspan(matrix.view()), indices.view()); }); } diff --git a/cpp/bench/matrix/gather.cu b/cpp/bench/matrix/gather.cu index 97812c20a1..c5d80744cd 100644 --- a/cpp/bench/matrix/gather.cu +++ b/cpp/bench/matrix/gather.cu @@ -64,14 +64,11 @@ struct Gather : public fixture { state.SetLabel(label_stream.str()); loop_on_state(state, [this]() { - auto matrix_const_view = raft::make_device_matrix_view( - matrix.data_handle(), matrix.extent(0), matrix.extent(1)); - auto map_const_view = - raft::make_device_vector_view(map.data_handle(), map.extent(0)); + auto matrix_const_view = raft::make_const_mdspan(matrix.view()); + auto map_const_view = raft::make_const_mdspan(map.view()); if constexpr (Conditional) { - auto stencil_const_view = - raft::make_device_vector_view(stencil.data_handle(), stencil.extent(0)); - auto pred_op = raft::plug_const_op(T(0.0), raft::greater_op()); + auto stencil_const_view = raft::make_const_mdspan(stencil.view()); + auto pred_op = raft::plug_const_op(T(0.0), raft::greater_op()); raft::matrix::gather_if( handle, matrix_const_view, out.view(), map_const_view, stencil_const_view, pred_op); } else { diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 786ce69f89..f805d20064 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -304,4 +304,52 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx, } } +/** + * @brief Const accessor specialization for default_accessor + * + * @tparam ElementType + * @param a + * @return std::experimental::default_accessor> + */ +template +std::experimental::default_accessor> accessor_of_const( + std::experimental::default_accessor a) +{ + return {a}; +} + +/** + * @brief Const accessor specialization for host_device_accessor + * + * @tparam ElementType the data type of the mdspan elements + * @tparam MemType the type of memory where the elements are stored. + * @param a host_device_accessor + * @return host_device_accessor>, + * MemType> + */ +template +host_device_accessor>, MemType> +accessor_of_const(host_device_accessor, MemType> a) +{ + return {a}; +} + +/** + * @brief Create a copy of the given mdspan with const element type + * + * @tparam ElementType the const-qualified data type of the mdspan elements + * @tparam Extents raft::extents for dimensions + * @tparam Layout policy for strides and layout ordering + * @tparam Accessor Accessor policy for the input and output + * @param mds raft::mdspan object + * @return raft::mdspan + */ +template +auto make_const_mdspan(mdspan mds) +{ + auto acc_c = accessor_of_const(mds.accessor()); + return mdspan, Extents, Layout, decltype(acc_c)>{ + mds.data_handle(), mds.mapping(), acc_c}; +} + } // namespace raft diff --git a/cpp/test/cluster/kmeans.cu b/cpp/test/cluster/kmeans.cu index f4345aff82..685bd1f965 100644 --- a/cpp/test/cluster/kmeans.cu +++ b/cpp/test/cluster/kmeans.cu @@ -112,8 +112,7 @@ class KmeansTest : public ::testing::TestWithParam> { rmm::device_uvector workspace(0, stream); rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); rmm::device_uvector inRankCp(0, stream); - auto X_view = - raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)); + auto X_view = raft::make_const_mdspan(X.view()); auto centroids_view = raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); auto miniX = raft::make_device_matrix(handle, n_samples / 4, n_features); @@ -126,12 +125,8 @@ class KmeansTest : public ::testing::TestWithParam> { miniX.extent(0), params.rng_state.seed); - raft::cluster::kmeans::init_plus_plus(handle, - params, - raft::make_device_matrix_view( - miniX.data_handle(), miniX.extent(0), miniX.extent(1)), - centroids_view, - workspace); + raft::cluster::kmeans::init_plus_plus( + handle, params, raft::make_const_mdspan(miniX.view()), centroids_view, workspace); auto minClusterDistance = raft::make_device_vector(handle, n_samples); auto minClusterAndDistance = @@ -285,10 +280,9 @@ class KmeansTest : public ::testing::TestWithParam> { raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, stream); - T inertia = 0; - int n_iter = 0; - auto X_view = - raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)); + T inertia = 0; + int n_iter = 0; + auto X_view = raft::make_const_mdspan(X.view()); raft::cluster::kmeans_fit_predict( handle, diff --git a/cpp/test/core/mdspan_utils.cu b/cpp/test/core/mdspan_utils.cu index f428da4b31..4bb689c8c0 100644 --- a/cpp/test/core/mdspan_utils.cu +++ b/cpp/test/core/mdspan_utils.cu @@ -214,4 +214,31 @@ void test_reshape() TEST(MDArray, Reshape) { test_reshape(); } +void test_const_mdspan() +{ + // 3d host array + { + using two_d_extents = extents; + using two_d_mdarray = host_mdarray; + + typename two_d_mdarray::mapping_type layout{two_d_extents{}}; + typename two_d_mdarray::container_policy_type policy; + two_d_mdarray mda{layout, policy}; + + auto const_mda = make_const_mdspan(mda.view()); + + static_assert(std::is_same_v, + "elements not the same"); + static_assert(std::is_same_v, + "extents not the same"); + static_assert(std::is_same_v, + "layouts not the same"); + ASSERT_EQ(mda.size(), const_mda.size()); + } +} + +TEST(MDSpan, ConstMDSpan) { test_const_mdspan(); } + } // namespace raft \ No newline at end of file diff --git a/docs/source/cpp_api/mdspan_mdspan.rst b/docs/source/cpp_api/mdspan_mdspan.rst index 272a724833..619150f538 100644 --- a/docs/source/cpp_api/mdspan_mdspan.rst +++ b/docs/source/cpp_api/mdspan_mdspan.rst @@ -22,6 +22,9 @@ mdspan: Multi-dimensional Non-owning View .. doxygenfunction:: raft::unravel_index :project: RAFT +.. doxygenfunction:: raft::make_const_mdspan(mdspan_type mds) + :project: RAFT + Device Vocabulary -----------------