Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
bugfix and GPU support for syevd (#8426)
Browse files Browse the repository at this point in the history
  • Loading branch information
asmushetzel authored and piiswrong committed Oct 28, 2017
1 parent 68038ad commit c5afa1f
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 47 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,4 @@ List of Contributors
* [Xizhou Zhu](https://github.com/einsiedler0408/)
* [Jean Kossaifi](https://github.com/JeanKossaifi/)
* [Kenta Kubo](https://github.com/kkk669/)
* [Manu Seth](https://github.com/mseth10/)
10 changes: 5 additions & 5 deletions src/operator/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,21 +168,21 @@ int linalg_gelqf_workspace_query(const Tensor<xpu, 2, DType>& A,
// CPU/GPU-versions of LAPACK function "syevd". Please refer to the
// LAPACK documentation for further details.
// Note:
// - The current implementation works for CPU only
// - A is input and output parameter (overwritten by U)
// - Input A is symmetric, we access the lower triangle only
// - Requires two workspace arrays, one in DType, other in int.

template<typename xpu, typename DType>
void linalg_syevd(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 1, DType>& L,
const Tensor<xpu, 1, DType>& work,
const Tensor<xpu, 1, int>& iwork, Stream<xpu> *s = 0);
Stream<xpu> *s = 0);

// This function determines the amount of workspace needed for linalg_syevd
// which is returned as number of elements of type DType.
template<typename xpu, typename DType>
void linalg_syevd_workspace_query(const Tensor<xpu, 2, DType>& A, int* lwork,
int* liwork, Stream<xpu> *s = 0);
int linalg_syevd_workspace_query(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 1, DType>& L,
Stream<xpu> *s = 0);

#include "linalg_impl.h"

Expand Down
109 changes: 89 additions & 20 deletions src/operator/linalg_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -985,42 +985,111 @@ template<> inline \
void linalg_syevd<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
const Tensor<cpu, 1, DType>& L, \
const Tensor<cpu, 1, DType>& work, \
const Tensor<cpu, 1, int>& iwork, \
Stream<cpu> *s) { \
check_syevd(A, L); \
int liwork(0); \
MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, 'L', A.size(0), \
A.dptr_, A.stride_, L.dptr_, work.dptr_, -1, &liwork, \
-1); \
int lwork(static_cast<int>(*work.dptr_)); \
int *iwork = static_cast<int*>(static_cast<void*>(work.dptr_ + lwork)); \
int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, 'L', A.size(0), \
A.dptr_, A.stride_, L.dptr_, work.dptr_, \
work.size(0), iwork.dptr_, iwork.size(0))); \
lwork, iwork, liwork)); \
CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \
}
// LINALG_CPU_SYEVD(ssyevd, float)
LINALG_CPU_SYEVD(ssyevd, float)
LINALG_CPU_SYEVD(dsyevd, double)

template<> inline
void linalg_syevd<cpu, float>(const Tensor<cpu, 2, float>& A,
const Tensor<cpu, 1, float>& L,
const Tensor<cpu, 1, float>& work,
const Tensor<cpu, 1, int>& iwork,
Stream<cpu> *s) {
CHECK(false) << "linalg_syevd is not currently implemented for float32." << std::endl
<< "Please use float64 for now. If the rest of your code runs on float32,"
<< " please use the Cast operator.";
}

// Mangle temp storage requirements for DType and int into a single
// request as we can only allocate one temp space per operator. We
// partition this temp space into two chunks again when calling sseyvd.
// Returned is the number of elements of type DType that the temp space
// needs to accomodate. This also makes this function signature equivalent
// to the work space query on GPU.
#define LINALG_CPU_SYEVD_WORKSPACE_QUERY(func, DType) \
template<> inline \
void linalg_syevd_workspace_query<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
int* lwork, int* liwork, \
Stream<cpu> *s) { \
int linalg_syevd_workspace_query<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
const Tensor<cpu, 1, DType>& L, \
Stream<cpu> *s) { \
DType work(0.0); \
int iwork(0); \
MXNET_LAPACK_##func(MXNET_LAPACK_ROW_MAJOR, 'L', A.size(0), \
A.dptr_, A.stride_, &work, &work, -1, &iwork, \
A.dptr_, A.stride_, L.dptr_, &work, -1, &iwork, \
-1); \
*lwork = static_cast<int>(work); \
*liwork = iwork; \
iwork = (sizeof(int) * iwork + sizeof(DType) - 1) / sizeof(DType); \
return static_cast<int>(work) + iwork; \
}
LINALG_CPU_SYEVD_WORKSPACE_QUERY(ssyevd, float)
LINALG_CPU_SYEVD_WORKSPACE_QUERY(dsyevd, double)

