From 3d7ddf8da1930dcbd4c30a64a1ad04ec001e8c0c Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Thu, 25 Aug 2022 08:54:34 +0200 Subject: [PATCH] Fix 3-arg `dot` for 1x1 structured matrices (#46473) (cherry picked from commit c3d500984fd7e31db9096d7a59093039fd4f0009) --- stdlib/LinearAlgebra/src/bidiag.jl | 11 +++--- stdlib/LinearAlgebra/src/tridiag.jl | 56 +++++++++++++++------------- stdlib/LinearAlgebra/test/bidiag.jl | 14 +++---- stdlib/LinearAlgebra/test/tridiag.jl | 6 ++- 4 files changed, 49 insertions(+), 38 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 317ed15af770ce..fb29f9ae595b9b 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -702,14 +702,15 @@ function dot(x::AbstractVector, B::Bidiagonal, y::AbstractVector) require_one_based_indexing(x, y) nx, ny = length(x), length(y) (nx == size(B, 1) == ny) || throw(DimensionMismatch()) - if iszero(nx) - return dot(zero(eltype(x)), zero(eltype(B)), zero(eltype(y))) + if nx ≤ 1 + nx == 0 && return dot(zero(eltype(x)), zero(eltype(B)), zero(eltype(y))) + return dot(x[1], B.dv[1], y[1]) end ev, dv = B.ev, B.dv - if B.uplo == 'U' + @inbounds if B.uplo == 'U' x₀ = x[1] r = dot(x[1], dv[1], y[1]) - @inbounds for j in 2:nx-1 + for j in 2:nx-1 x₋, x₀ = x₀, x[j] r += dot(adjoint(ev[j-1])*x₋ + adjoint(dv[j])*x₀, y[j]) end @@ -719,7 +720,7 @@ function dot(x::AbstractVector, B::Bidiagonal, y::AbstractVector) x₀ = x[1] x₊ = x[2] r = dot(adjoint(dv[1])*x₀ + adjoint(ev[1])*x₊, y[1]) - @inbounds for j in 2:nx-1 + for j in 2:nx-1 x₀, x₊ = x₊, x[j+1] r += dot(adjoint(dv[j])*x₀ + adjoint(ev[j])*x₊, y[j]) end diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index e5c31856d3f0ac..d6b44d85e68e9b 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -256,21 +256,24 @@ end function dot(x::AbstractVector, S::SymTridiagonal, y::AbstractVector) require_one_based_indexing(x, y) nx, ny = length(x), length(y) - (nx == size(S, 1) == ny) || throw(DimensionMismatch()) - if iszero(nx) - return dot(zero(eltype(x)), zero(eltype(S)), zero(eltype(y))) + (nx == size(S, 1) == ny) || throw(DimensionMismatch("dot")) + if nx ≤ 1 + nx == 0 && return dot(zero(eltype(x)), zero(eltype(S)), zero(eltype(y))) + return dot(x[1], S.dv[1], y[1]) end dv, ev = S.dv, S.ev - x₀ = x[1] - x₊ = x[2] - sub = transpose(ev[1]) - r = dot(adjoint(dv[1])*x₀ + adjoint(sub)*x₊, y[1]) - @inbounds for j in 2:nx-1 - x₋, x₀, x₊ = x₀, x₊, x[j+1] - sup, sub = transpose(sub), transpose(ev[j]) - r += dot(adjoint(sup)*x₋ + adjoint(dv[j])*x₀ + adjoint(sub)*x₊, y[j]) - end - r += dot(adjoint(transpose(sub))*x₀ + adjoint(dv[nx])*x₊, y[nx]) + @inbounds begin + x₀ = x[1] + x₊ = x[2] + sub = transpose(ev[1]) + r = dot(adjoint(dv[1])*x₀ + adjoint(sub)*x₊, y[1]) + for j in 2:nx-1 + x₋, x₀, x₊ = x₀, x₊, x[j+1] + sup, sub = transpose(sub), transpose(ev[j]) + r += dot(adjoint(sup)*x₋ + adjoint(dv[j])*x₀ + adjoint(sub)*x₊, y[j]) + end + r += dot(adjoint(transpose(sub))*x₀ + adjoint(dv[nx])*x₊, y[nx]) + end return r end @@ -841,18 +844,21 @@ function dot(x::AbstractVector, A::Tridiagonal, y::AbstractVector) require_one_based_indexing(x, y) nx, ny = length(x), length(y) (nx == size(A, 1) == ny) || throw(DimensionMismatch()) - if iszero(nx) - return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y))) - end - x₀ = x[1] - x₊ = x[2] - dl, d, du = A.dl, A.d, A.du - r = dot(adjoint(d[1])*x₀ + adjoint(dl[1])*x₊, y[1]) - @inbounds for j in 2:nx-1 - x₋, x₀, x₊ = x₀, x₊, x[j+1] - r += dot(adjoint(du[j-1])*x₋ + adjoint(d[j])*x₀ + adjoint(dl[j])*x₊, y[j]) - end - r += dot(adjoint(du[nx-1])*x₀ + adjoint(d[nx])*x₊, y[nx]) + if nx ≤ 1 + nx == 0 && return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y))) + return dot(x[1], A.d[1], y[1]) + end + @inbounds begin + x₀ = x[1] + x₊ = x[2] + dl, d, du = A.dl, A.d, A.du + r = dot(adjoint(d[1])*x₀ + adjoint(dl[1])*x₊, y[1]) + for j in 2:nx-1 + x₋, x₀, x₊ = x₀, x₊, x[j+1] + r += dot(adjoint(du[j-1])*x₋ + adjoint(d[j])*x₀ + adjoint(dl[j])*x₊, y[j]) + end + r += dot(adjoint(du[nx-1])*x₀ + adjoint(d[nx])*x₊, y[nx]) + end return r end diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index 422984d84eb6b5..c711bf3e1e1c19 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -623,14 +623,14 @@ end end @testset "generalized dot" begin - for elty in (Float64, ComplexF64) - dv = randn(elty, 5) - ev = randn(elty, 4) - x = randn(elty, 5) - y = randn(elty, 5) + for elty in (Float64, ComplexF64), n in (5, 1) + dv = randn(elty, n) + ev = randn(elty, n-1) + x = randn(elty, n) + y = randn(elty, n) for uplo in (:U, :L) B = Bidiagonal(dv, ev, uplo) - @test dot(x, B, y) ≈ dot(B'x, y) ≈ dot(x, Matrix(B), y) + @test dot(x, B, y) ≈ dot(B'x, y) ≈ dot(x, B*y) ≈ dot(x, Matrix(B), y) end dv = Vector{elty}(undef, 0) ev = Vector{elty}(undef, 0) @@ -638,7 +638,7 @@ end y = Vector{elty}(undef, 0) for uplo in (:U, :L) B = Bidiagonal(dv, ev, uplo) - @test dot(x, B, y) ≈ dot(zero(elty), zero(elty), zero(elty)) + @test dot(x, B, y) === zero(elty) end end end diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index ecdf6b416baa54..0698a583c8d455 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -434,7 +434,11 @@ end @testset "generalized dot" begin x = fill(convert(elty, 1), n) y = fill(convert(elty, 1), n) - @test dot(x, A, y) ≈ dot(A'x, y) + @test dot(x, A, y) ≈ dot(A'x, y) ≈ dot(x, A*y) + @test dot([1], SymTridiagonal([1], Int[]), [1]) == 1 + @test dot([1], Tridiagonal(Int[], [1], Int[]), [1]) == 1 + @test dot(Int[], SymTridiagonal(Int[], Int[]), Int[]) === 0 + @test dot(Int[], Tridiagonal(Int[], Int[], Int[]), Int[]) === 0 end end end