diff --git a/NEWS.md b/NEWS.md index 2ff018ce96ce4..012b3b788650e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -39,6 +39,7 @@ Standard library changes * `qr` and `qr!` functions support `blocksize` keyword argument ([#33053]). +* `dot` now admits a 3-argument method `dot(x, A, y)` to compute generalized dot products `dot(x, A*y)`, but without computing and storing the intermediate result `A*y` ([#32739]). #### SparseArrays diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 318c0e5ca6dfd..8e707c4516308 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -647,6 +647,36 @@ function *(A::SymTridiagonal, B::Diagonal) A_mul_B_td!(Tridiagonal(zeros(TS, size(A, 1)-1), zeros(TS, size(A, 1)), zeros(TS, size(A, 1)-1)), A, B) end +function dot(x::AbstractVector, B::Bidiagonal, y::AbstractVector) + require_one_based_indexing(x, y) + nx, ny = length(x), length(y) + (nx == size(B, 1) == ny) || throw(DimensionMismatch()) + if iszero(nx) + return dot(zero(eltype(x)), zero(eltype(B)), zero(eltype(y))) + end + ev, dv = B.ev, B.dv + if B.uplo == 'U' + x₀ = x[1] + r = dot(x[1], dv[1], y[1]) + @inbounds for j in 2:nx-1 + x₋, x₀ = x₀, x[j] + r += dot(adjoint(ev[j-1])*x₋ + adjoint(dv[j])*x₀, y[j]) + end + r += dot(adjoint(ev[nx-1])*x₀ + adjoint(dv[nx])*x[nx], y[nx]) + return r + else # B.uplo == 'L' + x₀ = x[1] + x₊ = x[2] + r = dot(adjoint(dv[1])*x₀ + adjoint(ev[1])*x₊, y[1]) + @inbounds for j in 2:nx-1 + x₀, x₊ = x₊, x[j+1] + r += dot(adjoint(dv[j])*x₀ + adjoint(ev[j])*x₊, y[j]) + end + r += dot(x₊, dv[nx], y[nx]) + return r + end +end + #Linear solvers ldiv!(A::Union{Bidiagonal, AbstractTriangular}, b::AbstractVector) = naivesub!(A, b) ldiv!(A::Transpose{<:Any,<:Bidiagonal}, b::AbstractVector) = ldiv!(copy(A), b) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index cd17474407d6f..bc531d9ac4801 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -637,11 +637,14 @@ end # disambiguation methods: * of Diagonal and Adj/Trans AbsVec *(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x))) +*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x))) *(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y)) -*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x))) *(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y)) +function dot(x::AbstractVector, D::Diagonal, y::AbstractVector) + mapreduce(t -> dot(t[1], t[2], t[3]), +, zip(x, D.diag, y)) +end function cholesky!(A::Diagonal, ::Val{false} = Val(false); check::Bool = true) info = 0 diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index b2267463aab45..1ffdf95ba6978 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -874,6 +874,51 @@ function dot(x::AbstractArray, y::AbstractArray) s end +""" + dot(x, A, y) + +Compute the generalized dot product `dot(x, A*y)` between two vectors `x` and `y`, +without storing the intermediate result of `A*y`. As for the two-argument +[`dot(_,_)`](@ref), this acts recursively. Moreover, for complex vectors, the +first vector is conjugated. + +!!! compat "Julia 1.4" + Three-argument `dot` requires at least Julia 1.4. + +# Examples +```jldoctest +julia> dot([1; 1], [1 2; 3 4], [2; 3]) +26 + +julia> dot(1:5, reshape(1:25, 5, 5), 2:6) +4850 + +julia> ⋅(1:5, reshape(1:25, 5, 5), 2:6) == dot(1:5, reshape(1:25, 5, 5), 2:6) +true +``` +""" +dot(x, A, y) = dot(x, A*y) # generic fallback for cases that are not covered by specialized methods + +function dot(x::AbstractVector, A::AbstractMatrix, y::AbstractVector) + (axes(x)..., axes(y)...) == axes(A) || throw(DimensionMismatch()) + T = typeof(dot(first(x), first(A), first(y))) + s = zero(T) + i₁ = first(eachindex(x)) + x₁ = first(x) + @inbounds for j in eachindex(y) + yj = y[j] + if !iszero(yj) + temp = zero(adjoint(A[i₁,j]) * x₁) + @simd for i in eachindex(x) + temp += adjoint(A[i,j]) * x[i] + end + s += dot(temp, yj) + end + end + return s +end +dot(x::AbstractVector, adjA::Adjoint, y::AbstractVector) = adjoint(dot(y, adjA.parent, x)) +dot(x::AbstractVector, transA::Transpose{<:Real}, y::AbstractVector) = adjoint(dot(y, transA.parent, x)) ########################################################################################### diff --git a/stdlib/LinearAlgebra/src/hessenberg.jl b/stdlib/LinearAlgebra/src/hessenberg.jl index 341d2c71bf332..5f9b2be1d34a2 100644 --- a/stdlib/LinearAlgebra/src/hessenberg.jl +++ b/stdlib/LinearAlgebra/src/hessenberg.jl @@ -284,6 +284,37 @@ function logabsdet(F::UpperHessenberg; shift::Number=false) return (logdeterminant, P) end +function dot(x::AbstractVector, H::UpperHessenberg, y::AbstractVector) + require_one_based_indexing(x, y) + m = size(H, 1) + (length(x) == m == length(y)) || throw(DimensionMismatch()) + if iszero(m) + return dot(zero(eltype(x)), zero(eltype(H)), zero(eltype(y))) + end + x₁ = x[1] + r = dot(x₁, H[1,1], y[1]) + r += dot(x[2], H[2,1], y[1]) + @inbounds for j in 2:m-1 + yj = y[j] + if !iszero(yj) + temp = adjoint(H[1,j]) * x₁ + @simd for i in 2:j+1 + temp += adjoint(H[i,j]) * x[i] + end + r += dot(temp, yj) + end + end + ym = y[m] + if !iszero(ym) + temp = adjoint(H[1,m]) * x₁ + @simd for i in 2:m + temp += adjoint(H[i,m]) * x[i] + end + r += dot(temp, ym) + end + return r +end + ###################################################################################### # Hessenberg factorizations Q(H+μI)Q' of A+μI: diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index 918e5747e3c69..1f3829d3727bf 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -457,6 +457,31 @@ end *(A::HermOrSym, B::HermOrSym) = A * copyto!(similar(parent(B)), B) +function dot(x::AbstractVector, A::RealHermSymComplexHerm, y::AbstractVector) + require_one_based_indexing(x, y) + (length(x) == length(y) == size(A, 1)) || throw(DimensionMismatch()) + data = A.data + r = zero(eltype(x)) * zero(eltype(A)) * zero(eltype(y)) + if A.uplo == 'U' + @inbounds for j = 1:length(y) + r += dot(x[j], real(data[j,j]), y[j]) + @simd for i = 1:j-1 + Aij = data[i,j] + r += dot(x[i], Aij, y[j]) + dot(x[j], adjoint(Aij), y[i]) + end + end + else # A.uplo == 'L' + @inbounds for j = 1:length(y) + r += dot(x[j], real(data[j,j]), y[j]) + @simd for i = j+1:length(y) + Aij = data[i,j] + r += dot(x[i], Aij, y[j]) + dot(x[j], adjoint(Aij), y[i]) + end + end + end + return r +end + # Fallbacks to avoid generic_matvecmul!/generic_matmatmul! ## Symmetric{<:Number} and Hermitian{<:Real} are invariant to transpose; peel off the t *(transA::Transpose{<:Any,<:RealHermSymComplexSym}, B::AbstractVector) = transA.parent * B diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index 903e8954319ee..1c156fa6bf326 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -545,6 +545,90 @@ end rmul!(A::Union{UpperTriangular,LowerTriangular}, c::Number) = mul!(A, A, c) lmul!(c::Number, A::Union{UpperTriangular,LowerTriangular}) = mul!(A, c, A) +function dot(x::AbstractVector, A::UpperTriangular, y::AbstractVector) + require_one_based_indexing(x, y) + m = size(A, 1) + (length(x) == m == length(y)) || throw(DimensionMismatch()) + if iszero(m) + return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y))) + end + x₁ = x[1] + r = dot(x₁, A[1,1], y[1]) + @inbounds for j in 2:m + yj = y[j] + if !iszero(yj) + temp = adjoint(A[1,j]) * x₁ + @simd for i in 2:j + temp += adjoint(A[i,j]) * x[i] + end + r += dot(temp, yj) + end + end + return r +end +function dot(x::AbstractVector, A::UnitUpperTriangular, y::AbstractVector) + require_one_based_indexing(x, y) + m = size(A, 1) + (length(x) == m == length(y)) || throw(DimensionMismatch()) + if iszero(m) + return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y))) + end + x₁ = first(x) + r = dot(x₁, y[1]) + @inbounds for j in 2:m + yj = y[j] + if !iszero(yj) + temp = adjoint(A[1,j]) * x₁ + @simd for i in 2:j-1 + temp += adjoint(A[i,j]) * x[i] + end + r += dot(temp, yj) + r += dot(x[j], yj) + end + end + return r +end +function dot(x::AbstractVector, A::LowerTriangular, y::AbstractVector) + require_one_based_indexing(x, y) + m = size(A, 1) + (length(x) == m == length(y)) || throw(DimensionMismatch()) + if iszero(m) + return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y))) + end + r = zero(typeof(dot(first(x), first(A), first(y)))) + @inbounds for j in 1:m + yj = y[j] + if !iszero(yj) + temp = adjoint(A[j,j]) * x[j] + @simd for i in j+1:m + temp += adjoint(A[i,j]) * x[i] + end + r += dot(temp, yj) + end + end + return r +end +function dot(x::AbstractVector, A::UnitLowerTriangular, y::AbstractVector) + require_one_based_indexing(x, y) + m = size(A, 1) + (length(x) == m == length(y)) || throw(DimensionMismatch()) + if iszero(m) + return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y))) + end + r = zero(typeof(dot(first(x), first(y)))) + @inbounds for j in 1:m + yj = y[j] + if !iszero(yj) + temp = x[j] + @simd for i in j+1:m + temp += adjoint(A[i,j]) * x[i] + end + r += dot(temp, yj) + end + end + return r +end + fillstored!(A::LowerTriangular, x) = (fillband!(A.data, x, 1-size(A,1), 0); A) fillstored!(A::UnitLowerTriangular, x) = (fillband!(A.data, x, 1-size(A,1), -1); A) fillstored!(A::UpperTriangular, x) = (fillband!(A.data, x, 0, size(A,2)-1); A) diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index 72d8226f476e9..670ef8dcd1fe0 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -202,6 +202,27 @@ end return C end +function dot(x::AbstractVector, S::SymTridiagonal, y::AbstractVector) + require_one_based_indexing(x, y) + nx, ny = length(x), length(y) + (nx == size(S, 1) == ny) || throw(DimensionMismatch()) + if iszero(nx) + return dot(zero(eltype(x)), zero(eltype(S)), zero(eltype(y))) + end + dv, ev = S.dv, S.ev + x₀ = x[1] + x₊ = x[2] + sub = transpose(ev[1]) + r = dot(adjoint(dv[1])*x₀ + adjoint(sub)*x₊, y[1]) + @inbounds for j in 2:nx-1 + x₋, x₀, x₊ = x₀, x₊, x[j+1] + sup, sub = transpose(sub), transpose(ev[j]) + r += dot(adjoint(sup)*x₋ + adjoint(dv[j])*x₀ + adjoint(sub)*x₊, y[j]) + end + r += dot(adjoint(transpose(sub))*x₀ + adjoint(dv[nx])*x₊, y[nx]) + return r +end + (\)(T::SymTridiagonal, B::StridedVecOrMat) = ldlt(T)\B # division with optional shift for use in shifted-Hessenberg solvers (hessenberg.jl): @@ -657,3 +678,22 @@ end Base._sum(A::Tridiagonal, ::Colon) = sum(A.d) + sum(A.dl) + sum(A.du) Base._sum(A::SymTridiagonal, ::Colon) = sum(A.dv) + 2sum(A.ev) + +function dot(x::AbstractVector, A::Tridiagonal, y::AbstractVector) + require_one_based_indexing(x, y) + nx, ny = length(x), length(y) + (nx == size(A, 1) == ny) || throw(DimensionMismatch()) + if iszero(nx) + return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y))) + end + x₀ = x[1] + x₊ = x[2] + dl, d, du = A.dl, A.d, A.du + r = dot(adjoint(d[1])*x₀ + adjoint(dl[1])*x₊, y[1]) + @inbounds for j in 2:nx-1 + x₋, x₀, x₊ = x₀, x₊, x[j+1] + r += dot(adjoint(du[j-1])*x₋ + adjoint(d[j])*x₀ + adjoint(dl[j])*x₊, y[j]) + end + r += dot(adjoint(du[nx-1])*x₀ + adjoint(d[nx])*x₊, y[nx]) + return r +end diff --git a/stdlib/LinearAlgebra/src/uniformscaling.jl b/stdlib/LinearAlgebra/src/uniformscaling.jl index 7b6c3a58000dc..d268120262e94 100644 --- a/stdlib/LinearAlgebra/src/uniformscaling.jl +++ b/stdlib/LinearAlgebra/src/uniformscaling.jl @@ -400,3 +400,7 @@ Array(s::UniformScaling, dims::Dims{2}) = Matrix(s, dims) ## Diagonal construction from UniformScaling Diagonal{T}(s::UniformScaling, m::Integer) where {T} = Diagonal{T}(fill(T(s.λ), m)) Diagonal(s::UniformScaling, m::Integer) = Diagonal{eltype(s)}(s, m) + +dot(x::AbstractVector, J::UniformScaling, y::AbstractVector) = dot(x, J.λ, y) +dot(x::AbstractVector, a::Number, y::AbstractVector) = sum(t -> dot(t[1], a, t[2]), zip(x, y)) +dot(x::AbstractVector, a::Union{Real,Complex}, y::AbstractVector) = a*dot(x, y) diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index 166449018d1f0..d55fbc32a7242 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -455,4 +455,17 @@ end @test A * Tridiagonal(ones(1, 1)) == A end +@testset "generalized dot" begin + for elty in (Float64, ComplexF64) + dv = randn(elty, 5) + ev = randn(elty, 4) + x = randn(elty, 5) + y = randn(elty, 5) + for uplo in (:U, :L) + B = Bidiagonal(dv, ev, uplo) + @test dot(x, B, y) ≈ dot(B'x, y) ≈ dot(x, Matrix(B), y) + end + end +end + end # module TestBidiagonal diff --git a/stdlib/LinearAlgebra/test/generic.jl b/stdlib/LinearAlgebra/test/generic.jl index 8880325d21e03..113fa15db3366 100644 --- a/stdlib/LinearAlgebra/test/generic.jl +++ b/stdlib/LinearAlgebra/test/generic.jl @@ -409,4 +409,32 @@ end @test all(!isnan, lmul!(false, Any[NaN])) end +@testset "generalized dot #32739" begin + for elty in (Int, Float32, Float64, BigFloat, Complex{Float32}, Complex{Float64}, Complex{BigFloat}) + n = 10 + if elty <: Int + A = rand(-n:n, n, n) + x = rand(-n:n, n) + y = rand(-n:n, n) + elseif elty <: Real + A = convert(Matrix{elty}, randn(n,n)) + x = rand(elty, n) + y = rand(elty, n) + else + A = convert(Matrix{elty}, complex.(randn(n,n), randn(n,n))) + x = rand(elty, n) + y = rand(elty, n) + end + @test dot(x, A, y) ≈ dot(A'x, y) ≈ *(x', A, y) ≈ (x'A)*y + @test dot(x, A', y) ≈ dot(A*x, y) ≈ *(x', A', y) ≈ (x'A')*y + elty <: Real && @test dot(x, transpose(A), y) ≈ dot(x, transpose(A)*y) ≈ *(x', transpose(A), y) ≈ (x'*transpose(A))*y + B = reshape([A], 1, 1) + x = [x] + y = [y] + @test dot(x, B, y) ≈ dot(B'x, y) + @test dot(x, B', y) ≈ dot(B*x, y) + elty <: Real && @test dot(x, transpose(B), y) ≈ dot(x, transpose(B)*y) + end +end + end # module TestGeneric diff --git a/stdlib/LinearAlgebra/test/hessenberg.jl b/stdlib/LinearAlgebra/test/hessenberg.jl index 4db168dfe9189..b7cc07d09446f 100644 --- a/stdlib/LinearAlgebra/test/hessenberg.jl +++ b/stdlib/LinearAlgebra/test/hessenberg.jl @@ -88,6 +88,11 @@ let n = 10 @test det(H + shift*I) ≈ det(A + shift*I) @test logabsdet(H + shift*I) ≅ logabsdet(A + shift*I) end + + HM = Matrix(h) + @test dot(b, h, b) ≈ dot(h'b, b) ≈ dot(b, HM, b) ≈ dot(HM'b, b) + c = b .+ 1 + @test dot(b, h, c) ≈ dot(h'b, c) ≈ dot(b, HM, c) ≈ dot(HM'b, c) end end diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index a7538fcd12382..ebdfc9c93c72f 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -365,6 +365,18 @@ end @test Symmetric(asym)\b ≈ asym\b end end + @testset "generalized dot product" begin + for uplo in (:U, :L) + @test dot(x, Hermitian(aherm, uplo), y) ≈ dot(x, Hermitian(aherm, uplo)*y) ≈ dot(x, Matrix(Hermitian(aherm, uplo)), y) + @test dot(x, Hermitian(aherm, uplo), x) ≈ dot(x, Hermitian(aherm, uplo)*x) ≈ dot(x, Matrix(Hermitian(aherm, uplo)), x) + end + if eltya <: Real + for uplo in (:U, :L) + @test dot(x, Symmetric(aherm, uplo), y) ≈ dot(x, Symmetric(aherm, uplo)*y) ≈ dot(x, Matrix(Symmetric(aherm, uplo)), y) + @test dot(x, Symmetric(aherm, uplo), x) ≈ dot(x, Symmetric(aherm, uplo)*x) ≈ dot(x, Matrix(Symmetric(aherm, uplo)), x) + end + end + end end end diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index 612fe05374dcd..5f6ec8771a030 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -236,6 +236,17 @@ for elty1 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFlo end end + # generalized dot + for eltyb in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFloat}) + b1 = convert(Vector{eltyb}, (elty1 <: Complex ? real(A1) : A1)*fill(1., n)) + b2 = convert(Vector{eltyb}, (elty1 <: Complex ? real(A1) : A1)*randn(n)) + if elty1 in (BigFloat, Complex{BigFloat}) || eltyb in (BigFloat, Complex{BigFloat}) + @test dot(b1, A1, b2) ≈ dot(A1'b1, b2) atol=sqrt(max(eps(real(float(one(elty1)))),eps(real(float(one(eltyb))))))*n*n + else + @test dot(b1, A1, b2) ≈ dot(A1'b1, b2) atol=sqrt(max(eps(real(float(one(elty1)))),eps(real(float(one(eltyb))))))*n*n + end + end + # Binary operations @test A1*0.5 == Matrix(A1)*0.5 @test 0.5*A1 == 0.5*Matrix(A1) diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index 686a1010914ab..8431c81fd3ef8 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -390,6 +390,11 @@ end end end end + @testset "generalized dot" begin + x = fill(convert(elty, 1), n) + y = fill(convert(elty, 1), n) + @test dot(x, A, y) ≈ dot(A'x, y) + end end end diff --git a/stdlib/LinearAlgebra/test/uniformscaling.jl b/stdlib/LinearAlgebra/test/uniformscaling.jl index 21a8cf0bb542c..63096fac75124 100644 --- a/stdlib/LinearAlgebra/test/uniformscaling.jl +++ b/stdlib/LinearAlgebra/test/uniformscaling.jl @@ -4,6 +4,10 @@ module TestUniformscaling using Test, LinearAlgebra, Random, SparseArrays +const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test") +isdefined(Main, :Quaternions) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Quaternions.jl")) +using .Main.Quaternions + Random.seed!(123) @testset "basic functions" begin @@ -330,4 +334,15 @@ end @test I(3) == [1 0 0; 0 1 0; 0 0 1] end +@testset "generalized dot" begin + x = rand(-10:10, 3) + y = rand(-10:10, 3) + λ = rand(-10:10) + J = UniformScaling(λ) + @test dot(x, J, y) == λ*dot(x, y) + λ = Quaternion(0.44567, 0.755871, 0.882548, 0.423612) + x, y = Quaternion(rand(4)...), Quaternion(rand(4)...) + @test dot([x], λ*I, [y]) ≈ dot(x, λ, y) ≈ dot(x, λ*y) +end + end # module TestUniformscaling diff --git a/stdlib/SparseArrays/src/linalg.jl b/stdlib/SparseArrays/src/linalg.jl index 437c5a8683eea..2576a0752086e 100644 --- a/stdlib/SparseArrays/src/linalg.jl +++ b/stdlib/SparseArrays/src/linalg.jl @@ -321,6 +321,54 @@ function dot(A::AbstractSparseMatrixCSC{T1,S1},B::AbstractSparseMatrixCSC{T2,S2} return r end +function dot(x::AbstractVector, A::AbstractSparseMatrixCSC, y::AbstractVector) + require_one_based_indexing(x, y) + m, n = size(A) + (length(x) == m && n == length(y)) || throw(DimensionMismatch()) + if iszero(m) || iszero(n) + return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y))) + end + T = promote_type(eltype(x), eltype(A), eltype(y)) + r = zero(T) + rvals = getrowval(A) + nzvals = getnzval(A) + @inbounds for col in 1:n + ycol = y[col] + if !iszero(ycol) + temp = zero(T) + for k in nzrange(A, col) + temp += adjoint(x[rvals[k]]) * nzvals[k] + end + r += temp * ycol + end + end + return r +end +function dot(x::SparseVector, A::AbstractSparseMatrixCSC, y::SparseVector) + m, n = size(A) + length(x) == m && n == length(y) || throw(DimensionMismatch()) + if iszero(m) || iszero(n) + return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y))) + end + r = zero(promote_type(eltype(x), eltype(A), eltype(y))) + xnzind = nonzeroinds(x) + xnzval = nonzeros(x) + ynzind = nonzeroinds(y) + ynzval = nonzeros(y) + Acolptr = getcolptr(A) + Arowval = getrowval(A) + Anzval = getnzval(A) + for (yi, yv) in zip(ynzind, ynzval) + A_ptr_lo = Acolptr[yi] + A_ptr_hi = Acolptr[yi+1] - 1 + if A_ptr_lo <= A_ptr_hi + r += _spdot(dot, 1, length(xnzind), xnzind, xnzval, + A_ptr_lo, A_ptr_hi, Arowval, Anzval) * yv + end + end + r +end + ## triangular sparse handling possible_adjoint(adj::Bool, a::Real ) = a diff --git a/stdlib/SparseArrays/test/sparse.jl b/stdlib/SparseArrays/test/sparse.jl index bd85f28609c44..6f21e18ae0bc3 100644 --- a/stdlib/SparseArrays/test/sparse.jl +++ b/stdlib/SparseArrays/test/sparse.jl @@ -424,6 +424,17 @@ end @test_throws DimensionMismatch dot(sprand(5,5,0.2),sprand(5,6,0.2)) end +@testset "generalized dot product" begin + for i = 1:5 + A = sprand(ComplexF64, 10, 15, 0.4) + Av = view(A, :, :) + x = sprand(ComplexF64, 10, 0.5) + y = sprand(ComplexF64, 15, 0.5) + @test dot(x, A, y) ≈ dot(Vector(x), A, Vector(y)) ≈ (Vector(x)' * Matrix(A)) * Vector(y) + @test dot(x, A, y) ≈ dot(x, Av, y) + end +end + const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test") isdefined(Main, :Quaternions) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Quaternions.jl")) using .Main.Quaternions