Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pivoted Cholesky decomposition for Diagonal #54585

Merged
merged 4 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions stdlib/LinearAlgebra/src/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ function _chol!(x::Number, _)
rx = real(x)
iszero(rx) && return (rx, convert(BlasInt, 1))
rxr = sqrt(abs(rx))
rval = convert(promote_type(typeof(x), typeof(rxr)), rxr)
rval = convert(promote_type(typeof(x), typeof(rxr)), rxr)
return (rval, convert(BlasInt, rx != abs(x)))
end

Expand Down Expand Up @@ -400,6 +400,13 @@ function _cholpivoted!(A::AbstractMatrix, ::Type{LowerTriangular}, tol::Real, ch
return A, piv, convert(BlasInt, rank), convert(BlasInt, info)
end
end
function _cholpivoted!(x::Number, tol)
rx = real(x)
iszero(rx) && return (rx, convert(BlasInt, 1))
rxr = sqrt(abs(rx))
rval = convert(promote_type(typeof(x), typeof(rxr)), rxr)
return (rval, convert(BlasInt, !(rx == abs(x) > tol)))
end

# cholesky!. Destructive methods for computing Cholesky factorization of real symmetric
# or Hermitian matrix
Expand Down Expand Up @@ -465,12 +472,12 @@ e.g. for integer types.
function cholesky!(A::AbstractMatrix, ::RowMaximum; tol = 0.0, check::Bool = true)
checksquare(A)
if !ishermitian(A)
C = CholeskyPivoted(A, 'U', Vector{BlasInt}(),convert(BlasInt, 1),
C = CholeskyPivoted(A, 'U', Vector{BlasInt}(), convert(BlasInt, 1),
tol, convert(BlasInt, -1))
check && checkpositivedefinite(-1)
check && checkpositivedefinite(convert(BlasInt, -1))
return C
else
return cholesky!(Hermitian(A), RowMaximum(); tol = tol, check = check)
return cholesky!(Hermitian(A), RowMaximum(); tol, check)
end
end
@deprecate cholesky!(A::StridedMatrix, ::Val{true}; kwargs...) cholesky!(A, RowMaximum(); kwargs...) false
Expand Down
30 changes: 30 additions & 0 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,36 @@ end
@deprecate cholesky!(A::Diagonal, ::Val{false}; check::Bool = true) cholesky!(A::Diagonal, NoPivot(); check) false
@deprecate cholesky(A::Diagonal, ::Val{false}; check::Bool = true) cholesky(A::Diagonal, NoPivot(); check) false

function cholesky!(A::Diagonal, ::RowMaximum; tol=0.0, check=true)
if !ishermitian(A)
C = CholeskyPivoted(A, 'U', Vector{BlasInt}(), convert(BlasInt, 1),
tol, convert(BlasInt, -1))
check && checkpositivedefinite(convert(BlasInt, -1))
else
d = A.diag
n = length(d)
info = 0
rank = n
p = sortperm(d, rev = true, by = real)
tol = tol < 0 ? n*eps(eltype(A))*real(d[p[1]]) : tol # LAPACK behavior
permute!(d, p)
@inbounds for i in eachindex(d)
di = d[i]
rootdi, j = _cholpivoted!(di, tol)
if j == 0
d[i] = rootdi
else
rank = i - 1
info = 1
break
end
end
C = CholeskyPivoted(A, 'U', p, convert(BlasInt, rank), tol, convert(BlasInt, info))
check && chkfullrank(C)
end
return C
end

inv(C::Cholesky{<:Any,<:Diagonal}) = Diagonal(map(inv∘abs2, C.factors.diag))

cholcopy(A::Diagonal) = copymutable_oftype(A, choltype(A))
Expand Down
35 changes: 33 additions & 2 deletions stdlib/LinearAlgebra/test/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ end
@test_throws PosDefException cholesky!(copy(M))
@test_throws PosDefException cholesky(M; check = true)
@test_throws PosDefException cholesky!(copy(M); check = true)
@test !LinearAlgebra.issuccess(cholesky(M; check = false))
@test !LinearAlgebra.issuccess(cholesky!(copy(M); check = false))
@test !issuccess(cholesky(M; check = false))
@test !issuccess(cholesky!(copy(M); check = false))
end
for M in (A, Hermitian(A)) # hermitian, but not semi-positive definite
@test_throws RankDeficientException cholesky(M, RowMaximum())
Expand Down Expand Up @@ -377,15 +377,28 @@ end
@test CD.U ≈ Diagonal(.√d) ≈ CM.U
@test D ≈ CD.L * CD.U
@test CD.info == 0
CD = cholesky(D, RowMaximum())
CM = cholesky(Matrix(D), RowMaximum())
@test CD isa CholeskyPivoted{Float64}
@test CD.U ≈ Diagonal(.√sort(d, rev=true)) ≈ CM.U
@test D ≈ Matrix(CD)
@test CD.info == 0

F = cholesky(Hermitian(I(3)))
@test F isa Cholesky{Float64,<:Diagonal}
@test Matrix(F) ≈ I(3)
F = cholesky(I(3), RowMaximum())
@test F isa CholeskyPivoted{Float64,<:Diagonal}
@test Matrix(F) ≈ I(3)

# real, failing
@test_throws PosDefException cholesky(Diagonal([1.0, -2.0]))
@test_throws RankDeficientException cholesky(Diagonal([1.0, -2.0]), RowMaximum())
Dnpd = cholesky(Diagonal([1.0, -2.0]); check = false)
@test Dnpd.info == 2
Dnpd = cholesky(Diagonal([1.0, -2.0]), RowMaximum(); check = false)
@test Dnpd.info == 1
@test Dnpd.rank == 1

# complex
D = complex(D)
Expand All @@ -395,15 +408,33 @@ end
@test CD.U ≈ Diagonal(.√d) ≈ CM.U
@test D ≈ CD.L * CD.U
@test CD.info == 0
CD = cholesky(D, RowMaximum())
CM = cholesky(Matrix(D), RowMaximum())
@test CD isa CholeskyPivoted{ComplexF64,<:Diagonal}
@test CD.U ≈ Diagonal(.√sort(d, by=real, rev=true)) ≈ CM.U
@test D ≈ Matrix(CD)
@test CD.info == 0

# complex, failing
D[2, 2] = 0.0 + 0im
@test_throws PosDefException cholesky(D)
@test_throws RankDeficientException cholesky(D, RowMaximum())
Dnpd = cholesky(D; check = false)
@test Dnpd.info == 2
Dnpd = cholesky(D, RowMaximum(); check = false)
@test Dnpd.info == 1
@test Dnpd.rank == 2

# InexactError for Int
@test_throws InexactError cholesky!(Diagonal([2, 1]))

# tolerance
D = Diagonal([0.5, 1])
@test_throws RankDeficientException cholesky(D, RowMaximum(), tol=nextfloat(0.5))
CD = cholesky(D, RowMaximum(), tol=nextfloat(0.5), check=false)
@test rank(CD) == 1
@test !issuccess(CD)
@test Matrix(cholesky(D, RowMaximum(), tol=prevfloat(0.5))) ≈ D
end

@testset "Cholesky for AbstractMatrix" begin
Expand Down