diff --git a/lib/cusolver/dense_generic.jl b/lib/cusolver/dense_generic.jl index 90995be7d2..93997424b5 100644 --- a/lib/cusolver/dense_generic.jl +++ b/lib/cusolver/dense_generic.jl @@ -190,6 +190,7 @@ end # Xlarft! function larft!(direct::Char, storev::Char, v::StridedCuMatrix{T}, tau::StridedCuVector{T}, t::StridedCuMatrix{T}) where {T <: BlasFloat} + CUSOLVER.version() < v"11.6.0" && throw(ErrorException("This operation is not supported by the current CUDA version.")) n, k = size(v) ktau = length(tau) mt, nt = size(t) @@ -449,6 +450,7 @@ end # Xgeev function Xgeev!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat} + CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version.")) n = checksquare(A) VL = if jobvl == 'V' CuMatrix{T}(undef, n, n) @@ -492,6 +494,44 @@ function Xgeev!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}) where {T <: Bla return W, VL, VR end +# XsyevBatched +function XsyevBatched!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat} + CUSOLVER.version() < v"11.7.1" && throw(ErrorException("This operation is not supported by the current CUDA version.")) + chkuplo(uplo) + n, num_matrices = size(A) + batch_size = num_matrices ÷ n + R = real(T) + lda = max(1, stride(A, 2)) + W = CuVector{R}(undef, n * batch_size) + params = CuSolverParameters() + dh = dense_handle() + resize!(dh.info, batch_size) + + function bufferSize() + out_cpu = Ref{Csize_t}(0) + out_gpu = Ref{Csize_t}(0) + cusolverDnXsyevBatched_bufferSize(dh, params, jobz, uplo, n, + T, A, lda, R, W, T, out_gpu, out_cpu, batch_size) + out_gpu[], out_cpu[] + end + with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu + cusolverDnXsyevBatched(dh, params, jobz, uplo, n, T, A, + lda, R, W, T, buffer_gpu, sizeof(buffer_gpu), + buffer_cpu, sizeof(buffer_cpu), dh.info, batch_size) + end + + info = @allowscalar collect(dh.info) + for i = 1:batch_size + chkargsok(info[i] |> BlasInt) + end + + if jobz == 'N' + return W + elseif jobz == 'V' + return W, A + end +end + # LAPACK for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64) @eval begin diff --git a/test/libraries/cusolver/dense_generic.jl b/test/libraries/cusolver/dense_generic.jl index ffe0ac549c..77a97c4155 100644 --- a/test/libraries/cusolver/dense_generic.jl +++ b/test/libraries/cusolver/dense_generic.jl @@ -31,6 +31,36 @@ p = 5 end end end + + @testset "syevBatched!" begin + batch_size = 5 + for uplo in ('L', 'U') + (uplo == 'L') && (elty == ComplexF32) && continue + + A = rand(elty, n, n * batch_size) + B = rand(elty, n, n * batch_size) + for i = 1:batch_size + S = rand(elty,n,n) + S = S * S' + I + B[:,(i-1)*n+1:i*n] .= S + S = uplo == 'L' ? tril(S) : triu(S) + A[:,(i-1)*n+1:i*n] .= S + end + d_A = CuMatrix(A) + d_W, d_V = CUSOLVER.XsyevBatched!('V', uplo, d_A) + W = collect(d_W) + V = collect(d_V) + for i = 1:batch_size + Bᵢ = B[:,(i-1)*n+1:i*n] + Wᵢ = Diagonal(W[(i-1)*n+1:i*n]) + Vᵢ = V[:,(i-1)*n+1:i*n] + @test Bᵢ * Vᵢ ≈ Vᵢ * Diagonal(Wᵢ) + end + + d_A = CuMatrix(A) + d_W = CUSOLVER.XsyevBatched!('N', uplo, d_A) + end + end end if CUSOLVER.version() >= v"11.6.0"