Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: enabled selctg to reorder eigen values from within gees/gges #9655

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 41 additions & 10 deletions base/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -685,9 +685,23 @@ immutable Schur{Ty<:BlasFloat, S<:AbstractMatrix} <: Factorization{Ty}
end
Schur{Ty}(T::AbstractMatrix{Ty}, Z::AbstractMatrix{Ty}, values::Vector) = Schur{Ty, typeof(T)}(T, Z, values)

schurfact!{T<:BlasFloat}(A::StridedMatrix{T}) = Schur(LinAlg.LAPACK.gees!('V', A)...)
schurfact{T<:BlasFloat}(A::StridedMatrix{T}) = schurfact!(copy(A))
schurfact{T}(A::StridedMatrix{T}) = (S = promote_type(Float32,typeof(one(T)/norm(one(T)))); S != T ? schurfact!(convert(AbstractMatrix{S},A)) : schurfact!(copy(A)))
function schurfact!{T<:BlasFloat}(A::StridedMatrix{T},
selctg::Union(Function, Ptr{Void})=C_NULL)
Schur(LinAlg.LAPACK.gees!('V', A, selctg)...)
end
function schurfact{T<:BlasFloat}(A::StridedMatrix{T},
selctg::Union(Function, Ptr{Void})=C_NULL)
schurfact!(copy(A), selctg)
end
function schurfact{T}(A::StridedMatrix{T},
selctg::Union(Function, Ptr{Void})=C_NULL)
S = promote_type(Float32,typeof(one(T)/norm(one(T))))
if S != T
schurfact!(convert(AbstractMatrix{S},A), selctg)
else
schurfact!(copy(A), selctg)
end
end

function getindex(F::Schur, d::Symbol)
(d == :T || d == :Schur) && return F.T
Expand All @@ -696,16 +710,18 @@ function getindex(F::Schur, d::Symbol)
throw(KeyError(d))
end

function schur(A::AbstractMatrix)
SchurF = schurfact(A)
function schur(A::AbstractMatrix, selctg::Union(Function, Ptr{Void})=C_NULL)
SchurF = schurfact(A, selctg)
SchurF[:T], SchurF[:Z], SchurF[:values]
end

# For ordering after computing the decomposition
ordschur!{Ty<:BlasFloat}(Q::StridedMatrix{Ty}, T::StridedMatrix{Ty}, select::Array{Int}) = Schur(LinAlg.LAPACK.trsen!(select, T , Q)...)
ordschur{Ty<:BlasFloat}(Q::StridedMatrix{Ty}, T::StridedMatrix{Ty}, select::Array{Int}) = ordschur!(copy(Q), copy(T), select)
ordschur!{Ty<:BlasFloat}(schur::Schur{Ty}, select::Array{Int}) = (res=ordschur!(schur.Z, schur.T, select); schur[:values][:]=res[:values]; res)
ordschur{Ty<:BlasFloat}(schur::Schur{Ty}, select::Array{Int}) = ordschur(schur.Z, schur.T, select)


