Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: make FillArrays.jl a weakdep #163

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
name = "ArrayLayouts"
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
authors = ["Sheehan Olver <[email protected]>"]
version = "1.2.1"
version = "1.3.0"

[deps]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

[extensions]
ArrayLayoutsFillArraysExt = "FillArrays"

[compat]
FillArrays = "1.2.1"
julia = "1.6"

[extras]
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Base64", "Random", "StableRNGs", "Test"]
test = ["Base64", "FillArrays", "Random", "StableRNGs", "Test"]
117 changes: 117 additions & 0 deletions ext/ArrayLayoutsFillArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
module ArrayLayoutsFillArraysExt

using FillArrays
using FillArrays: AbstractFill, getindex_value

using ArrayLayouts
using ArrayLayouts: OnesLayout, Mul, MulAdd, diagonal
import ArrayLayouts: MemoryLayout, _copyto!, sub_materialize, diagonaldata, mulzeros
export layoutfillmul

import Base: copy, *, +, -
import Base.Broadcast: materialize!
import LinearAlgebra
using LinearAlgebra: Adjoint, Transpose, Symmetric, Hermitian, Diagonal,
AdjointAbsVec, TransposeAbsVec, UniformScaling

macro layoutfillmul(Typ)
ret = quote
(*)(A::LinearAlgebra.AdjointAbsVec{<:Any,<:Zeros{<:Any,1}}, B::$Typ) = ArrayLayouts.mul(A,B)
(*)(A::LinearAlgebra.TransposeAbsVec{<:Any,<:Zeros{<:Any,1}}, B::$Typ) = ArrayLayouts.mul(A,B)
(*)(A::LinearAlgebra.Transpose{T,<:$Typ}, B::Zeros{T,1}) where T<:Real = ArrayLayouts.mul(A,B)
end
for Mod in (:Adjoint, :Transpose, :Symmetric, :Hermitian)
ret = quote
$ret

(*)(A::$Mod{<:Any,<:$Typ}, B::Zeros{<:Any,1}) = ArrayLayouts.mul(A,B)
end
end
esc(ret)
end

@layoutfillmul LayoutMatrix

*(a::Zeros{<:Any,2}, b::LayoutMatrix) = FillArrays.mult_zeros(a, b)
*(a::LayoutMatrix, b::Zeros{<:Any,2}) = FillArrays.mult_zeros(a, b)
*(a::LayoutMatrix, b::Zeros{<:Any,1}) = FillArrays.mult_zeros(a, b)
*(a::Transpose{T, <:LayoutMatrix{T}} where T, b::Zeros{<:Any, 2}) = FillArrays.mult_zeros(a, b)
*(a::Adjoint{T, <:LayoutMatrix{T}} where T, b::Zeros{<:Any, 2}) = FillArrays.mult_zeros(a, b)
*(A::Adjoint{<:Any, <:Zeros{<:Any,1}}, B::Diagonal{<:Any,<:LayoutVector}) = (B' * A')'
*(A::Transpose{<:Any, <:Zeros{<:Any,1}}, B::Diagonal{<:Any,<:LayoutVector}) = transpose(transpose(B) * transpose(A))
*(a::Adjoint{<:Number,<:LayoutVector}, b::Zeros{<:Number,1})= FillArrays._adjvec_mul_zeros(a, b)
function *(a::Transpose{T, <:LayoutVector{T}}, b::Zeros{T, 1}) where T<:Real
la, lb = length(a), length(b)
if la ≠ lb
throw(DimensionMismatch("dot product arguments have lengths $la and $lb"))
end
return zero(T)
end

# equivalent to rescaling
function materialize!(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout}})
M.B .= getindex_value(M.A.diag) .* M.B
M.B
end
# equivalent to rescaling
function materialize!(M::Rmul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}})
M.A .= M.A .* getindex_value(M.B.diag)
M.A
end

copy(M::Ldiv{<:DiagonalLayout{<:AbstractFillLayout}}) = inv(getindex_value(M.A.diag)) .* M.B
copy(M::Ldiv{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout}) = diagonal(inv(getindex_value(M.A.diag)) .* M.B.diag)

