Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] [1/5] Header structure: replace specializations #1437

Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d9801e8
MV: add -inl suffix to header paths
ahendriksen Apr 20, 2023
50b374d
MV: raft_runtime src files
ahendriksen Apr 13, 2023
8974ae3
FIX: add missing includes
ahendriksen Apr 20, 2023
95ef31b
FIX: getWorkspaceSize
ahendriksen Apr 20, 2023
7edcb6a
PREP: Separate rbf_fin_op
ahendriksen Apr 20, 2023
71de7bd
PREP: registers: Add _types header
ahendriksen Apr 20, 2023
541cabc
Change RAFT_COMPILED from INTERFACE to PUBLIC
ahendriksen Apr 20, 2023
d81b14e
Define RAFT_EXPLICIT and RAFT_EXPLICIT_INSTANTIATE_ONLY
ahendriksen Apr 20, 2023
48ea769
Update docs
ahendriksen Apr 20, 2023
e6bb5d5
Replace specializations by split headers
ahendriksen Apr 20, 2023
ff79abf
Deprecate specialization headers
ahendriksen Apr 20, 2023
c9e7413
Add interleaved scan instances
ahendriksen Apr 20, 2023
0c889dc
Separate fused_l2_nn_helpers
ahendriksen Apr 20, 2023
f97b2a8
Remove pairwise_matrix_instantiation_point
ahendriksen Apr 20, 2023
fb637f7
Rename specialization => instantiation
ahendriksen Apr 20, 2023
7b065af
test/neighbors/selection.cu: Expose kFaissMaxK
ahendriksen Apr 20, 2023
bd5611e
Update docs/source/developer_guide.md
ahendriksen Apr 27, 2023
dba178f
Update docs/source/developer_guide.md
ahendriksen Apr 27, 2023
f10f7d3
Update docs/source/using_libraft.md
ahendriksen Apr 27, 2023
6ed6c3c
Update docs/source/using_libraft.md
ahendriksen Apr 27, 2023
564e4b2
dev guide: Point out to document correct header
ahendriksen Apr 27, 2023
fa523aa
Replace !defined by ifndef for consistency
ahendriksen Apr 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 88 additions & 135 deletions cpp/CMakeLists.txt

Large diffs are not rendered by default.

194 changes: 194 additions & 0 deletions cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/core/operators.hpp> // raft::identity_op
#include <raft/distance/detail/distance_ops/all_ops.cuh> // ops::*
#include <raft/distance/detail/distance_ops/cutlass.cuh> // ops::has_cutlass_op
#include <raft/distance/detail/kernels/rbf_fin_op.cuh> // rbf_fin_op
#include <raft/distance/detail/pairwise_matrix/params.cuh> // pairwise_matrix_params
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

namespace raft::distance::detail {

template <typename OpT,
typename DataT,
typename AccT,
typename OutT,
typename FinOpT,
typename IdxT = int>
void pairwise_matrix_dispatch(OpT distance_op,
IdxT m,
IdxT n,
IdxT k,
const DataT* x,
const DataT* y,
const DataT* x_norm,
const DataT* y_norm,
OutT* out,
FinOpT fin_op,
cudaStream_t stream,
bool is_row_major) RAFT_EXPLICIT;

}; // namespace raft::distance::detail

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \
OpT, DataT, AccT, OutT, FinOpT, IdxT) \
extern template void raft::distance::detail:: \
pairwise_matrix_dispatch<OpT<DataT, AccT, IdxT>, DataT, AccT, OutT, FinOpT, IdxT>( \
OpT<DataT, AccT, IdxT> distance_op, \
IdxT m, \
IdxT n, \
IdxT k, \
const DataT* x, \
const DataT* y, \
const DataT* x_norm, \
const DataT* y_norm, \
OutT* out, \
FinOpT fin_op, \
cudaStream_t stream, \
bool is_row_major)

/*
* Hierarchy of instantiations:
*
* This file defines extern template instantiations of the distance kernels. The
* instantiation of the public API is handled in raft/distance/distance-ext.cuh.
*
* After adding an instance here, make sure to also add the instance there.
*/

// The following two instances are used in the RBF kernel object. Note the use of int64_t for the
// index type.
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::l2_unexp_distance_op,
float,
float,
float,
raft::distance::kernels::detail::rbf_fin_op<float>,
int64_t);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::l2_unexp_distance_op,
double,
double,
double,
raft::distance::kernels::detail::rbf_fin_op<double>,
int64_t);

// Rest of instances
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::canberra_distance_op,
double,
double,
double,
raft::identity_op,
int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::correlation_distance_op,
float,
float,
float,
raft::identity_op,
int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::correlation_distance_op,
double,
double,
double,
raft::identity_op,
int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::hamming_distance_op, double, double, double, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::hellinger_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::hellinger_distance_op,
double,
double,
double,
raft::identity_op,
int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::jensen_shannon_distance_op,
float,
float,
float,
raft::identity_op,
int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::jensen_shannon_distance_op,
double,
double,
double,
raft::identity_op,
int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::kl_divergence_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::kl_divergence_op, double, double, double, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::l1_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::l1_distance_op, double, double, double, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::l2_exp_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::l2_exp_distance_op, double, double, double, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::l2_unexp_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::l2_unexp_distance_op,
double,
double,
double,
raft::identity_op,
int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::l_inf_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::l_inf_distance_op, double, double, double, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::lp_unexp_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::lp_unexp_distance_op,
double,
double,
double,
raft::identity_op,
int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::russel_rao_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::russel_rao_distance_op,
double,
double,
double,
raft::identity_op,
int);

#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch
24 changes: 24 additions & 0 deletions cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY)
#include "dispatch-inl.cuh"
#endif

#ifdef RAFT_COMPILED
#include "dispatch-ext.cuh"
#endif
Loading