Skip to content

Commit

Permalink
Call MulAddMul instead of multiplication in _generic_matmatmul! (#5…
Browse files Browse the repository at this point in the history
…6089)

Fix https://github.com/JuliaLang/julia/issues/56085 by calling a newly
created `MulAddMul` object that only wraps the `alpha` (with `beta` set
to `false`). This avoids the explicit multiplication if `alpha` is known
to be `isone`.

(cherry picked from commit 0af99e6)
  • Loading branch information
jishnub authored and KristofferC committed Oct 18, 2024
1 parent 9dda314 commit 415294a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
6 changes: 4 additions & 2 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)

@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S},
_add::MulAddMul) where {T,S,R}
_add::MulAddMul{ais1}) where {T,S,R,ais1}
AxM = axes(A, 1)
AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector`
BxK = axes(B, 1)
Expand All @@ -885,11 +885,13 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
if BxN != CxN
throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)"))
end
_rmul_alpha = MulAddMul{ais1,true,typeof(_add.alpha),Bool}(_add.alpha,false)
if isbitstype(R) && sizeof(R) 16 && !(A isa Adjoint || A isa Transpose)
_rmul_or_fill!(C, _add.beta)
(iszero(_add.alpha) || isempty(A) || isempty(B)) && return C
@inbounds for n in BxN, k in BxK
Balpha = B[k,n]*_add.alpha
# Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha)
Balpha = _rmul_alpha(B[k,n])
@simd for m in AxM
C[m,n] = muladd(A[m,k], Balpha, C[m,n])
end
Expand Down
18 changes: 18 additions & 0 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1107,4 +1107,22 @@ end
end
end

@testset "issue #56085" begin
struct Thing
data::Float64
end

Base.zero(::Type{Thing}) = Thing(0.)
Base.zero(::Thing) = Thing(0.)
Base.one(::Type{Thing}) = Thing(1.)
Base.one(::Thing) = Thing(1.)
Base.:+(t::Thing...) = +(getfield.(t, :data)...)
Base.:*(t::Thing...) = *(getfield.(t, :data)...)

M = Float64[1 2; 3 4]
A = Thing.(M)

@test A * A M * M
end

end # module TestMatmul

0 comments on commit 415294a

Please sign in to comment.