#ifdef __CUDACC__

// SYEVD only available with cuda8 or higher.
#if CUDA_VERSION >= 8000

// Row-major vs. col-major handled by using upper triangular
// in cusolver-call.
#define LINALG_GPU_SYEVD(fname, DType) \
template<> inline \
void linalg_syevd<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
const Tensor<gpu, 1, DType>& L, \
const Tensor<gpu, 1, DType>& work, \
Stream<gpu> *s) { \
using namespace mxnet; \
using mshadow::gpu; \
CHECK_NOTNULL(s); \
check_syevd(A, L); \
Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \
CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \
CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, \
A.size(0), A.dptr_ , A.stride_, L.dptr_, work.dptr_, \
work.size(0), static_cast<int *>(info.dptr))); \
Storage::Get()->Free(info); \
}

#define LINALG_GPU_SYEVD_WORKSPACE_QUERY(fname, DType) \
template<> inline \
int linalg_syevd_workspace_query<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
const Tensor<gpu, 1, DType>& L, \
Stream<gpu> *s) { \
using namespace mxnet; \
using mshadow::gpu; \
int lwork(0); \
CUSOLVER_CALL(cusolver##fname##_bufferSize(Stream<gpu>::GetSolverHandle(s), \
CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, \
A.size(0), A.dptr_ , A.stride_, L.dptr_, &lwork)); \
return lwork; \
}

#else

#define LINALG_GPU_SYEVD(fname, DType) \
template<> inline \
void linalg_syevd<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
const Tensor<gpu, 1, DType>& L, \
const Tensor<gpu, 1, DType>& work, \
Stream<gpu> *s) { \
LOG(FATAL) << "syevd requires CUDA version >= 8.0!"; \
}

#define LINALG_GPU_SYEVD_WORKSPACE_QUERY(fname, DType) \
template<> inline \
int linalg_syevd_workspace_query<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
const Tensor<gpu, 1, DType>& L, \
Stream<gpu> *s) { \
LOG(FATAL) << "syevd requires CUDA version >= 8.0!"; \
return 0; \
}

#endif // CUDA_VERSION >= 8000

LINALG_GPU_SYEVD(DnSsyevd, float)
LINALG_GPU_SYEVD(DnDsyevd, double)

LINALG_GPU_SYEVD_WORKSPACE_QUERY(DnSsyevd, float)
LINALG_GPU_SYEVD_WORKSPACE_QUERY(DnDsyevd, double)

#endif // __CUDACC__

#endif // MXNET_OPERATOR_LINALG_IMPL_H_
4 changes: 0 additions & 4 deletions src/operator/tensor/la_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -571,10 +571,6 @@ mode). In this case, *U* has *n* dimensions like *A*, and *L* has *n-1* dimensio
.. note:: The operator supports float32 and float64 data types only.
.. note:: For the time being, this operator supports the float64 data type only. If the
rest of your expression uses float32, please apply the Cast operator to inputs
and outputs.
.. note:: Derivatives for this operator are defined only if *A* is such that all its
eigenvalues are distinct, and the eigengaps are not too small. If you need
gradients, do not apply this operator to matrices with multiple eigenvalues.
Expand Down
6 changes: 6 additions & 0 deletions src/operator/tensor/la_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ NNVM_REGISTER_OP(_linalg_gelqf)
NNVM_REGISTER_OP(_backward_linalg_gelqf)
.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 4, 1, gelqf_backward>);

