Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] [2/5] Header structure: remove specialization includes #1438

Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Separate fused_l2_nn_helpers
These types are not used in the ext header, but are useful to have.
ahendriksen committed Apr 20, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 0c889dc64a676287f43db60b884e26d882a2a278
16 changes: 5 additions & 11 deletions cpp/include/raft/distance/fused_l2_nn-ext.cuh
Original file line number Diff line number Diff line change
@@ -16,23 +16,17 @@

#pragma once

#include <cstdint> // int64_t
#include <raft/core/device_resources.hpp> // raft::device_resources
#include <raft/core/kvp.hpp> // raft::KeyValuePair
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
#include <cstdint> // int64_t
#include <raft/core/device_resources.hpp> // raft::device_resources
#include <raft/core/kvp.hpp> // raft::KeyValuePair
#include <raft/distance/fused_l2_nn_helpers.cuh> // include initialize and reduce operations
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

namespace raft {
namespace distance {

template <typename DataT, typename OutT, typename IdxT, typename ReduceOpT>
void initialize(raft::device_resources const& handle,
OutT* min,
IdxT m,
DataT maxVal,
ReduceOpT redOp) RAFT_EXPLICIT;

template <typename DataT, typename OutT, typename IdxT>
void fusedL2NNMinReduce(OutT* min,
const DataT* x,
26 changes: 1 addition & 25 deletions cpp/include/raft/distance/fused_l2_nn-inl.cuh
Original file line number Diff line number Diff line change
@@ -23,38 +23,14 @@
#include <limits>
#include <raft/core/device_resources.hpp>
#include <raft/distance/detail/fused_l2_nn.cuh>
#include <raft/distance/fused_l2_nn_helpers.cuh>
#include <raft/linalg/contractions.cuh>
#include <raft/util/cuda_utils.cuh>
#include <stdint.h>
#include <type_traits>

namespace raft {
namespace distance {
/**
* \defgroup fused_l2_nn Fused 1-nearest neighbors
* @{
*/

template <typename LabelT, typename DataT>
using KVPMinReduce = detail::KVPMinReduceImpl<LabelT, DataT>;

template <typename LabelT, typename DataT>
using MinAndDistanceReduceOp = detail::MinAndDistanceReduceOpImpl<LabelT, DataT>;

template <typename LabelT, typename DataT>
using MinReduceOp = detail::MinReduceOpImpl<LabelT, DataT>;

/** @} */

/**
* Initialize array using init value from reduction op
*/
template <typename DataT, typename OutT, typename IdxT, typename ReduceOpT>
void initialize(
raft::device_resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp)
{
detail::initialize<DataT, OutT, IdxT, ReduceOpT>(min, m, maxVal, redOp, handle.get_stream());
}

/**
* \ingroup fused_l2_nn
49 changes: 49 additions & 0 deletions cpp/include/raft/distance/fused_l2_nn_helpers.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/distance/detail/fused_l2_nn.cuh>

namespace raft::distance {

/**
* \defgroup fused_l2_nn Fused 1-nearest neighbors
* @{
*/

template <typename LabelT, typename DataT>
using KVPMinReduce = detail::KVPMinReduceImpl<LabelT, DataT>;

template <typename LabelT, typename DataT>
using MinAndDistanceReduceOp = detail::MinAndDistanceReduceOpImpl<LabelT, DataT>;

template <typename LabelT, typename DataT>
using MinReduceOp = detail::MinReduceOpImpl<LabelT, DataT>;

/** @} */

/**
* Initialize array using init value from reduction op
*/
template <typename DataT, typename OutT, typename IdxT, typename ReduceOpT>
void initialize(
raft::device_resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp)
{
detail::initialize<DataT, OutT, IdxT, ReduceOpT>(min, m, maxVal, redOp, handle.get_stream());
}

} // namespace raft::distance