From 45e6ff846dffc3f827ff7f7a90d075101ece1f9c Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 22 Oct 2024 13:32:08 +0200 Subject: [PATCH] CUBLAS: Don't use BLAS1 wrappers for strided arrays, only vectors. --- lib/cublas/linalg.jl | 16 +++++++--- lib/cublas/wrappers.jl | 70 +++++++++++++++++++++++------------------- test/base/linalg.jl | 6 +++- 3 files changed, 56 insertions(+), 36 deletions(-) diff --git a/lib/cublas/linalg.jl b/lib/cublas/linalg.jl index 8c7fcb79a9..8957c58e6d 100644 --- a/lib/cublas/linalg.jl +++ b/lib/cublas/linalg.jl @@ -13,18 +13,24 @@ LinearAlgebra.rmul!(x::StridedCuArray{<:CublasFloat}, k::Number) = LinearAlgebra.rmul!(x::DenseCuArray{<:CublasFloat}, k::Real) = invoke(rmul!, Tuple{typeof(x), Number}, x, k) -function LinearAlgebra.dot(x::StridedCuArray{T}, y::StridedCuArray{T}) where T<:Union{Float16, CublasReal} +function LinearAlgebra.dot(x::StridedCuVector{T}, + y::StridedCuVector{T}) where T<:Union{Float16, CublasReal} n = length(x) n==length(y) || throw(DimensionMismatch("dot product arguments have lengths $(length(x)) and $(length(y))")) dot(n, x, y) end -function LinearAlgebra.dot(x::StridedCuArray{T}, y::StridedCuArray{T}) where T<:Union{ComplexF16, CublasComplex} +function LinearAlgebra.dot(x::StridedCuVector{T}, + y::StridedCuVector{T}) where T<:Union{ComplexF16, CublasComplex} n = length(x) n==length(y) || throw(DimensionMismatch("dot product arguments have lengths $(length(x)) and $(length(y))")) dotc(n, x, y) end +# resolve ambiguities with generic wrapper below +LinearAlgebra.dot(x::CuArray{T}, y::CuArray{T}) where T<:Union{Float32, Float64} = + invoke(LinearAlgebra.dot, Tuple{StridedCuArray{T}, StridedCuArray{T}}, x, y) + # generic fallback function LinearAlgebra.dot(x::AnyCuArray{T1}, y::AnyCuArray{T2}) where {T1,T2} n = length(x) @@ -97,14 +103,16 @@ function LinearAlgebra.dot(x::AnyCuArray{T1}, y::AnyCuArray{T2}) where {T1,T2} end end -function LinearAlgebra.:(*)(transx::Transpose{<:Any,<:StridedCuVector{T}}, y::StridedCuVector{T}) where T<:Union{ComplexF16, CublasComplex} +function LinearAlgebra.:(*)(transx::Transpose{<:Any,<:StridedCuVector{T}}, + y::StridedCuVector{T}) where T<:Union{ComplexF16, CublasComplex} x = transx.parent n = length(x) n==length(y) || throw(DimensionMismatch("dot product arguments have lengths $(length(x)) and $(length(y))")) return dotu(n, x, y) end -function LinearAlgebra.norm(x::DenseCuArray{<:Union{Float16, ComplexF16, CublasFloat}}, p::Real=2) +function LinearAlgebra.norm(x::DenseCuArray{<:Union{Float16, ComplexF16, CublasFloat}}, + p::Real=2) if p == 2 return nrm2(x) else diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index 150a01f66e..11e56d1530 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -78,6 +78,13 @@ function juliaStorageType(T::Type{<:Complex}, ct::cublasComputeType_t) end # Level 1 + +# most level 1 routines are intended for use on vectors, so only accept a single stride. +# however, it is often convenient to also use these routines on arbitrary arrays, +# interpreting them as vectors. this does not work with arbitrary strides, so we +# define a union matching arrays with only a non-unit stride in the first dimension. +const StridedCuVecOrDenseMat{T} = Union{StridedCuVector{T}, DenseCuArray{T}} + ## copy for (fname, fname_64, elty) in ((:cublasDcopy_v2, :cublasDcopy_v2_64, :Float64), (:cublasScopy_v2, :cublasScopy_v2_64, :Float32), @@ -85,8 +92,8 @@ for (fname, fname_64, elty) in ((:cublasDcopy_v2, :cublasDcopy_v2_64, :Float64), (:cublasCcopy_v2, :cublasCcopy_v2_64, :ComplexF32)) @eval begin function copy!(n::Integer, - x::StridedCuArray{$elty}, - y::StridedCuArray{$elty},) + x::StridedCuVecOrDenseMat{$elty}, + y::StridedCuVecOrDenseMat{$elty},) if CUBLAS.version() >= v"12.0" $fname_64(handle(), n, x, stride(x, 1), y, stride(y, 1)) else @@ -96,7 +103,8 @@ for (fname, fname_64, elty) in ((:cublasDcopy_v2, :cublasDcopy_v2_64, :Float64), end end end -function copy!(n::Integer, x::StridedCuArray{T}, y::StridedCuArray{T}) where {T <: Union{Float16, ComplexF16}} +function copy!(n::Integer, x::StridedCuVecOrDenseMat{T}, + y::StridedCuVecOrDenseMat{T}) where {T <: Union{Float16, ComplexF16}} copyto!(y, x) # bad end @@ -108,7 +116,7 @@ for (fname, fname_64, elty) in ((:cublasDscal_v2, :cublasDscal_v2_64, :Float64), @eval begin function scal!(n::Integer, alpha::Number, - x::StridedCuArray{$elty}) + x::StridedCuVecOrDenseMat{$elty}) if CUBLAS.version() >= v"12.0" $fname_64(handle(), n, alpha, x, stride(x, 1)) else @@ -118,7 +126,7 @@ for (fname, fname_64, elty) in ((:cublasDscal_v2, :cublasDscal_v2_64, :Float64), end end end -function scal!(n::Integer, alpha::Number, x::StridedCuArray{Float16}) +function scal!(n::Integer, alpha::Number, x::StridedCuVecOrDenseMat{Float16}) α = convert(Float32, alpha) cublasScalEx(handle(), n, Ref{Float32}(α), Float32, x, Float16, stride(x, 1), Float32) return x @@ -129,7 +137,7 @@ for (fname, fname_64, elty, celty) in ((:cublasCsscal_v2, :cublasCsscal_v2_64, : @eval begin function scal!(n::Integer, alpha::$elty, - x::StridedCuArray{$celty}) + x::StridedCuVecOrDenseMat{$celty}) if CUBLAS.version() >= v"12.0" $fname_64(handle(), n, alpha, x, stride(x, 1)) else @@ -139,7 +147,7 @@ for (fname, fname_64, elty, celty) in ((:cublasCsscal_v2, :cublasCsscal_v2_64, : end end end -function scal!(n::Integer, alpha::Number, x::StridedCuArray{ComplexF16}) +function scal!(n::Integer, alpha::Number, x::StridedCuVecOrDenseMat{ComplexF16}) wide_x = widen.(x) scal!(n, alpha, wide_x) thin_x = convert(typeof(x), wide_x) @@ -156,8 +164,8 @@ for (jname, fname, fname_64, elty) in ((:dot, :cublasDdot_v2, :cublasDdot_v2_64, (:dotu, :cublasCdotu_v2, :cublasCdotu_v2_64, :ComplexF32)) @eval begin function $jname(n::Integer, - x::StridedCuArray{$elty}, - y::StridedCuArray{$elty}) + x::StridedCuVecOrDenseMat{$elty}, + y::StridedCuVecOrDenseMat{$elty}) result = Ref{$elty}() if CUBLAS.version() >= v"12.0" $fname_64(handle(), n, x, stride(x, 1), y, stride(y, 1), result) @@ -168,15 +176,15 @@ for (jname, fname, fname_64, elty) in ((:dot, :cublasDdot_v2, :cublasDdot_v2_64, end end end -function dot(n::Integer, x::StridedCuArray{Float16}, y::StridedCuArray{Float16}) +function dot(n::Integer, x::StridedCuVecOrDenseMat{Float16}, y::StridedCuVecOrDenseMat{Float16}) result = Ref{Float16}() cublasDotEx(handle(), n, x, Float16, stride(x, 1), y, Float16, stride(y, 1), result, Float16, Float32) return result[] end -function dotc(n::Integer, x::StridedCuArray{ComplexF16}, y::StridedCuArray{ComplexF16}) +function dotc(n::Integer, x::StridedCuVecOrDenseMat{ComplexF16}, y::StridedCuVecOrDenseMat{ComplexF16}) convert(ComplexF16, dotc(n, convert(CuArray{ComplexF32}, x), convert(CuArray{ComplexF32}, y))) end -function dotu(n::Integer, x::StridedCuArray{ComplexF16}, y::DenseCuArray{ComplexF16}) +function dotu(n::Integer, x::StridedCuVecOrDenseMat{ComplexF16}, y::StridedCuVecOrDenseMat{ComplexF16}) convert(ComplexF16, dotu(n, convert(CuArray{ComplexF32}, x), convert(CuArray{ComplexF32}, y))) end @@ -187,7 +195,7 @@ for (fname, fname_64, elty, ret_type) in ((:cublasDnrm2_v2, :cublasDnrm2_v2_64, (:cublasScnrm2_v2, :cublasScnrm2_v2_64, :ComplexF32, :Float32)) @eval begin function nrm2(n::Integer, - X::StridedCuArray{$elty}) + X::StridedCuVecOrDenseMat{$elty}) result = Ref{$ret_type}() if CUBLAS.version() >= v"12.0" $fname_64(handle(), n, X, stride(X, 1), result) @@ -198,14 +206,14 @@ for (fname, fname_64, elty, ret_type) in ((:cublasDnrm2_v2, :cublasDnrm2_v2_64, end end end -nrm2(x::StridedCuArray) = nrm2(length(x), x) +nrm2(x::StridedCuVecOrDenseMat) = nrm2(length(x), x) -function nrm2(n::Integer, x::StridedCuArray{Float16}) +function nrm2(n::Integer, x::StridedCuVecOrDenseMat{Float16}) result = Ref{Float16}() cublasNrm2Ex(handle(), n, x, Float16, stride(x, 1), result, Float16, Float32) return result[] end -function nrm2(n::Integer, x::StridedCuArray{ComplexF16}) +function nrm2(n::Integer, x::StridedCuVecOrDenseMat{ComplexF16}) wide_x = widen.(x) nrm = nrm2(n, wide_x) return convert(Float16, nrm) @@ -218,7 +226,7 @@ for (fname, fname_64, elty, ret_type) in ((:cublasDasum_v2, :cublasDasum_v2_64, (:cublasScasum_v2, :cublasScasum_v2_64, :ComplexF32, :Float32)) @eval begin function asum(n::Integer, - x::StridedCuArray{$elty}) + x::StridedCuVecOrDenseMat{$elty}) result = Ref{$ret_type}() if CUBLAS.version() >= v"12.0" $fname_64(handle(), n, x, stride(x, 1), result) @@ -238,8 +246,8 @@ for (fname, fname_64, elty) in ((:cublasDaxpy_v2, :cublasDaxpy_v2_64, :Float64), @eval begin function axpy!(n::Integer, alpha::Number, - dx::StridedCuArray{$elty}, - dy::StridedCuArray{$elty}) + dx::StridedCuVecOrDenseMat{$elty}, + dy::StridedCuVecOrDenseMat{$elty}) if CUBLAS.version() >= v"12.0" $fname_64(handle(), n, alpha, dx, stride(dx, 1), dy, stride(dy, 1)) else @@ -250,12 +258,12 @@ for (fname, fname_64, elty) in ((:cublasDaxpy_v2, :cublasDaxpy_v2_64, :Float64), end end -function axpy!(n::Integer, alpha::Number, dx::StridedCuArray{Float16}, dy::StridedCuArray{Float16}) +function axpy!(n::Integer, alpha::Number, dx::StridedCuVecOrDenseMat{Float16}, dy::StridedCuVecOrDenseMat{Float16}) α = convert(Float32, alpha) cublasAxpyEx(handle(), n, Ref{Float32}(α), Float32, dx, Float16, stride(dx, 1), dy, Float16, stride(dy, 1), Float32) return dy end -function axpy!(n::Integer, alpha::Number, dx::StridedCuArray{ComplexF16}, dy::StridedCuArray{ComplexF16}) +function axpy!(n::Integer, alpha::Number, dx::StridedCuVecOrDenseMat{ComplexF16}, dy::StridedCuVecOrDenseMat{ComplexF16}) wide_x = widen.(dx) wide_y = widen.(dy) axpy!(n, alpha, wide_x, wide_y) @@ -273,8 +281,8 @@ for (fname, fname_64, elty, sty) in ((:cublasSrot_v2, :cublasSrot_v2_64, :Float3 (:cublasZdrot_v2, :cublasZdrot_v2_64, :ComplexF64, :Real)) @eval begin function rot!(n::Integer, - x::StridedCuArray{$elty}, - y::StridedCuArray{$elty}, + x::StridedCuVecOrDenseMat{$elty}, + y::StridedCuVecOrDenseMat{$elty}, c::Real, s::$sty) if CUBLAS.version() >= v"12.0" @@ -294,8 +302,8 @@ for (fname, fname_64, elty) in ((:cublasSswap_v2, :cublasSswap_v2_64, :Float32), (:cublasZswap_v2, :cublasZswap_v2_64, :ComplexF64)) @eval begin function swap!(n::Integer, - x::StridedCuArray{$elty}, - y::StridedCuArray{$elty}) + x::StridedCuVecOrDenseMat{$elty}, + y::StridedCuVecOrDenseMat{$elty}) if CUBLAS.version() >= v"12.0" $fname_64(handle(), n, x, stride(x, 1), y, stride(y, 1)) else @@ -308,9 +316,9 @@ end function axpby!(n::Integer, alpha::Number, - dx::StridedCuArray{T}, + dx::StridedCuVecOrDenseMat{T}, beta::Number, - dy::StridedCuArray{T}) where T <: Union{Float16, ComplexF16, CublasFloat} + dy::StridedCuVecOrDenseMat{T}) where T <: Union{Float16, ComplexF16, CublasFloat} scal!(n, beta, dy) axpy!(n, alpha, dx, dy) dy @@ -324,7 +332,7 @@ for (fname, fname_64, elty) in ((:cublasIdamax_v2, :cublasIdamax_v2_64, :Float64 (:cublasIcamax_v2, :cublasIcamax_v2_64, :ComplexF32)) @eval begin function iamax(n::Integer, - dx::StridedCuArray{$elty}) + dx::StridedCuVecOrDenseMat{$elty}) if CUBLAS.version() >= v"12.0" result = Ref{Int64}() $fname_64(handle(), n, dx, stride(dx, 1), result) @@ -336,7 +344,7 @@ for (fname, fname_64, elty) in ((:cublasIdamax_v2, :cublasIdamax_v2_64, :Float64 end end end -iamax(dx::StridedCuArray) = iamax(length(dx), dx) +iamax(dx::StridedCuVecOrDenseMat) = iamax(length(dx), dx) ## iamin # iamin is not in standard blas is a CUBLAS extension @@ -346,7 +354,7 @@ for (fname, fname_64, elty) in ((:cublasIdamin_v2, :cublasIdamin_v2_64, :Float64 (:cublasIcamin_v2, :cublasIcamin_v2_64, :ComplexF32)) @eval begin function iamin(n::Integer, - dx::StridedCuArray{$elty},) + dx::StridedCuVecOrDenseMat{$elty},) if CUBLAS.version() >= v"12.0" result = Ref{Int64}() $fname_64(handle(), n, dx, stride(dx, 1), result) @@ -358,7 +366,7 @@ for (fname, fname_64, elty) in ((:cublasIdamin_v2, :cublasIdamin_v2_64, :Float64 end end end -iamin(dx::StridedCuArray) = iamin(length(dx), dx) +iamin(dx::StridedCuVecOrDenseMat) = iamin(length(dx), dx) # Level 2 ## mv diff --git a/test/base/linalg.jl b/test/base/linalg.jl index e1d91be502..32a93ffa78 100644 --- a/test/base/linalg.jl +++ b/test/base/linalg.jl @@ -16,6 +16,10 @@ end end @test testf(dot, rand(Bool, 1024, 1024), rand(Float64, 1024, 1024)) + + # https://discourse.julialang.org/t/result-of-inner-product-of-two-cuarray-with-views-is-incorrect/121539 + @test testf(dot, view(rand(Float32, 100, 100), 2:99, 2:99), + view(rand(Float32, 100, 100), 2:99, 2:99)) end @testset "kron" begin @@ -33,4 +37,4 @@ end @test Array(kron(A, B)) ≈ kron(Array(A), Array(B)) @test Array(kron(B, A)) ≈ kron(Array(B), Array(A)) end -end \ No newline at end of file +end