diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 9c74addd6b69c..0a008256316d7 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -869,7 +869,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) @noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}, - _add::MulAddMul) where {T,S,R} + _add::MulAddMul{ais1}) where {T,S,R,ais1} AxM = axes(A, 1) AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector` BxK = axes(B, 1) @@ -885,11 +885,13 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A if BxN != CxN throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)")) end + _rmul_alpha = MulAddMul{ais1,true,typeof(_add.alpha),Bool}(_add.alpha,false) if isbitstype(R) && sizeof(R) ≤ 16 && !(A isa Adjoint || A isa Transpose) _rmul_or_fill!(C, _add.beta) (iszero(_add.alpha) || isempty(A) || isempty(B)) && return C @inbounds for n in BxN, k in BxK - Balpha = B[k,n]*_add.alpha + # Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha) + Balpha = _rmul_alpha(B[k,n]) @simd for m in AxM C[m,n] = muladd(A[m,k], Balpha, C[m,n]) end diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index db61fbe0ab45a..aab535cbe0303 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -1107,4 +1107,22 @@ end end end +@testset "issue #56085" begin + struct Thing + data::Float64 + end + + Base.zero(::Type{Thing}) = Thing(0.) + Base.zero(::Thing) = Thing(0.) + Base.one(::Type{Thing}) = Thing(1.) + Base.one(::Thing) = Thing(1.) + Base.:+(t::Thing...) = +(getfield.(t, :data)...) + Base.:*(t::Thing...) = *(getfield.(t, :data)...) + + M = Float64[1 2; 3 4] + A = Thing.(M) + + @test A * A ≈ M * M +end + end # module TestMatmul