immutable GeneralizedSchur{Ty<:BlasFloat, M<:AbstractMatrix} <: Factorization{Ty}
S::M
T::M
Expand All @@ -717,9 +733,23 @@ immutable GeneralizedSchur{Ty<:BlasFloat, M<:AbstractMatrix} <: Factorization{Ty
end
GeneralizedSchur{Ty}(S::AbstractMatrix{Ty}, T::AbstractMatrix{Ty}, alpha::Vector, beta::Vector{Ty}, Q::AbstractMatrix{Ty}, Z::AbstractMatrix{Ty}) = GeneralizedSchur{Ty,typeof(S)}(S, T, alpha, beta, Q, Z)

schurfact!{T<:BlasFloat}(A::StridedMatrix{T}, B::StridedMatrix{T}) = GeneralizedSchur(LinAlg.LAPACK.gges!('V', 'V', A, B)...)
schurfact{T<:BlasFloat}(A::StridedMatrix{T},B::StridedMatrix{T}) = schurfact!(copy(A),copy(B))
schurfact{TA,TB}(A::StridedMatrix{TA}, B::StridedMatrix{TB}) = (S = promote_type(Float32,typeof(one(TA)/norm(one(TA))),TB); schurfact!(S != TA ? convert(AbstractMatrix{S},A) : copy(A), S != TB ? convert(AbstractMatrix{S},B) : copy(B)))
function schurfact!{T<:BlasFloat}(A::StridedMatrix{T}, B::StridedMatrix{T},
selctg::Union(Function, Ptr{Void})=C_NULL)
GeneralizedSchur(LinAlg.LAPACK.gges!('V', 'V', A, B, selctg)...)
end

function schurfact{T<:BlasFloat}(A::StridedMatrix{T},B::StridedMatrix{T},
selctg::Union(Function, Ptr{Void})=C_NULL)
schurfact!(copy(A), copy(B), selctg)
end

function schurfact{TA,TB}(A::StridedMatrix{TA}, B::StridedMatrix{TB},
selctg::Union(Function, Ptr{Void})=C_NULL)
S = promote_type(Float32,typeof(one(TA)/norm(one(TA))),TB)
schurfact!(S != TA ? convert(AbstractMatrix{S},A) : copy(A),
S != TB ? convert(AbstractMatrix{S},B) : copy(B),
selctg)
end

function getindex(F::GeneralizedSchur, d::Symbol)
d == :S && return F.S
Expand All @@ -732,8 +762,9 @@ function getindex(F::GeneralizedSchur, d::Symbol)
throw(KeyError(d))
end

function schur(A::AbstractMatrix, B::AbstractMatrix)
SchurF = schurfact(A, B)
function schur(A::AbstractMatrix, B::AbstractMatrix,
selctg::Union(Function, Ptr{Void})=C_NULL)
SchurF = schurfact(A, B, selctg)
SchurF[:S], SchurF[:T], SchurF[:Q], SchurF[:Z]
end

Expand Down
118 changes: 102 additions & 16 deletions base/linalg/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3358,12 +3358,22 @@ for (orghr, elty) in
end
end
end

# Schur forms
__selctg = C_NULL # use global so we can compile cfunction dynamically

function selctg2c{T}(a::T, b::T, c::T)
global __selctg # get current global for __selctg. Updated in gees/gges
return convert(BlasInt, __selctg(a, b, c))::BlasInt
end


for (gees, gges, elty) in
((:dgees_,:dgges_,:Float64),
(:sgees_,:sgges_,:Float32))
@eval begin
function gees!(jobvs::BlasChar, A::StridedMatrix{$elty})
function gees!(jobvs::BlasChar, A::StridedMatrix{$elty},
selctg::Union(Function, Ptr{Void})=C_NULL)
# .. Scalar Arguments ..
# CHARACTER JOBVS, SORT
# INTEGER INFO, LDA, LDVS, LWORK, N, SDIM
Expand All @@ -3382,16 +3392,35 @@ for (gees, gges, elty) in
work = Array($elty, 1)
lwork = blas_int(-1)
info = Array(BlasInt, 1)
sort = selctg == C_NULL ? 'N' : 'S'

# NOTE: type of bwork in ccall type tuple is always Ptr{BlasInt}.
# This works because if sort is 'N' (or selctg is C_NULL)
# then LAPACK never touches this object so incorrect
# type won't cause problems
bwork = selctg == C_NULL ? C_NULL : Array(BlasInt, n)

# create inner_selctg to be passed to function
if selctg == C_NULL
inner_selctg = C_NULL
else
# change __selctg global to our value
global __selctg
__selctg = selctg
inner_selctg = cfunction(selctg2c, BlasInt,
($elty, $elty, $elty))
end

for i = 1:2
ccall(($(blasfunc(gees)), liblapack), Void,
(Ptr{BlasChar}, Ptr{BlasChar}, Ptr{Void}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{Void}, Ptr{BlasInt}),
&jobvs, &'N', C_NULL, &n,
Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
&jobvs, &sort, inner_selctg, &n,
A, &max(1, n), sdim, wr,
wi, vs, &ldvs, work,
&lwork, C_NULL, info)
&lwork, bwork, info)
@lapackerror
if lwork < 0
lwork = blas_int(real(work[1]))
Expand All @@ -3400,7 +3429,9 @@ for (gees, gges, elty) in
end
A, vs, all(wi .== 0) ? wr : complex(wr, wi)
end
function gges!(jobvsl::Char, jobvsr::Char, A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
function gges!(jobvsl::Char, jobvsr::Char, A::StridedMatrix{$elty},
B::StridedMatrix{$elty},
selctg::Union(Function, Ptr{Void})=C_NULL)
# * .. Scalar Arguments ..
# CHARACTER JOBVSL, JOBVSR, SORT
# INTEGER INFO, LDA, LDB, LDVSL, LDVSR, LWORK, N, SDIM
Expand All @@ -3424,19 +3455,34 @@ for (gees, gges, elty) in
work = Array($elty, 1)
lwork = blas_int(-1)
info = Array(BlasInt, 1)
sort = selctg == C_NULL ? 'N' : 'S'

# NOTE: See note above
bwork = selctg == C_NULL ? C_NULL : Array(BlasInt, n)

# create inner_selctg to be passed to blasfunc
if selctg == C_NULL
inner_selctg = C_NULL
else
# change __selctg global to our value
global __selctg
__selctg = selctg
inner_selctg = cfunction(selctg2c, BlasInt,
($elty, $elty, $elty))
end
for i = 1:2
ccall(($(blasfunc(gges)), liblapack), Void,
(Ptr{BlasChar}, Ptr{BlasChar}, Ptr{BlasChar}, Ptr{Void},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{Void},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}),
&jobvsl, &jobvsr, &'N', C_NULL,
&jobvsl, &jobvsr, &sort, inner_selctg,
&n, A, &max(1,n), B,
&max(1,n), &sdim, alphar, alphai,
beta, vsl, &ldvsl, vsr,
&ldvsr, work, &lwork, C_NULL,
&ldvsr, work, &lwork, bwork,
info)
if i == 1
lwork = blas_int(real(work[1]))
Expand All @@ -3448,11 +3494,19 @@ for (gees, gges, elty) in
end
end
end