NNVM_REGISTER_OP(_linalg_syevd)
.set_attr<FCompute>("FCompute<gpu>", LaOpForwSyevd<gpu, syevd>);

NNVM_REGISTER_OP(_backward_linalg_syevd)
.set_attr<FCompute>("FCompute<gpu>", LaOpBackwSyevd<gpu, syevd_backward>);

#endif

} // namespace op
Expand Down
17 changes: 7 additions & 10 deletions src/operator/tensor/la_op_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,16 +324,13 @@ struct syevd {
linalg_check_batch_size(A.size(0), U.size(0), L.size(0));
if (A.dptr_ != U.dptr_) Copy(U, A, s);
// From here on, we work on U only
// Reserve workspaces (size determined by query)
int lwork(0), liwork(0);
linalg_syevd_workspace_query(U[0], &lwork, &liwork, s);
// Reserve workspace (size determined by query)
int lwork(linalg_syevd_workspace_query(U[0], L[0], s));
Tensor<xpu, 1, DType> work = ctx.requested[0]
.get_space_typed<xpu, 1, DType>(Shape1(lwork), s);
Tensor<xpu, 1, int> iwork = ctx.requested[0]
.get_space_typed<xpu, 1, int>(Shape1(liwork), s);
// Loop over items in batch
for (index_t i = 0; i < U.size(0); ++i) {
linalg_syevd(U[i], L[i], work, iwork, s);
linalg_syevd(U[i], L[i], work, s);
}
// Set signs of eigenvectors in a deterministic way
using namespace mxnet_op;
Expand Down Expand Up @@ -603,13 +600,13 @@ struct gelqf_backward {
template<typename DType>
DType syevd_back_helper_eps(DType* X);

template<> inline
float syevd_back_helper_eps(float* X) {
template<>
MSHADOW_XINLINE float syevd_back_helper_eps(float* X) {
return 1e-30;
}

template<> inline
double syevd_back_helper_eps(double* X) {
template<>
MSHADOW_XINLINE double syevd_back_helper_eps(double* X) {
return 1e-100;
}

Expand Down
16 changes: 8 additions & 8 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4134,9 +4134,11 @@ def _syevd_backward(grad_u, grad_l, u, l):
return np.dot(temp3, u)

def test_laop_3():
# Operators implemented for CPU only currently
# Currently disabled on GPU as syevd needs cuda8
# and MxNet builds use cuda 7.5
if not (default_context() == mx.cpu()):
return

np.random.seed(1896893923)
dtype = np.float64
rtol_fw = 1e-6
Expand All @@ -4148,7 +4150,6 @@ def test_laop_3():
grad_check = 1

data1 = mx.symbol.Variable('data1')

check_fw = lambda sym, location, expected :\
check_symbolic_forward(sym, location, expected, rtol=rtol_fw,
atol=atol_fw, dtype=dtype)
Expand Down Expand Up @@ -4200,13 +4201,12 @@ def test_laop_3():
check_grad(test_syevd_l_4, [a_batch])


# Note: Currently, linalg.syevd is activated for float64 only, due to the issues
# demonstrated by this unit test. For this reason, the second part of this test
# (float32) is deactivated for now.
def test_laop_4():
# Operators implemented for CPU only currently
if not(default_context() == mx.cpu()):
# Currently disabled on GPU as syevd needs cuda8
# and MxNet builds use cuda 7.5
if not (default_context() == mx.cpu()):
return

np.random.seed(1896893923)
rtol_fw = 1e-6
atol_fw = 1e-6
Expand All @@ -4226,7 +4226,7 @@ def test_laop_4():
check_fw(test_syevd, [a_np], [u_np, l_np], np.float64)
# float32
#print('float32')
#check_fw(test_syevd, [a_np], [u_np, l_np], np.float32)
check_fw(test_syevd, [a_np], [u_np, l_np], np.float32)


def test_stack():
Expand Down

0 comments on commit c5afa1f

Please sign in to comment.