copy(M::Rdiv{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A .* inv(getindex_value(M.B.diag))
copy(M::Rdiv{<:DiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = diagonal(M.A.diag .* inv(getindex_value(M.B.diag)))

copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout}}) = getindex_value(diagonaldata(M.A)) * M.B
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B
copy(M::Rmul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))
copy(M::Rmul{<:DualLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))

copy(M::Rmul{<:BidiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:BidiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B
copy(M::Rmul{<:TridiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:TridiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B
copy(M::Rmul{<:SymTridiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:SymTridiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B

MemoryLayout(::Type{<:AbstractFill}) = FillLayout()
MemoryLayout(::Type{<:Zeros}) = ZerosLayout()
MemoryLayout(::Type{<:Ones}) = OnesLayout()

_copyto!(_, ::AbstractFillLayout, dest::AbstractArray{<:Any,N}, src::AbstractArray{<:Any,N}) where N =
fill!(dest, getindex_value(src))

_fill_copyto!(dest, C::Zeros) = zero!(dest) # exploit special fill! overload

sub_materialize(::AbstractFillLayout, V, ax) = Fill(getindex_value(V), ax)
sub_materialize(::ZerosLayout, V, ax) = Zeros{eltype(V)}(ax)
sub_materialize(::OnesLayout, V, ax) = Ones{eltype(V)}(ax)

*(x::AdjointAbsVec{<:Any,<:Zeros{<:Any,1}}, D::Diagonal, y::LayoutVector) = FillArrays._triple_zeromul(x, D, y)
*(x::TransposeAbsVec{<:Any,<:Zeros{<:Any,1}}, D::Diagonal, y::LayoutVector) = FillArrays._triple_zeromul(x, D, y)

@inline LinearAlgebra.dot(a::LayoutVector, b::AbstractFill{<:Any,1}) = FillArrays._fill_dot_rev(a,b)
@inline LinearAlgebra.dot(a::AbstractFill{<:Any,1}, b::LayoutVector) = FillArrays._fill_dot(a,b)

# equivalent to rescaling
function materialize!(M::MulAdd{<:DiagonalLayout{<:AbstractFillLayout}})
checkdimensions(M)
M.C .= (M.α * getindex_value(M.A.diag)) .* M.B .+ M.β .* M.C
M.C
end

function materialize!(M::MulAdd{<:Any,<:DiagonalLayout{<:AbstractFillLayout}})
checkdimensions(M)
M.C .= M.α .* M.A .* getindex_value(M.B.diag) .+ M.β .* M.C
M.C
end

fillzeros(::Type{T}, ax) where T<:Number = Zeros{T}(ax)
mulzeros(::Type{T}, M) where T<:Number = fillzeros(T, axes(M))
mulzeros(::Type{T}, M::Mul{<:DualLayout,<:Any,<:Adjoint}) where T<:Number = fillzeros(T, axes(M,2))'
mulzeros(::Type{T}, M::Mul{<:DualLayout,<:Any,<:Transpose}) where T<:Number = transpose(fillzeros(T, axes(M,2)))

end
35 changes: 16 additions & 19 deletions src/ArrayLayouts.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module ArrayLayouts
using Base: _typed_hcat
using Base, Base.Broadcast, LinearAlgebra, FillArrays, SparseArrays
using Base, Base.Broadcast, LinearAlgebra, SparseArrays
using LinearAlgebra.BLAS

using Base: AbstractCartesianIndex, OneTo, oneto, RangeIndex, ReinterpretArray, ReshapedArray,
Expand All @@ -25,8 +25,6 @@ using LinearAlgebra.BLAS: BlasFloat, BlasReal, BlasComplex

AdjointQtype{T} = isdefined(LinearAlgebra, :AdjointQ) ? LinearAlgebra.AdjointQ{T} : Adjoint{T,<:AbstractQ}

using FillArrays: AbstractFill, getindex_value, axes_print_matrix_row

using Base: require_one_based_indexing

export materialize, materialize!, MulAdd, muladd!, Ldiv, Rdiv, Lmul, Rmul, Dot,
Expand Down Expand Up @@ -121,6 +119,11 @@ include("diagonal.jl")
include("triangular.jl")
include("factorizations.jl")

@static if !isdefined(Base, :get_extension)
include("../ext/ArrayLayoutsFillArraysExt.jl")
end


# Extend this function if you're only looking to dispatch on the axes
@inline sub_materialize_axes(V, _) = Array(V)
@inline sub_materialize(_, V, ax) = sub_materialize_axes(V, ax)
Expand Down Expand Up @@ -196,22 +199,6 @@ getindex(A::LayoutVector, kr::Colon) = layout_getindex(A, kr)
getindex(A::AdjOrTrans{<:Any,<:LayoutVector}, kr::Integer, jr::Colon) = layout_getindex(A, kr, jr)
getindex(A::AdjOrTrans{<:Any,<:LayoutVector}, kr::Integer, jr::AbstractVector) = layout_getindex(A, kr, jr)

*(a::Zeros{<:Any,2}, b::LayoutMatrix) = FillArrays.mult_zeros(a, b)
*(a::LayoutMatrix, b::Zeros{<:Any,2}) = FillArrays.mult_zeros(a, b)
*(a::LayoutMatrix, b::Zeros{<:Any,1}) = FillArrays.mult_zeros(a, b)
*(a::Transpose{T, <:LayoutMatrix{T}} where T, b::Zeros{<:Any, 2}) = FillArrays.mult_zeros(a, b)
*(a::Adjoint{T, <:LayoutMatrix{T}} where T, b::Zeros{<:Any, 2}) = FillArrays.mult_zeros(a, b)
*(A::Adjoint{<:Any, <:Zeros{<:Any,1}}, B::Diagonal{<:Any,<:LayoutVector}) = (B' * A')'
*(A::Transpose{<:Any, <:Zeros{<:Any,1}}, B::Diagonal{<:Any,<:LayoutVector}) = transpose(transpose(B) * transpose(A))
*(a::Adjoint{<:Number,<:LayoutVector}, b::Zeros{<:Number,1})= FillArrays._adjvec_mul_zeros(a, b)
function *(a::Transpose{T, <:LayoutVector{T}}, b::Zeros{T, 1}) where T<:Real
la, lb = length(a), length(b)
if la ≠ lb
throw(DimensionMismatch("dot product arguments have lengths $la and $lb"))
end
return zero(T)
end

*(A::Diagonal{<:Any,<:LayoutVector}, B::Diagonal{<:Any,<:LayoutVector}) = mul(A, B)
*(A::Diagonal{<:Any,<:LayoutVector}, B::AbstractMatrix) = mul(A, B)
*(A::AbstractMatrix, B::Diagonal{<:Any,<:LayoutVector}) = mul(A, B)
Expand Down Expand Up @@ -385,6 +372,16 @@ Base.replace_in_print_matrix(A::Union{LayoutVector,
UnitLowerTriangular{<:Any,<:AdjOrTrans{<:Any,<:LayoutMatrix}}}, i::Integer, j::Integer, s::AbstractString) =
layout_replace_in_print_matrix(MemoryLayout(A), A, i, j, s)

if VERSION < v"1.8-"
axes_print_matrix_row(lay, io, X, A, i, cols, sep) =
Base.invoke(Base.print_matrix_row, Tuple{IO,AbstractVecOrMat,Vector,Integer,AbstractVector,AbstractString},
io, X, A, i, cols, sep)
else
axes_print_matrix_row(lay, io, X, A, i, cols, sep, idxlast::Integer=last(axes(X, 2))) =
Base.invoke(Base.print_matrix_row, Tuple{IO,AbstractVecOrMat,Vector,Integer,AbstractVector,AbstractString,Integer},
io, X, A, i, cols, sep, idxlast)
end

Base.print_matrix_row(io::IO,
X::Union{LayoutMatrix,
LayoutVector,
Expand Down
26 changes: 0 additions & 26 deletions src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@ mulreduce(M::Mul{<:Any,<:DiagonalLayout}) = Rmul(M)

# Diagonal multiplication never changes structure
similar(M::Lmul{<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.B, T, axes)
# equivalent to rescaling
function materialize!(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout}})
M.B .= getindex_value(M.A.diag) .* M.B
M.B
end


copy(M::Lmul{<:DiagonalLayout,<:DiagonalLayout}) = diagonal(diagonaldata(M.A) .* diagonaldata(M.B))
Expand All @@ -32,11 +27,6 @@ end

# Diagonal multiplication never changes structure
similar(M::Rmul{<:Any,<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.A, T, axes)
# equivalent to rescaling
function materialize!(M::Rmul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}})
M.A .= M.A .* getindex_value(M.B.diag)
M.A
end


function materialize!(M::Ldiv{<:DiagonalLayout})
Expand All @@ -46,13 +36,9 @@ end

copy(M::Ldiv{<:DiagonalLayout,<:DiagonalLayout}) = diagonal(M.A.diag .\ M.B.diag)
copy(M::Ldiv{<:DiagonalLayout}) = M.A.diag .\ M.B
copy(M::Ldiv{<:DiagonalLayout{<:AbstractFillLayout}}) = inv(getindex_value(M.A.diag)) .* M.B
copy(M::Ldiv{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout}) = diagonal(inv(getindex_value(M.A.diag)) .* M.B.diag)

copy(M::Rdiv{<:DiagonalLayout,<:DiagonalLayout}) = diagonal(M.A.diag .* inv.(M.B.diag))
copy(M::Rdiv{<:Any,<:DiagonalLayout}) = M.A .* inv.(permutedims(M.B.diag))
copy(M::Rdiv{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A .* inv(getindex_value(M.B.diag))
copy(M::Rdiv{<:DiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = diagonal(M.A.diag .* inv(getindex_value(M.B.diag)))



Expand All @@ -71,18 +57,6 @@ copy(M::Lmul{DiagonalLayout{OnesLayout},DiagonalLayout{OnesLayout}}) = _copy_oft
copy(M::Rmul{<:Any,DiagonalLayout{OnesLayout}}) = _copy_oftype(M.A, eltype(M))
copy(M::Rmul{<:DualLayout,DiagonalLayout{OnesLayout}}) = _copy_oftype(M.A, eltype(M))

copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout}}) = getindex_value(diagonaldata(M.A)) * M.B
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B
copy(M::Rmul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))
copy(M::Rmul{<:DualLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))

copy(M::Rmul{<:BidiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:BidiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B
copy(M::Rmul{<:TridiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:TridiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B
copy(M::Rmul{<:SymTridiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:SymTridiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B


copy(M::Rmul{<:BidiagonalLayout,DiagonalLayout{OnesLayout}}) = _copy_oftype(M.A, eltype(M))
copy(M::Lmul{DiagonalLayout{OnesLayout},<:BidiagonalLayout}) = _copy_oftype(M.B, eltype(M))
Expand Down
11 changes: 0 additions & 11 deletions src/memorylayout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -531,23 +531,12 @@ struct ZerosLayout <: AbstractFillLayout end
struct OnesLayout <: AbstractFillLayout end
struct EyeLayout <: MemoryLayout end

MemoryLayout(::Type{<:AbstractFill}) = FillLayout()
MemoryLayout(::Type{<:Zeros}) = ZerosLayout()
MemoryLayout(::Type{<:Ones}) = OnesLayout()

# all sub arrays are same
sublayout(L::AbstractFillLayout, inds::Type) = L
reshapedlayout(L::AbstractFillLayout, _) = L
adjointlayout(::Type, L::AbstractFillLayout) = L
transposelayout(L::AbstractFillLayout) = L

_copyto!(_, ::AbstractFillLayout, dest::AbstractArray{<:Any,N}, src::AbstractArray{<:Any,N}) where N =
fill!(dest, getindex_value(src))

sub_materialize(::AbstractFillLayout, V, ax) = Fill(getindex_value(V), ax)
sub_materialize(::ZerosLayout, V, ax) = Zeros{eltype(V)}(ax)
sub_materialize(::OnesLayout, V, ax) = Ones{eltype(V)}(ax)

abstract type AbstractBandedLayout <: MemoryLayout end
abstract type AbstractTridiagonalLayout <: AbstractBandedLayout end

Expand Down
20 changes: 2 additions & 18 deletions src/mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,6 @@ macro layoutmul(Typ)
(*)(A::AbstractMatrix, B::$Typ) = ArrayLayouts.mul(A,B)
(*)(A::LinearAlgebra.AdjointAbsVec, B::$Typ) = ArrayLayouts.mul(A,B)
(*)(A::LinearAlgebra.TransposeAbsVec, B::$Typ) = ArrayLayouts.mul(A,B)
(*)(A::LinearAlgebra.AdjointAbsVec{<:Any,<:Zeros{<:Any,1}}, B::$Typ) = ArrayLayouts.mul(A,B)
(*)(A::LinearAlgebra.TransposeAbsVec{<:Any,<:Zeros{<:Any,1}}, B::$Typ) = ArrayLayouts.mul(A,B)
(*)(A::LinearAlgebra.Transpose{T,<:$Typ}, B::Zeros{T,1}) where T<:Real = ArrayLayouts.mul(A,B)

(*)(A::LinearAlgebra.AbstractQ, B::$Typ) = ArrayLayouts.mul(A,B)
(*)(A::$Typ, B::LinearAlgebra.AbstractQ) = ArrayLayouts.mul(A,B)
Expand Down Expand Up @@ -278,7 +275,6 @@ macro layoutmul(Typ)
(*)(A::LinearAlgebra.TransposeAbsVec, B::$Mod{<:Any,<:$Typ}) = ArrayLayouts.mul(A,B)
(*)(A::$Mod{<:Any,<:$Typ}, B::AbstractVector) = ArrayLayouts.mul(A,B)
(*)(A::$Mod{<:Any,<:$Typ}, B::ArrayLayouts.LayoutVector) = ArrayLayouts.mul(A,B)
(*)(A::$Mod{<:Any,<:$Typ}, B::Zeros{<:Any,1}) = ArrayLayouts.mul(A,B)

(*)(A::$Mod{<:Any,<:$Typ}, B::$Typ) = ArrayLayouts.mul(A,B)
(*)(A::$Typ, B::$Mod{<:Any,<:$Typ}) = ArrayLayouts.mul(A,B)
Expand Down Expand Up @@ -306,8 +302,6 @@ end
*(x::Transpose{<:Any,<:LayoutVector}, D::Diagonal{<:Any,<:LayoutVector}) = mul(x, D)
*(x::AdjointAbsVec, D::Diagonal, y::LayoutVector) = x * mul(D,y)
*(x::TransposeAbsVec, D::Diagonal, y::LayoutVector) = x * mul(D,y)
*(x::AdjointAbsVec{<:Any,<:Zeros{<:Any,1}}, D::Diagonal, y::LayoutVector) = FillArrays._triple_zeromul(x, D, y)
*(x::TransposeAbsVec{<:Any,<:Zeros{<:Any,1}}, D::Diagonal, y::LayoutVector) = FillArrays._triple_zeromul(x, D, y)


*(A::UpperOrLowerTriangular{<:Any,<:LayoutMatrix}, B::UpperOrLowerTriangular{<:Any,<:LayoutMatrix}) = mul(A, B)
Expand Down Expand Up @@ -358,8 +352,6 @@ dot(a, b) = materialize(Dot(a, b))
@inline LinearAlgebra.dot(a::LayoutArray, b::LayoutArray) = dot(a,b)
@inline LinearAlgebra.dot(a::LayoutArray, b::AbstractArray) = dot(a,b)
@inline LinearAlgebra.dot(a::AbstractArray, b::LayoutArray) = dot(a,b)
@inline LinearAlgebra.dot(a::LayoutVector, b::AbstractFill{<:Any,1}) = FillArrays._fill_dot_rev(a,b)
@inline LinearAlgebra.dot(a::AbstractFill{<:Any,1}, b::LayoutVector) = FillArrays._fill_dot(a,b)
@inline LinearAlgebra.dot(a::LayoutArray{<:Number}, b::SparseArrays.SparseVectorUnion{<:Number}) = dot(a,b)
@inline LinearAlgebra.dot(a::SparseArrays.SparseVectorUnion{<:Number}, b::LayoutArray{<:Number}) = dot(a,b)

Expand All @@ -379,17 +371,9 @@ LinearAlgebra.dot(x::AbstractVector, A::Symmetric{<:Real,<:LayoutMatrix}, y::Abs

# allow overloading for infinite or lazy case
@inline _power_by_squaring(_, _, A, p) = invoke(Base.power_by_squaring, Tuple{AbstractMatrix,Integer}, A, p)
# TODO: Remove unnecessary _apply
_apply(_, _, op, Λ::UniformScaling, A::AbstractMatrix) = op(Diagonal(Fill(Λ.λ,(axes(A,1),))), A)
_apply(_, _, op, A::AbstractMatrix, Λ::UniformScaling) = op(A, Diagonal(Fill(Λ.λ,(axes(A,1),))))

for Typ in (:LayoutMatrix, :(Symmetric{<:Any,<:LayoutMatrix}), :(Hermitian{<:Any,<:LayoutMatrix}),
:(Adjoint{<:Any,<:LayoutMatrix}), :(Transpose{<:Any,<:LayoutMatrix}))
@eval begin
@inline Base.power_by_squaring(A::$Typ, p::Integer) = _power_by_squaring(MemoryLayout(A), size(A), A, p)
@inline +(A::$Typ, Λ::UniformScaling) = _apply(MemoryLayout(A), size(A), +, A, Λ)
@inline +(Λ::UniformScaling, A::$Typ) = _apply(MemoryLayout(A), size(A), +, Λ, A)
@inline -(A::$Typ, Λ::UniformScaling) = _apply(MemoryLayout(A), size(A), -, A, Λ)
@inline -(Λ::UniformScaling, A::$Typ) = _apply(MemoryLayout(A), size(A), -, Λ, A)
end
@eval @inline Base.power_by_squaring(A::$Typ, p::Integer) =
_power_by_squaring(MemoryLayout(A), size(A), A, p)
end
18 changes: 0 additions & 18 deletions src/muladd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ materialize(M::MulAdd) = copy(instantiate(M))
copy(M::MulAdd) = copyto!(similar(M), M)

_fill_copyto!(dest, C) = copyto!(dest, C)
_fill_copyto!(dest, C::Zeros) = zero!(dest) # exploit special fill! overload

@inline copyto!(dest::AbstractArray{T}, M::MulAdd) where T =
muladd!(M.α, unalias(dest,M.A), unalias(dest,M.B), M.β, _fill_copyto!(dest, M.C))
Expand Down Expand Up @@ -362,18 +361,6 @@ materialize!(M::BlasMatMulVecAdd{<:HermitianLayout{<:AbstractRowMajor},<:Abstrac
similar(M::MulAdd{<:DiagonalLayout,<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.B, T, axes)
similar(M::MulAdd{<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.B, T, axes)
similar(M::MulAdd{<:Any,<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.A, T, axes)
# equivalent to rescaling
function materialize!(M::MulAdd{<:DiagonalLayout{<:AbstractFillLayout}})
checkdimensions(M)
M.C .= (M.α * getindex_value(M.A.diag)) .* M.B .+ M.β .* M.C
M.C
end

function materialize!(M::MulAdd{<:Any,<:DiagonalLayout{<:AbstractFillLayout}})
checkdimensions(M)
M.C .= M.α .* M.A .* getindex_value(M.B.diag) .+ M.β .* M.C
M.C
end


BroadcastStyle(::Type{<:MulAdd}) = ApplyBroadcastStyle()
Expand All @@ -383,11 +370,6 @@ scalarone(::Type{A}) where {A<:AbstractArray} = scalarone(eltype(A))
scalarzero(::Type{T}) where T = zero(T)
scalarzero(::Type{A}) where {A<:AbstractArray} = scalarzero(eltype(A))

fillzeros(::Type{T}, ax) where T<:Number = Zeros{T}(ax)
mulzeros(::Type{T}, M) where T<:Number = fillzeros(T, axes(M))
mulzeros(::Type{T}, M::Mul{<:DualLayout,<:Any,<:Adjoint}) where T<:Number = fillzeros(T, axes(M,2))'
mulzeros(::Type{T}, M::Mul{<:DualLayout,<:Any,<:Transpose}) where T<:Number = transpose(fillzeros(T, axes(M,2)))

# initiate array-valued MulAdd
function _mulzeros!(dest::AbstractVector{T}, A, B) where T
for k in axes(dest,1)
Expand Down