Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

alg keyword for svd and svd! #31057

Merged
merged 8 commits into from
Aug 15, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Standard library changes
* The BLAS submodule no longer exports `dot`, which conflicts with that in LinearAlgebra ([#31838]).
* `diagm` and `spdiagm` now accept optional `m,n` initial arguments to specify a size ([#31654]).
* `Hessenberg` factorizations `H` now support efficient shifted solves `(H+µI) \ b` and determinants, and use a specialized tridiagonal factorization for Hermitian matrices. There is also a new `UpperHessenberg` matrix type ([#31853]).
* Added keyword argument `alg` to `svd` and `svd!` that allows one to switch between different SVD algorithms ([#31057]).

#### SparseArrays

Expand Down
6 changes: 6 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ else
const BlasInt = Int32
end


abstract type Algorithm end
struct DivideAndConquer <: Algorithm end
struct GolubReinsch <: Algorithm end
carstenbauer marked this conversation as resolved.
Show resolved Hide resolved


# Check that stride of matrix/vector is 1
# Writing like this to avoid splatting penalty when called with multiple arguments,
# see PR 16416
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ function svd!(M::Bidiagonal{<:BlasReal}; full::Bool = false)
d, e, U, Vt, Q, iQ = LAPACK.bdsdc!(M.uplo, 'I', M.dv, M.ev)
SVD(U, d, Vt)
end
function svd(M::Bidiagonal; full::Bool = false)
svd!(copy(M), full = full)
function svd(M::Bidiagonal; kw...)
svd!(copy(M), kw...)
end

####################
Expand Down
37 changes: 24 additions & 13 deletions stdlib/LinearAlgebra/src/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ function SVD{T}(U::AbstractArray, S::AbstractVector{Tr}, Vt::AbstractArray) wher
convert(AbstractArray{T}, Vt))
end


# iteration for destructuring into components
Base.iterate(S::SVD) = (S.U, Val(:S))
Base.iterate(S::SVD, ::Val{:S}) = (S.S, Val(:V))
Base.iterate(S::SVD, ::Val{:V}) = (S.V, Val(:done))
Base.iterate(S::SVD, ::Val{:done}) = nothing

"""
svd!(A; full::Bool = false) -> SVD
svd!(A; full::Bool = false, alg::Algorithm = DivideAndConquer()) -> SVD
carstenbauer marked this conversation as resolved.
Show resolved Hide resolved

`svd!` is the same as [`svd`](@ref), but saves space by
overwriting the input `A`, instead of creating a copy.
Expand Down Expand Up @@ -92,18 +93,25 @@ julia> A
0.0 0.0 -2.0 0.0 0.0
```
"""
function svd!(A::StridedMatrix{T}; full::Bool = false) where T<:BlasFloat
function svd!(A::StridedMatrix{T}; full::Bool = false, alg::Algorithm = DivideAndConquer()) where T<:BlasFloat
m,n = size(A)
if m == 0 || n == 0
u,s,vt = (Matrix{T}(I, m, full ? m : n), real(zeros(T,0)), Matrix{T}(I, n, n))
else
u,s,vt = LAPACK.gesdd!(full ? 'A' : 'S', A)
if typeof(alg) == DivideAndConquer
carstenbauer marked this conversation as resolved.
Show resolved Hide resolved
u,s,vt = LAPACK.gesdd!(full ? 'A' : 'S', A)
elseif typeof(alg) == GolubReinsch
c = full ? 'A' : 'S'
u,s,vt = LAPACK.gesvd!(c, c, A)
else
throw(ArgumentError("Unsupported value for `alg` keyword."))
end
end
SVD(u,s,vt)
end

"""
svd(A; full::Bool = false) -> SVD
svd(A; full::Bool = false, alg::Algorithm = DivideAndConquer()) -> SVD

Compute the singular value decomposition (SVD) of `A` and return an `SVD` object.

Expand All @@ -120,6 +128,9 @@ and `V` is `N \\times N`, while in the thin factorization `U` is `M
\\times K` and `V` is `N \\times K`, where `K = \\min(M,N)` is the
number of singular values.

If `alg = DivideAndConquer()` (default) a divide-and-conquer algorithm is used to calculate the SVD.
Another (typically slower) option is `alg = GolubReinsch()`.
carstenbauer marked this conversation as resolved.
Show resolved Hide resolved

# Examples
```jldoctest
julia> A = [1. 0. 0. 0. 2.; 0. 0. 3. 0. 0.; 0. 0. 0. 0. 0.; 0. 2. 0. 0. 0.]
Expand All @@ -144,21 +155,21 @@ julia> u == F.U && s == F.S && v == F.V
true
```
"""
function svd(A::StridedVecOrMat{T}; full::Bool = false) where T
svd!(copy_oftype(A, eigtype(T)), full = full)
function svd(A::StridedVecOrMat{T}; full::Bool = false, alg::Algorithm = DivideAndConquer()) where T
svd!(copy_oftype(A, eigtype(T)), full = full, alg = alg)
end
function svd(x::Number; full::Bool = false)
function svd(x::Number; full::Bool = false, alg::Algorithm = DivideAndConquer())
SVD(x == 0 ? fill(one(x), 1, 1) : fill(x/abs(x), 1, 1), [abs(x)], fill(one(x), 1, 1))
end
function svd(x::Integer; full::Bool = false)
svd(float(x), full = full)
function svd(x::Integer; full::Bool = false, alg::Algorithm = DivideAndConquer())
svd(float(x), full = full, alg = alg)
end
function svd(A::Adjoint; full::Bool = false)
s = svd(A.parent, full = full)
function svd(A::Adjoint; full::Bool = false, alg::Algorithm = DivideAndConquer())
s = svd(A.parent, full = full, alg = alg)
return SVD(s.Vt', s.S, s.U')
end
function svd(A::Transpose; full::Bool = false)
s = svd(A.parent, full = full)
function svd(A::Transpose; full::Bool = false, alg::Algorithm = DivideAndConquer())
s = svd(A.parent, full = full, alg = alg)
return SVD(transpose(s.Vt), s.S, transpose(s.U))
end

Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2513,7 +2513,7 @@ eigen(A::AbstractTriangular) = Eigen(eigvals(A), eigvecs(A))
# Generic singular systems
for func in (:svd, :svd!, :svdvals)
@eval begin
($func)(A::AbstractTriangular) = ($func)(copyto!(similar(parent(A)), A))
($func)(A::AbstractTriangular; kwargs...) = ($func)(copyto!(similar(parent(A)), A); kwargs...)
end
end

Expand Down
31 changes: 31 additions & 0 deletions stdlib/LinearAlgebra/test/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,35 @@ aimg = randn(n,n)/2
end
end



@testset "SVD Algorithms" begin
≊(x,y) = isapprox(x,y,rtol=1e-15)
carstenbauer marked this conversation as resolved.
Show resolved Hide resolved

allpos = (v) -> begin
carstenbauer marked this conversation as resolved.
Show resolved Hide resolved
for e in v
e < 0 && return false
end
return true
end

x = [0.1 0.2; 0.3 0.4]

for alg in [LinearAlgebra.GolubReinsch(), LinearAlgebra.DivideAndConquer()]
sx1 = svd(x, alg = alg)
@test sx1.U * Diagonal(sx1.S) * sx1.Vt ≊ x
@test sx1.V * sx1.Vt ≊ I
@test sx1.U * sx1.U' ≊ I
@test allpos(sx1.S)

sx2 = svd!(copy(x), alg = alg)
@test sx2.U * Diagonal(sx2.S) * sx2.Vt ≊ x
@test sx2.V * sx2.Vt ≊ I
@test sx2.U * sx2.U' ≊ I
@test allpos(sx2.S)
end
end



end # module TestSVD