Skip to content

Commit

Permalink
fix styling
Browse files Browse the repository at this point in the history
  • Loading branch information
jinsolp committed Jun 6, 2024
1 parent 7ae655c commit e4398c3
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 25 deletions.
12 changes: 6 additions & 6 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1352,9 +1352,9 @@ void GNND<Data_t, Index_t>::build(Data_t* data,
if (return_distances) {
for (size_t i = 0; i < (size_t)nrow_; i++) {
raft::copy(output_distances + i * build_config_.output_graph_degree,
graph_.h_dists.data_handle() + i * build_config_.node_degree,
build_config_.output_graph_degree,
raft::resource::get_cuda_stream(res));
graph_.h_dists.data_handle() + i * build_config_.node_degree,
build_config_.output_graph_degree,
raft::resource::get_cuda_stream(res));
}
}
Expand Down Expand Up @@ -1431,7 +1431,7 @@ void build(raft::resources const& res,
.internal_node_degree = extended_intermediate_degree,
.max_iterations = params.max_iterations,
.termination_threshold = params.termination_threshold,
.output_graph_degree = params.graph_degree};
.output_graph_degree = params.graph_degree};
GNND<const T, int> nnd(res, build_config);
nnd.build(dataset.data_handle(),
Expand All @@ -1454,8 +1454,8 @@ template <typename T,
typename Accessor =
host_device_accessor<std::experimental::default_accessor<T>, memory_type::host>>
index<DistData_t, IdxT> build(raft::resources const& res,
const index_params& params,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset)
const index_params& params,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset)
{
size_t intermediate_degree = params.intermediate_graph_degree;
size_t graph_degree = params.graph_degree;
Expand Down
10 changes: 5 additions & 5 deletions cpp/include/raft/neighbors/nn_descent.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -57,8 +57,8 @@ namespace raft::neighbors::experimental::nn_descent {
*/
template <typename T, typename IdxT = uint32_t>
index<detail::DistData_t, IdxT> build(raft::resources const& res,
index_params const& params,
raft::device_matrix_view<const T, int64_t, row_major> dataset)
index_params const& params,
raft::device_matrix_view<const T, int64_t, row_major> dataset)
{
return detail::build<T, IdxT>(res, params, dataset);
}
Expand Down Expand Up @@ -131,8 +131,8 @@ void build(raft::resources const& res,
*/
template <typename T, typename IdxT = uint32_t>
index<detail::DistData_t, IdxT> build(raft::resources const& res,
index_params const& params,
raft::host_matrix_view<const T, int64_t, row_major> dataset)
index_params const& params,
raft::host_matrix_view<const T, int64_t, row_major> dataset)
{
return detail::build<T, IdxT>(res, params, dataset);
}
Expand Down
18 changes: 8 additions & 10 deletions cpp/include/raft/neighbors/nn_descent_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@

#include "ann_types.hpp"

#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
Expand Down Expand Up @@ -109,11 +109,11 @@ struct index : ann::index {
* @param res raft::resources is an object mangaging resources
* @param graph_view raft::host_matrix_view<IdxT, int64_t, raft::row_major> for storing knn-graph
*/
index(raft::resources const& res,
raft::host_matrix_view<IdxT, int64_t, raft::row_major> graph_view,
std::optional<raft::device_matrix_view<T, int64_t, row_major>> distances_view =
std::nullopt,
bool return_distances = false)
index(
raft::resources const& res,
raft::host_matrix_view<IdxT, int64_t, raft::row_major> graph_view,
std::optional<raft::device_matrix_view<T, int64_t, row_major>> distances_view = std::nullopt,
bool return_distances = false)
: ann::index(),
res_{res},
metric_{raft::distance::DistanceType::L2Expanded},
Expand All @@ -123,9 +123,7 @@ struct index : ann::index {
distances_view_(distances_view),
return_distances_(return_distances)
{
if(!distances_view.has_value()) {
distances_view_ = distances_.value().view();
}
if (!distances_view.has_value()) { distances_view_ = distances_.value().view(); }
}

/** Distance metric used for clustering. */
Expand Down
12 changes: 8 additions & 4 deletions cpp/test/neighbors/ann_nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,16 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
auto database_host_view = raft::make_host_matrix_view<const DataT, int64_t>(
(const DataT*)database_host.data_handle(), ps.n_rows, ps.dim);
auto index = nn_descent::build<DataT, IdxT>(handle_, index_params, database_host_view);
raft::copy(indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_);
raft::copy(distances_NNDescent.data(), index.distances().data_handle(), queries_size, stream_);
raft::copy(
indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_);
raft::copy(
distances_NNDescent.data(), index.distances().data_handle(), queries_size, stream_);
} else {
auto index = nn_descent::build<DataT, IdxT>(handle_, index_params, database_view);
raft::copy(indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_);
raft::copy(distances_NNDescent.data(), index.distances().data_handle(), queries_size, stream_);
raft::copy(
indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_);
raft::copy(
distances_NNDescent.data(), index.distances().data_handle(), queries_size, stream_);
};
}
resource::sync_stream(handle_);
Expand Down

0 comments on commit e4398c3

Please sign in to comment.