# Complex Schur forms
function zselctg2c{T}(a::T, b::T)
global __selctg # get current __selctg. Updated in zgees/zgges
return convert(BlasInt, __selctg(a, b))::BlasInt
end

for (gees, gges, elty, relty) in
((:zgees_,:zgges_,:Complex128,:Float64),
(:cgees_,:cgges_,:Complex64,:Float32))
@eval begin
function gees!(jobvs::BlasChar, A::StridedMatrix{$elty})
function gees!(jobvs::BlasChar, A::StridedMatrix{$elty},
selctg::Union(Function, Ptr{Void})=C_NULL)
# * .. Scalar Arguments ..
# CHARACTER JOBVS, SORT
# INTEGER INFO, LDA, LDVS, LWORK, N, SDIM
Expand All @@ -3472,16 +3526,31 @@ for (gees, gges, elty, relty) in
lwork = blas_int(-1)
rwork = Array($relty, n)
info = Array(BlasInt, 1)
sort = selctg == C_NULL ? 'N' : 'S'

# NOTE: See note above
bwork = selctg == C_NULL ? C_NULL : Array(BlasInt, n)

# create inner_selctg to be passed to blasfunc
if selctg == C_NULL
inner_selctg = C_NULL
else
# change __selctg global to our value
global __selctg
__selctg = selctg
inner_selctg = cfunction(zselctg2c, BlasInt,
($elty, $elty))
end
for i = 1:2
ccall(($(blasfunc(gees)), liblapack), Void,
(Ptr{BlasChar}, Ptr{BlasChar}, Ptr{Void}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$relty}, Ptr{Void}, Ptr{BlasInt}),
&jobvs, &sort, C_NULL, &n,
Ptr{$relty}, Ptr{BlasInt}, Ptr{BlasInt}),
&jobvs, &sort, inner_selctg, &n,
A, &max(1, n), &sdim, w,
vs, &ldvs, work, &lwork,
rwork, C_NULL, info)
rwork, bwork, info)
@lapackerror
if lwork < 0
lwork = blas_int(real(work[1]))
Expand All @@ -3490,7 +3559,9 @@ for (gees, gges, elty, relty) in
end
A, vs, w
end
function gges!(jobvsl::Char, jobvsr::Char, A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
function gges!(jobvsl::Char, jobvsr::Char, A::StridedMatrix{$elty},
B::StridedMatrix{$elty},
selctg::Union(Function, Ptr{Void})=C_NULL)
# * .. Scalar Arguments ..
# CHARACTER JOBVSL, JOBVSR, SORT
# INTEGER INFO, LDA, LDB, LDVSL, LDVSR, LWORK, N, SDIM
Expand All @@ -3515,19 +3586,34 @@ for (gees, gges, elty, relty) in
lwork = blas_int(-1)
rwork = Array($relty, 8n)
info = Array(BlasInt, 1)
sort = selctg == C_NULL ? 'N' : 'S'

# NOTE: See note above
bwork = selctg == C_NULL ? C_NULL : Array(BlasInt, n)

# create inner_selctg to be passed to blasfunc
if selctg == C_NULL
inner_selctg = C_NULL
else
# change __selctg global to our value
global __selctg
__selctg = selctg
inner_selctg = cfunction(zselctg2c, BlasInt,
($elty, $elty))
end
for i = 1:2
ccall(($(blasfunc(gges)), liblapack), Void,
(Ptr{BlasChar}, Ptr{BlasChar}, Ptr{BlasChar}, Ptr{Void},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$relty}, Ptr{Void},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$relty}, Ptr{BlasInt},
Ptr{BlasInt}),
&jobvsl, &jobvsr, &'N', C_NULL,
&jobvsl, &jobvsr, &sort, inner_selctg,
&n, A, &max(1,n), B,
&max(1,n), &sdim, alpha, beta,
vsl, &ldvsl, vsr, &ldvsr,
work, &lwork, rwork, C_NULL,
work, &lwork, rwork, bwork,
info)
if i == 1
lwork = blas_int(real(work[1]))
Expand Down
29 changes: 29 additions & 0 deletions test/linalg1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,20 @@ debug && println("Schur")
@test istriu(f[:Schur]) || iseltype(a,Real)
end

