Skip to content

Commit

Permalink
Add unwrapping mechanism for triangular matrices
Browse files Browse the repository at this point in the history
(cherry picked from commit e67ddaa7216ec0a56d6140e5013eb56eeff712f5)
(cherry picked from commit 6aa1ab3396a8dbdd5fd7deb8936e12202ef1229a)
  • Loading branch information
dkarrasch authored and KristofferC committed Dec 2, 2024
1 parent 2e1d85a commit 528326f
Show file tree
Hide file tree
Showing 6 changed files with 550 additions and 504 deletions.
8 changes: 7 additions & 1 deletion src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ end
Adjoint(A) = Adjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A)
Transpose(A) = Transpose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A)

# TODO: remove, is already replaced by wrapperop
"""
adj_or_trans(::AbstractArray) -> adjoint|transpose|identity
adj_or_trans(::Type{<:AbstractArray}) -> adjoint|transpose|identity
Return [`adjoint`](@ref) from an `Adjoint` type or object and
[`transpose`](@ref) from a `Transpose` type or object. Otherwise,
return [`identity`](@ref). Note that `Adjoint` and `Transpose` have
Expand All @@ -94,9 +94,15 @@ inplace_adj_or_trans(::Type{<:AbstractArray}) = copyto!
inplace_adj_or_trans(::Type{<:Adjoint}) = adjoint!
inplace_adj_or_trans(::Type{<:Transpose}) = transpose!

# unwraps Adjoint, Transpose, Symmetric, Hermitian
_unwrap(A::Adjoint) = parent(A)
_unwrap(A::Transpose) = parent(A)

# unwraps Adjoint and Transpose only
_unwrap_at(A) = A
_unwrap_at(A::Adjoint) = parent(A)
_unwrap_at(A::Transpose) = parent(A)

Base.dataids(A::Union{Adjoint, Transpose}) = Base.dataids(A.parent)
Base.unaliascopy(A::Union{Adjoint,Transpose}) = typeof(A)(Base.unaliascopy(A.parent))

Expand Down
4 changes: 2 additions & 2 deletions src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ function ldiv!(c::AbstractVecOrMat, A::Bidiagonal, b::AbstractVecOrMat)
end
ldiv!(A::AdjOrTrans{<:Any,<:Bidiagonal}, b::AbstractVecOrMat) = @inline ldiv!(b, A, b)
ldiv!(c::AbstractVecOrMat, A::AdjOrTrans{<:Any,<:Bidiagonal}, b::AbstractVecOrMat) =
(t = adj_or_trans(A); _rdiv!(t(c), t(b), t(A)); return c)
(t = wrapperop(A); _rdiv!(t(c), t(b), t(A)); return c)

### Generic promotion methods and fallbacks
\(A::Bidiagonal, B::AbstractVecOrMat) = ldiv!(_initarray(\, eltype(A), eltype(B), B), A, B)
Expand Down Expand Up @@ -833,7 +833,7 @@ end
rdiv!(A::AbstractMatrix, B::Bidiagonal) = @inline _rdiv!(A, A, B)
rdiv!(A::AbstractMatrix, B::AdjOrTrans{<:Any,<:Bidiagonal}) = @inline _rdiv!(A, A, B)
_rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::AdjOrTrans{<:Any,<:Bidiagonal}) =
(t = adj_or_trans(B); ldiv!(t(C), t(B), t(A)); return C)
(t = wrapperop(B); ldiv!(t(C), t(B), t(A)); return C)

/(A::AbstractMatrix, B::Bidiagonal) = _rdiv!(_initarray(/, eltype(A), eltype(B), A), A, B)

Expand Down
4 changes: 2 additions & 2 deletions src/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ for T = (:Number, :UniformScaling, :Diagonal)
end

function *(H::UpperHessenberg, U::UpperOrUnitUpperTriangular)
HH = _mulmattri!(_initarray(*, eltype(H), eltype(U), H), H, U)
HH = mul!(_initarray(*, eltype(H), eltype(U), H), H, U)
UpperHessenberg(HH)
end
function *(U::UpperOrUnitUpperTriangular, H::UpperHessenberg)
HH = _multrimat!(_initarray(*, eltype(U), eltype(H), H), U, H)
HH = mul!(_initarray(*, eltype(U), eltype(H), H), U, H)
UpperHessenberg(HH)
end

Expand Down
16 changes: 6 additions & 10 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ AdjOrTransStridedMat{T} = Union{Adjoint{<:Any, <:StridedMatrix{T}}, Transpose{<:
StridedMaybeAdjOrTransMat{T} = Union{StridedMatrix{T}, Adjoint{<:Any, <:StridedMatrix{T}}, Transpose{<:Any, <:StridedMatrix{T}}}
StridedMaybeAdjOrTransVecOrMat{T} = Union{StridedVecOrMat{T}, AdjOrTrans{<:Any, <:StridedVecOrMat{T}}}

_parent(A) = A
_parent(A::Adjoint) = parent(A)
_parent(A::Transpose) = parent(A)

matprod(x, y) = x*y + x*y

# dot products
Expand Down Expand Up @@ -115,14 +111,14 @@ end
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
wrapperop(A)(convert(AbstractArray{TS}, _parent(A))),
wrapperop(B)(convert(AbstractArray{TS}, _parent(B))))
wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
end
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasComplex})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
wrapperop(A)(convert(AbstractArray{TS}, _parent(A))),
wrapperop(B)(convert(AbstractArray{TS}, _parent(B))))
wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
end

# Complex Matrix times real matrix: We use that it is generally faster to reinterpret the
Expand All @@ -131,13 +127,13 @@ function (*)(A::StridedMatrix{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:Bla
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
convert(AbstractArray{TS}, A),
wrapperop(B)(convert(AbstractArray{real(TS)}, _parent(B))))
wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
end
function (*)(A::AdjOrTransStridedMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
copymutable_oftype(A, TS), # remove AdjOrTrans to use reinterpret trick below
wrapperop(B)(convert(AbstractArray{real(TS)}, _parent(B))))
wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
end
# the following case doesn't seem to benefit from the translation A*B = (B' * A')'
function (*)(A::StridedMatrix{<:BlasReal}, B::StridedMatrix{<:BlasComplex})
Expand Down
Loading

0 comments on commit 528326f

Please sign in to comment.