Skip to content

Commit

Permalink
Updating raft::linalg APIs to use mdspan (#809)
Browse files Browse the repository at this point in the history
Authors:
  - Divye Gala (https://github.com/divyegala)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #809
  • Loading branch information
divyegala authored Sep 30, 2022
1 parent d475fca commit 7adf15e
Show file tree
Hide file tree
Showing 72 changed files with 3,610 additions and 405 deletions.
52 changes: 45 additions & 7 deletions cpp/include/raft/core/device_mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,23 @@ using managed_mdspan = mdspan<ElementType, Extents, LayoutPolicy, managed_access

namespace detail {
template <typename T, bool B>
struct is_device_accessible_mdspan : std::false_type {
struct is_device_mdspan : std::false_type {
};
template <typename T>
struct is_device_accessible_mdspan<T, true>
: std::bool_constant<T::accessor_type::is_device_accessible> {
struct is_device_mdspan<T, true> : std::bool_constant<T::accessor_type::is_device_accessible> {
};

/**
* @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type
*/
template <typename T>
using is_device_accessible_mdspan_t = is_device_accessible_mdspan<T, is_mdspan_v<T>>;
using is_device_mdspan_t = is_device_mdspan<T, is_mdspan_v<T>>;

template <typename T>
using is_input_device_mdspan_t = is_device_mdspan<T, is_input_mdspan_v<T>>;

template <typename T>
using is_output_device_mdspan_t = is_device_mdspan<T, is_output_mdspan_v<T>>;

template <typename T, bool B>
struct is_managed_mdspan : std::false_type {
Expand All @@ -70,18 +75,37 @@ struct is_managed_mdspan<T, true> : std::bool_constant<T::accessor_type::is_mana
template <typename T>
using is_managed_mdspan_t = is_managed_mdspan<T, is_mdspan_v<T>>;

template <typename T>
using is_input_managed_mdspan_t = is_managed_mdspan<T, is_input_mdspan_v<T>>;

template <typename T>
using is_output_managed_mdspan_t = is_managed_mdspan<T, is_output_mdspan_v<T>>;

} // end namespace detail

/**
* @\brief Boolean to determine if variadic template types Tn are either raft::device_mdspan or a
* derived type
*/
template <typename... Tn>
inline constexpr bool is_device_accessible_mdspan_v =
std::conjunction_v<detail::is_device_accessible_mdspan_t<Tn>...>;
inline constexpr bool is_device_mdspan_v = std::conjunction_v<detail::is_device_mdspan_t<Tn>...>;

template <typename... Tn>
inline constexpr bool is_input_device_mdspan_v =
std::conjunction_v<detail::is_input_device_mdspan_t<Tn>...>;

template <typename... Tn>
inline constexpr bool is_output_device_mdspan_v =
std::conjunction_v<detail::is_output_device_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_device_mdspan = std::enable_if_t<is_device_accessible_mdspan_v<Tn...>>;
using enable_if_device_mdspan = std::enable_if_t<is_device_mdspan_v<Tn...>>;

template <typename... Tn>
using enable_if_input_device_mdspan = std::enable_if_t<is_input_device_mdspan_v<Tn...>>;

template <typename... Tn>
using enable_if_output_device_mdspan = std::enable_if_t<is_output_device_mdspan_v<Tn...>>;

/**
* @\brief Boolean to determine if variadic template types Tn are either raft::managed_mdspan or a
Expand All @@ -90,9 +114,23 @@ using enable_if_device_mdspan = std::enable_if_t<is_device_accessible_mdspan_v<T
template <typename... Tn>
inline constexpr bool is_managed_mdspan_v = std::conjunction_v<detail::is_managed_mdspan_t<Tn>...>;

template <typename... Tn>
inline constexpr bool is_input_managed_mdspan_v =
std::conjunction_v<detail::is_input_managed_mdspan_t<Tn>...>;

template <typename... Tn>
inline constexpr bool is_output_managed_mdspan_v =
std::conjunction_v<detail::is_output_managed_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_managed_mdspan = std::enable_if_t<is_managed_mdspan_v<Tn...>>;

template <typename... Tn>
using enable_if_input_managed_mdspan = std::enable_if_t<is_input_managed_mdspan_v<Tn...>>;

template <typename... Tn>
using enable_if_output_managed_mdspan = std::enable_if_t<is_output_managed_mdspan_v<Tn...>>;

/**
* @brief Shorthand for 0-dim host mdspan (scalar).
* @tparam ElementType the data type of the scalar element
Expand Down
32 changes: 25 additions & 7 deletions cpp/include/raft/core/host_mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,23 @@ using host_mdspan = mdspan<ElementType, Extents, LayoutPolicy, host_accessor<Acc
namespace detail {

template <typename T, bool B>
struct is_host_accessible_mdspan : std::false_type {
struct is_host_mdspan : std::false_type {
};
template <typename T>
struct is_host_accessible_mdspan<T, true>
: std::bool_constant<T::accessor_type::is_host_accessible> {
struct is_host_mdspan<T, true> : std::bool_constant<T::accessor_type::is_host_accessible> {
};

/**
* @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type
*/
template <typename T>
using is_host_accessible_mdspan_t = is_host_accessible_mdspan<T, is_mdspan_v<T>>;
using is_host_mdspan_t = is_host_mdspan<T, is_mdspan_v<T>>;

template <typename T>
using is_input_host_mdspan_t = is_host_mdspan<T, is_input_mdspan_v<T>>;

template <typename T>
using is_output_host_mdspan_t = is_host_mdspan<T, is_output_mdspan_v<T>>;

} // namespace detail

Expand All @@ -57,11 +62,24 @@ using is_host_accessible_mdspan_t = is_host_accessible_mdspan<T, is_mdspan_v<T>>
* derived type
*/
template <typename... Tn>
inline constexpr bool is_host_accessible_mdspan_v =
std::conjunction_v<detail::is_host_accessible_mdspan_t<Tn>...>;
inline constexpr bool is_host_mdspan_v = std::conjunction_v<detail::is_host_mdspan_t<Tn>...>;

template <typename... Tn>
inline constexpr bool is_input_host_mdspan_v =
std::conjunction_v<detail::is_input_host_mdspan_t<Tn>...>;

template <typename... Tn>
inline constexpr bool is_output_host_mdspan_v =
std::conjunction_v<detail::is_output_host_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_host_mdspan = std::enable_if_t<is_input_mdspan_v<Tn...>>;

template <typename... Tn>
using enable_if_input_host_mdspan = std::enable_if_t<is_input_host_mdspan_v<Tn...>>;

template <typename... Tn>
using enable_if_host_mdspan = std::enable_if_t<is_host_accessible_mdspan_v<Tn...>>;
using enable_if_output_host_mdspan = std::enable_if_t<is_output_host_mdspan_v<Tn...>>;

/**
* @brief Shorthand for 0-dim host mdspan (scalar).
Expand Down
34 changes: 34 additions & 0 deletions cpp/include/raft/core/mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,31 @@ struct is_mdspan<T, std::void_t<decltype(__takes_an_mdspan_ptr(std::declval<T*>(
: std::true_type {
};

template <typename T, typename = void>
struct is_input_mdspan : std::false_type {
};
template <typename T>
struct is_input_mdspan<T, std::void_t<decltype(__takes_an_mdspan_ptr(std::declval<T*>()))>>
: std::bool_constant<std::is_const_v<typename T::element_type>> {
};

template <typename T, typename = void>
struct is_output_mdspan : std::false_type {
};
template <typename T>
struct is_output_mdspan<T, std::void_t<decltype(__takes_an_mdspan_ptr(std::declval<T*>()))>>
: std::bool_constant<not std::is_const_v<typename T::element_type>> {
};

template <typename T>
using is_mdspan_t = is_mdspan<std::remove_const_t<T>>;

template <typename T>
using is_input_mdspan_t = is_input_mdspan<T>;

template <typename T>
using is_output_mdspan_t = is_output_mdspan<T>;

/**
* @\brief Boolean to determine if variadic template types Tn are either
* raft::host_mdspan/raft::device_mdspan or their derived types
Expand All @@ -70,6 +92,18 @@ inline constexpr bool is_mdspan_v = std::conjunction_v<is_mdspan_t<Tn>...>;
template <typename... Tn>
using enable_if_mdspan = std::enable_if_t<is_mdspan_v<Tn...>>;

template <typename... Tn>
inline constexpr bool is_input_mdspan_v = std::conjunction_v<is_input_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_input_mdspan = std::enable_if_t<is_input_mdspan_v<Tn...>>;

template <typename... Tn>
inline constexpr bool is_output_mdspan_v = std::conjunction_v<is_output_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_output_mdspan = std::enable_if_t<is_output_mdspan_v<Tn...>>;

// uint division optimization inspired by the CIndexer in cupy. Division operation is
// slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64
// bit when the index is smaller, then try to avoid division when it's exp of 2.
Expand Down
149 changes: 141 additions & 8 deletions cpp/include/raft/linalg/add.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@

#include "detail/add.cuh"

#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/util/input_validation.hpp>

namespace raft {
namespace linalg {

Expand All @@ -46,7 +50,7 @@ using detail::adds_scalar;
* @param stream cuda stream where to launch work
*/
template <typename InT, typename OutT = InT, typename IdxType = int>
void addScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream)
void addScalar(OutT* out, const InT* in, const InT scalar, IdxType len, cudaStream_t stream)
{
detail::addScalar(out, in, scalar, len, stream);
}
Expand All @@ -72,24 +76,153 @@ void add(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t st

/** Substract single value pointed by singleScalarDev parameter in device memory from inDev[i] and
* write result to outDev[i]
* @tparam math_t data-type upon which the math operation will be performed
* @tparam InT input data-type. Also the data-type upon which the math ops
* will be performed
* @tparam OutT output data-type
* @tparam IdxType Integer type used to for addressing
* @param outDev the output buffer
* @param inDev the input buffer
* @param singleScalarDev pointer to the scalar located in device memory
* @param len number of elements in the input and output buffer
* @param stream cuda stream
*/
template <typename math_t, typename IdxType = int>
void addDevScalar(math_t* outDev,
const math_t* inDev,
const math_t* singleScalarDev,
IdxType len,
cudaStream_t stream)
template <typename InT, typename OutT = InT, typename IdxType = int>
void addDevScalar(
OutT* outDev, const InT* inDev, const InT* singleScalarDev, IdxType len, cudaStream_t stream)
{
detail::addDevScalar(outDev, inDev, singleScalarDev, len, stream);
}

/**
* @defgroup add Addition Arithmetic
* @{
*/

/**
* @brief Elementwise add operation
* @tparam InType Input Type raft::device_mdspan
* @tparam OutType Output Type raft::device_mdspan
* @param[in] handle raft::handle_t
* @param[in] in1 First Input
* @param[in] in2 Second Input
* @param[out] out Output
*/
template <typename InType,
typename OutType,
typename = raft::enable_if_input_device_mdspan<InType>,
typename = raft::enable_if_output_device_mdspan<OutType>>
void add(const raft::handle_t& handle, InType in1, InType in2, OutType out)
{
using in_value_t = typename InType::value_type;
using out_value_t = typename OutType::value_type;

RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous");
RAFT_EXPECTS(raft::is_row_or_column_major(in1), "Input 1 must be contiguous");
RAFT_EXPECTS(raft::is_row_or_column_major(in2), "Input 2 must be contiguous");
RAFT_EXPECTS(out.size() == in1.size() && in1.size() == in2.size(),
"Size mismatch between Output and Inputs");

if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
add<in_value_t, out_value_t, std::uint32_t>(out.data_handle(),
in1.data_handle(),
in2.data_handle(),
static_cast<std::uint32_t>(out.size()),
handle.get_stream());
} else {
add<in_value_t, out_value_t, std::uint64_t>(out.data_handle(),
in1.data_handle(),
in2.data_handle(),
static_cast<std::uint64_t>(out.size()),
handle.get_stream());
}
}

/**
* @brief Elementwise addition of device scalar to input
* @tparam InType Input Type raft::device_mdspan
* @tparam OutType Output Type raft::device_mdspan
* @tparam ScalarIdxType Index Type of scalar
* @param[in] handle raft::handle_t
* @param[in] in Input
* @param[in] scalar raft::device_scalar_view
* @param[in] out Output
*/
template <typename InType,
typename OutType,
typename ScalarIdxType,
typename = raft::enable_if_input_device_mdspan<InType>,
typename = raft::enable_if_output_device_mdspan<OutType>>
void add_scalar(const raft::handle_t& handle,
InType in,
OutType out,
raft::device_scalar_view<const typename InType::value_type, ScalarIdxType> scalar)
{
using in_value_t = typename InType::value_type;
using out_value_t = typename OutType::value_type;

RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous");
RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous");
RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Input");

if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
addDevScalar<in_value_t, out_value_t, std::uint32_t>(out.data_handle(),
in.data_handle(),
scalar.data_handle(),
static_cast<std::uint32_t>(out.size()),
handle.get_stream());
} else {
addDevScalar<in_value_t, out_value_t, std::uint64_t>(out.data_handle(),
in.data_handle(),
scalar.data_handle(),
static_cast<std::uint64_t>(out.size()),
handle.get_stream());
}
}

/**
* @brief Elementwise addition of host scalar to input
* @tparam InType Input Type raft::device_mdspan
* @tparam OutType Output Type raft::device_mdspan
* @tparam ScalarIdxType Index Type of scalar
* @param[in] handle raft::handle_t
* @param[in] in Input
* @param[in] scalar raft::host_scalar_view
* @param[in] out Output
*/
template <typename InType,
typename OutType,
typename ScalarIdxType,
typename = raft::enable_if_input_device_mdspan<InType>,
typename = raft::enable_if_output_device_mdspan<OutType>>
void add_scalar(const raft::handle_t& handle,
const InType in,
OutType out,
raft::host_scalar_view<const typename InType::value_type, ScalarIdxType> scalar)
{
using in_value_t = typename InType::value_type;
using out_value_t = typename OutType::value_type;

RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous");
RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous");
RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Input");

if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
addScalar<in_value_t, out_value_t, std::uint32_t>(out.data_handle(),
in.data_handle(),
*scalar.data_handle(),
static_cast<std::uint32_t>(out.size()),
handle.get_stream());
} else {
addScalar<in_value_t, out_value_t, std::uint64_t>(out.data_handle(),
in.data_handle(),
*scalar.data_handle(),
static_cast<std::uint64_t>(out.size()),
handle.get_stream());
}
}

/** @} */ // end of group add

}; // end namespace linalg
}; // end namespace raft

Expand Down
Loading

0 comments on commit 7adf15e

Please sign in to comment.