Skip to content

Commit

Permalink
Fix 3-arg dot for 1x1 structured matrices (#46473)
Browse files Browse the repository at this point in the history
(cherry picked from commit c3d5009)
  • Loading branch information
dkarrasch authored and KristofferC committed Aug 26, 2022
1 parent ffaaccf commit 3d7ddf8
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 38 deletions.
11 changes: 6 additions & 5 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,14 +702,15 @@ function dot(x::AbstractVector, B::Bidiagonal, y::AbstractVector)
require_one_based_indexing(x, y)
nx, ny = length(x), length(y)
(nx == size(B, 1) == ny) || throw(DimensionMismatch())
if iszero(nx)
return dot(zero(eltype(x)), zero(eltype(B)), zero(eltype(y)))
if nx 1
nx == 0 && return dot(zero(eltype(x)), zero(eltype(B)), zero(eltype(y)))
return dot(x[1], B.dv[1], y[1])
end
ev, dv = B.ev, B.dv
if B.uplo == 'U'
@inbounds if B.uplo == 'U'
x₀ = x[1]
r = dot(x[1], dv[1], y[1])
@inbounds for j in 2:nx-1
for j in 2:nx-1
x₋, x₀ = x₀, x[j]
r += dot(adjoint(ev[j-1])*x₋ + adjoint(dv[j])*x₀, y[j])
end
Expand All @@ -719,7 +720,7 @@ function dot(x::AbstractVector, B::Bidiagonal, y::AbstractVector)
x₀ = x[1]
x₊ = x[2]
r = dot(adjoint(dv[1])*x₀ + adjoint(ev[1])*x₊, y[1])
@inbounds for j in 2:nx-1
for j in 2:nx-1
x₀, x₊ = x₊, x[j+1]
r += dot(adjoint(dv[j])*x₀ + adjoint(ev[j])*x₊, y[j])
end
Expand Down
56 changes: 31 additions & 25 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,21 +256,24 @@ end
function dot(x::AbstractVector, S::SymTridiagonal, y::AbstractVector)
require_one_based_indexing(x, y)
nx, ny = length(x), length(y)
(nx == size(S, 1) == ny) || throw(DimensionMismatch())
if iszero(nx)
return dot(zero(eltype(x)), zero(eltype(S)), zero(eltype(y)))
(nx == size(S, 1) == ny) || throw(DimensionMismatch("dot"))
if nx 1
nx == 0 && return dot(zero(eltype(x)), zero(eltype(S)), zero(eltype(y)))
return dot(x[1], S.dv[1], y[1])
end
dv, ev = S.dv, S.ev
x₀ = x[1]
x₊ = x[2]
sub = transpose(ev[1])
r = dot(adjoint(dv[1])*x₀ + adjoint(sub)*x₊, y[1])
@inbounds for j in 2:nx-1
x₋, x₀, x₊ = x₀, x₊, x[j+1]
sup, sub = transpose(sub), transpose(ev[j])
r += dot(adjoint(sup)*x₋ + adjoint(dv[j])*x₀ + adjoint(sub)*x₊, y[j])
end
r += dot(adjoint(transpose(sub))*x₀ + adjoint(dv[nx])*x₊, y[nx])
@inbounds begin
x₀ = x[1]
x₊ = x[2]
sub = transpose(ev[1])
r = dot(adjoint(dv[1])*x₀ + adjoint(sub)*x₊, y[1])
for j in 2:nx-1
x₋, x₀, x₊ = x₀, x₊, x[j+1]
sup, sub = transpose(sub), transpose(ev[j])
r += dot(adjoint(sup)*x₋ + adjoint(dv[j])*x₀ + adjoint(sub)*x₊, y[j])
end
r += dot(adjoint(transpose(sub))*x₀ + adjoint(dv[nx])*x₊, y[nx])
end
return r
end

Expand Down Expand Up @@ -841,18 +844,21 @@ function dot(x::AbstractVector, A::Tridiagonal, y::AbstractVector)
require_one_based_indexing(x, y)
nx, ny = length(x), length(y)
(nx == size(A, 1) == ny) || throw(DimensionMismatch())
if iszero(nx)
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
end
x₀ = x[1]
x₊ = x[2]
dl, d, du = A.dl, A.d, A.du
r = dot(adjoint(d[1])*x₀ + adjoint(dl[1])*x₊, y[1])
@inbounds for j in 2:nx-1
x₋, x₀, x₊ = x₀, x₊, x[j+1]
r += dot(adjoint(du[j-1])*x₋ + adjoint(d[j])*x₀ + adjoint(dl[j])*x₊, y[j])
end
r += dot(adjoint(du[nx-1])*x₀ + adjoint(d[nx])*x₊, y[nx])
if nx 1
nx == 0 && return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
return dot(x[1], A.d[1], y[1])
end
@inbounds begin
x₀ = x[1]
x₊ = x[2]
dl, d, du = A.dl, A.d, A.du
r = dot(adjoint(d[1])*x₀ + adjoint(dl[1])*x₊, y[1])
for j in 2:nx-1
x₋, x₀, x₊ = x₀, x₊, x[j+1]
r += dot(adjoint(du[j-1])*x₋ + adjoint(d[j])*x₀ + adjoint(dl[j])*x₊, y[j])
end
r += dot(adjoint(du[nx-1])*x₀ + adjoint(d[nx])*x₊, y[nx])
end
return r
end

Expand Down
14 changes: 7 additions & 7 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -623,22 +623,22 @@ end
end

@testset "generalized dot" begin
for elty in (Float64, ComplexF64)
dv = randn(elty, 5)
ev = randn(elty, 4)
x = randn(elty, 5)
y = randn(elty, 5)
for elty in (Float64, ComplexF64), n in (5, 1)
dv = randn(elty, n)
ev = randn(elty, n-1)
x = randn(elty, n)
y = randn(elty, n)
for uplo in (:U, :L)
B = Bidiagonal(dv, ev, uplo)
@test dot(x, B, y) dot(B'x, y) dot(x, Matrix(B), y)
@test dot(x, B, y) dot(B'x, y) dot(x, B*y) dot(x, Matrix(B), y)
end
dv = Vector{elty}(undef, 0)
ev = Vector{elty}(undef, 0)
x = Vector{elty}(undef, 0)
y = Vector{elty}(undef, 0)
for uplo in (:U, :L)
B = Bidiagonal(dv, ev, uplo)
@test dot(x, B, y) dot(zero(elty), zero(elty), zero(elty))
@test dot(x, B, y) === zero(elty)
end
end
end
Expand Down
6 changes: 5 additions & 1 deletion stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,11 @@ end
@testset "generalized dot" begin
x = fill(convert(elty, 1), n)
y = fill(convert(elty, 1), n)
@test dot(x, A, y) dot(A'x, y)
@test dot(x, A, y) dot(A'x, y) dot(x, A*y)
@test dot([1], SymTridiagonal([1], Int[]), [1]) == 1
@test dot([1], Tridiagonal(Int[], [1], Int[]), [1]) == 1
@test dot(Int[], SymTridiagonal(Int[], Int[]), Int[]) === 0
@test dot(Int[], Tridiagonal(Int[], Int[], Int[]), Int[]) === 0
end
end
end
Expand Down

0 comments on commit 3d7ddf8

Please sign in to comment.