debug && println("Sorted Schur (with selctg)")
if eltya != BigFloat && eltyb != BigFloat # Revisit when implemented in julia
if eltya <: Complex
selctg(a, b) = abs(a / b) <= 1 ? 1 : 0
else
selctg(a, b, c) = (a+b)/c <= 1 ? 1 : 0
end
f = schurfact(a, selctg)
@test_approx_eq f[:vectors]*f[:Schur]*f[:vectors]' a
@test_approx_eq sort(real(f[:values])) sort(real(d))
@test_approx_eq sort(imag(f[:values])) sort(imag(d))
@test istriu(f[:Schur]) || iseltype(a,Real)
end

debug && println("Reorder Schur")
if eltya != BigFloat && eltyb != BigFloat # Revisit when implemented in julia
# use asym for real schur to enforce tridiag structure
Expand All @@ -200,6 +214,21 @@ debug && println("Reorder Schur")
@test_approx_eq O[:vectors]*O[:Schur]*O[:vectors]' ordschura
end

debug && println("Sorted Generalized (with selctg)")
if eltya != BigFloat && eltyb != BigFloat # Revisit when implemented in julia
if eltya <: Complex
selctg(a, b) = abs(a) == 0.0 ? 1 : 0
else
selctg(a, b, c) = (a+b)/c <= 1 ? 1 : 0
end

f = schurfact(a[1:5,1:5], a[6:10,6:10], selctg)
@test_approx_eq f[:Q]*f[:S]*f[:Z]' a[1:5,1:5]
@test_approx_eq f[:Q]*f[:T]*f[:Z]' a[6:10,6:10]
@test istriu(f[:S]) || iseltype(a,Real)
@test istriu(f[:T]) || iseltype(a,Real)
end

debug && println("Generalized Schur")
if eltya != BigFloat && eltyb != BigFloat # Revisit when implemented in julia
f = schurfact(a[1:5,1:5], a[6:10,6:10])
Expand Down