Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for svd API #1190

Merged
merged 8 commits into from
Feb 4, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions cpp/include/raft/linalg/svd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -214,22 +214,26 @@ void svd_qr(raft::device_resources const& handle,
std::forward<UType>(U_in);
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> V =
std::forward<VType>(V_in);
ValueType* left_sing_vecs_ptr = nullptr;
ValueType* right_sing_vecs_ptr = nullptr;

if (U) {
RAFT_EXPECTS(in.extent(0) == U.value().extent(0) && in.extent(1) == U.value().extent(1),
"U should have dimensions m * n");
left_sing_vecs_ptr = U.value().data_handle();
}
if (V) {
RAFT_EXPECTS(in.extent(1) == V.value().extent(0) && in.extent(1) == V.value().extent(1),
"V should have dimensions n * n");
right_sing_vecs_ptr = V.value().data_handle();
}
svdQR(handle,
const_cast<ValueType*>(in.data_handle()),
in.extent(0),
in.extent(1),
sing_vals.data_handle(),
U.value().data_handle(),
V.value().data_handle(),
left_sing_vecs_ptr,
right_sing_vecs_ptr,
false,
U.has_value(),
V.has_value(),
Expand Down Expand Up @@ -278,22 +282,26 @@ void svd_qr_transpose_right_vec(
std::forward<UType>(U_in);
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> V =
std::forward<VType>(V_in);
ValueType* left_sing_vecs_ptr = nullptr;
ValueType* right_sing_vecs_ptr = nullptr;

if (U) {
RAFT_EXPECTS(in.extent(0) == U.value().extent(0) && in.extent(1) == U.value().extent(1),
"U should have dimensions m * n");
left_sing_vecs_ptr = U.value().data_handle();
}
if (V) {
RAFT_EXPECTS(in.extent(1) == V.value().extent(0) && in.extent(1) == V.value().extent(1),
"V should have dimensions n * n");
right_sing_vecs_ptr = V.value().data_handle();
}
svdQR(handle,
const_cast<ValueType*>(in.data_handle()),
in.extent(0),
in.extent(1),
sing_vals.data_handle(),
U.value().data_handle(),
V.value().data_handle(),
left_sing_vecs_ptr,
right_sing_vecs_ptr,
true,
U.has_value(),
V.has_value(),
Expand All @@ -320,7 +328,7 @@ void svd_qr_transpose_right_vec(Args... args)
* @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N)
* @param[out] S singular values raft::device_vector_view of shape (K)
* @param[out] V right singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (m, n)
* raft::col_major and dimensions (n, n)
* @param[out] U optional left singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (m, n)
*/
Expand All @@ -332,38 +340,52 @@ void svd_eig(
raft::device_matrix_view<ValueType, IndexType, raft::col_major> V,
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U = std::nullopt)
{
ValueType* left_sing_vecs_ptr = nullptr;
if (U) {
RAFT_EXPECTS(in.extent(0) == U.value().extent(0) && in.extent(1) == U.value().extent(1),
"U should have dimensions m * n");
left_sing_vecs_ptr = U.value().data_handle();
}
RAFT_EXPECTS(in.extent(0) == V.extent(0) && in.extent(1) == V.extent(1),
RAFT_EXPECTS(in.extent(1) == V.extent(0) && in.extent(1) == V.extent(1),
"V should have dimensions n * n");
svdEig(handle,
const_cast<ValueType*>(in.data_handle()),
in.extent(0),
in.extent(1),
S.data_handle(),
U.value().data_handle(),
V.value().data_handle(),
left_sing_vecs_ptr,
V.data_handle(),
U.has_value(),
handle.get_stream());
}

template <typename ValueType, typename IndexType, typename UType>
void svd_eig(const raft::handle_t& handle,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> in,
raft::device_vector_view<ValueType, IndexType> S,
raft::device_matrix_view<ValueType, IndexType, raft::col_major> V,
UType&& U = std::nullopt)
lowener marked this conversation as resolved.
Show resolved Hide resolved
{
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U_optional =
std::forward<UType>(U);
svd_eig(handle, in, S, V, U_optional);
}

/**
* @brief reconstruct a matrix use left and right singular vectors and
* singular values
* @param[in] handle raft::device_resources
* @param[in] U left singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (m, k)
* @param[in] S singular values raft::device_vector_view of shape (k, k)
* @param[in] S square matrix with singular values on its diagonal of shape (k, k)
* @param[in] V right singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (k, n)
* @param[out] out output raft::device_matrix_view with layout raft::col_major of shape (m, n)
*/
template <typename ValueType, typename IndexType>
void svd_reconstruction(raft::device_resources const& handle,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> U,
raft::device_vector_view<const ValueType, IndexType> S,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> S,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> V,
raft::device_matrix_view<ValueType, IndexType, raft::col_major> out)
{
Expand All @@ -380,6 +402,7 @@ void svd_reconstruction(raft::device_resources const& handle,
const_cast<ValueType*>(U.data_handle()),
const_cast<ValueType*>(S.data_handle()),
const_cast<ValueType*>(V.data_handle()),
out.data_handle(),
out.extent(0),
out.extent(1),
S.extent(0),
Expand Down