Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUBLAS: Don't use BLAS1 wrappers for strided arrays, only vectors. #2528

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions lib/cublas/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,24 @@ LinearAlgebra.rmul!(x::StridedCuArray{<:CublasFloat}, k::Number) =
LinearAlgebra.rmul!(x::DenseCuArray{<:CublasFloat}, k::Real) =
invoke(rmul!, Tuple{typeof(x), Number}, x, k)

function LinearAlgebra.dot(x::StridedCuArray{T}, y::StridedCuArray{T}) where T<:Union{Float16, CublasReal}
function LinearAlgebra.dot(x::StridedCuVector{T},
y::StridedCuVector{T}) where T<:Union{Float16, CublasReal}
n = length(x)
n==length(y) || throw(DimensionMismatch("dot product arguments have lengths $(length(x)) and $(length(y))"))
dot(n, x, y)
end

function LinearAlgebra.dot(x::StridedCuArray{T}, y::StridedCuArray{T}) where T<:Union{ComplexF16, CublasComplex}
function LinearAlgebra.dot(x::StridedCuVector{T},
y::StridedCuVector{T}) where T<:Union{ComplexF16, CublasComplex}
n = length(x)
n==length(y) || throw(DimensionMismatch("dot product arguments have lengths $(length(x)) and $(length(y))"))
dotc(n, x, y)
end

# resolve ambiguities with generic wrapper below
LinearAlgebra.dot(x::CuArray{T}, y::CuArray{T}) where T<:Union{Float32, Float64} =
invoke(LinearAlgebra.dot, Tuple{StridedCuArray{T}, StridedCuArray{T}}, x, y)

# generic fallback
function LinearAlgebra.dot(x::AnyCuArray{T1}, y::AnyCuArray{T2}) where {T1,T2}
n = length(x)
Expand Down Expand Up @@ -97,14 +103,16 @@ function LinearAlgebra.dot(x::AnyCuArray{T1}, y::AnyCuArray{T2}) where {T1,T2}
end
end

function LinearAlgebra.:(*)(transx::Transpose{<:Any,<:StridedCuVector{T}}, y::StridedCuVector{T}) where T<:Union{ComplexF16, CublasComplex}
function LinearAlgebra.:(*)(transx::Transpose{<:Any,<:StridedCuVector{T}},
y::StridedCuVector{T}) where T<:Union{ComplexF16, CublasComplex}
x = transx.parent
n = length(x)
n==length(y) || throw(DimensionMismatch("dot product arguments have lengths $(length(x)) and $(length(y))"))
return dotu(n, x, y)
end

function LinearAlgebra.norm(x::DenseCuArray{<:Union{Float16, ComplexF16, CublasFloat}}, p::Real=2)
function LinearAlgebra.norm(x::DenseCuArray{<:Union{Float16, ComplexF16, CublasFloat}},
p::Real=2)
if p == 2
return nrm2(x)
else
Expand Down
70 changes: 39 additions & 31 deletions lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,22 @@ function juliaStorageType(T::Type{<:Complex}, ct::cublasComputeType_t)
end

# Level 1

# most level 1 routines are intended for use on vectors, so only accept a single stride.
# however, it is often convenient to also use these routines on arbitrary arrays,
# interpreting them as vectors. this does not work with arbitrary strides, so we
# define a union matching arrays with only a non-unit stride in the first dimension.
const StridedCuVecOrDenseMat{T} = Union{StridedCuVector{T}, DenseCuArray{T}}

## copy
for (fname, fname_64, elty) in ((:cublasDcopy_v2, :cublasDcopy_v2_64, :Float64),
(:cublasScopy_v2, :cublasScopy_v2_64, :Float32),
(:cublasZcopy_v2, :cublasZcopy_v2_64, :ComplexF64),
(:cublasCcopy_v2, :cublasCcopy_v2_64, :ComplexF32))
@eval begin
function copy!(n::Integer,
x::StridedCuArray{$elty},
y::StridedCuArray{$elty},)
x::StridedCuVecOrDenseMat{$elty},
y::StridedCuVecOrDenseMat{$elty},)
if CUBLAS.version() >= v"12.0"
$fname_64(handle(), n, x, stride(x, 1), y, stride(y, 1))
else
Expand All @@ -96,7 +103,8 @@ for (fname, fname_64, elty) in ((:cublasDcopy_v2, :cublasDcopy_v2_64, :Float64),
end
end
end
function copy!(n::Integer, x::StridedCuArray{T}, y::StridedCuArray{T}) where {T <: Union{Float16, ComplexF16}}
function copy!(n::Integer, x::StridedCuVecOrDenseMat{T},
y::StridedCuVecOrDenseMat{T}) where {T <: Union{Float16, ComplexF16}}
copyto!(y, x) # bad
end

Expand All @@ -108,7 +116,7 @@ for (fname, fname_64, elty) in ((:cublasDscal_v2, :cublasDscal_v2_64, :Float64),
@eval begin
function scal!(n::Integer,
alpha::Number,
x::StridedCuArray{$elty})
x::StridedCuVecOrDenseMat{$elty})
if CUBLAS.version() >= v"12.0"
$fname_64(handle(), n, alpha, x, stride(x, 1))
else
Expand All @@ -118,7 +126,7 @@ for (fname, fname_64, elty) in ((:cublasDscal_v2, :cublasDscal_v2_64, :Float64),
end
end
end
function scal!(n::Integer, alpha::Number, x::StridedCuArray{Float16})
function scal!(n::Integer, alpha::Number, x::StridedCuVecOrDenseMat{Float16})
α = convert(Float32, alpha)
cublasScalEx(handle(), n, Ref{Float32}(α), Float32, x, Float16, stride(x, 1), Float32)
return x
Expand All @@ -129,7 +137,7 @@ for (fname, fname_64, elty, celty) in ((:cublasCsscal_v2, :cublasCsscal_v2_64, :
@eval begin
function scal!(n::Integer,
alpha::$elty,
x::StridedCuArray{$celty})
x::StridedCuVecOrDenseMat{$celty})
if CUBLAS.version() >= v"12.0"
$fname_64(handle(), n, alpha, x, stride(x, 1))
else
Expand All @@ -139,7 +147,7 @@ for (fname, fname_64, elty, celty) in ((:cublasCsscal_v2, :cublasCsscal_v2_64, :
end
end
end
function scal!(n::Integer, alpha::Number, x::StridedCuArray{ComplexF16})
function scal!(n::Integer, alpha::Number, x::StridedCuVecOrDenseMat{ComplexF16})
wide_x = widen.(x)
scal!(n, alpha, wide_x)
thin_x = convert(typeof(x), wide_x)
Expand All @@ -156,8 +164,8 @@ for (jname, fname, fname_64, elty) in ((:dot, :cublasDdot_v2, :cublasDdot_v2_64,
(:dotu, :cublasCdotu_v2, :cublasCdotu_v2_64, :ComplexF32))
@eval begin
function $jname(n::Integer,
x::StridedCuArray{$elty},
y::StridedCuArray{$elty})
x::StridedCuVecOrDenseMat{$elty},
y::StridedCuVecOrDenseMat{$elty})
result = Ref{$elty}()
if CUBLAS.version() >= v"12.0"
$fname_64(handle(), n, x, stride(x, 1), y, stride(y, 1), result)
Expand All @@ -168,15 +176,15 @@ for (jname, fname, fname_64, elty) in ((:dot, :cublasDdot_v2, :cublasDdot_v2_64,
end
end
end
function dot(n::Integer, x::StridedCuArray{Float16}, y::StridedCuArray{Float16})
function dot(n::Integer, x::StridedCuVecOrDenseMat{Float16}, y::StridedCuVecOrDenseMat{Float16})
result = Ref{Float16}()
cublasDotEx(handle(), n, x, Float16, stride(x, 1), y, Float16, stride(y, 1), result, Float16, Float32)
return result[]
end
function dotc(n::Integer, x::StridedCuArray{ComplexF16}, y::StridedCuArray{ComplexF16})
function dotc(n::Integer, x::StridedCuVecOrDenseMat{ComplexF16}, y::StridedCuVecOrDenseMat{ComplexF16})
convert(ComplexF16, dotc(n, convert(CuArray{ComplexF32}, x), convert(CuArray{ComplexF32}, y)))
end
function dotu(n::Integer, x::StridedCuArray{ComplexF16}, y::DenseCuArray{ComplexF16})
function dotu(n::Integer, x::StridedCuVecOrDenseMat{ComplexF16}, y::StridedCuVecOrDenseMat{ComplexF16})
convert(ComplexF16, dotu(n, convert(CuArray{ComplexF32}, x), convert(CuArray{ComplexF32}, y)))
end

Expand All @@ -187,7 +195,7 @@ for (fname, fname_64, elty, ret_type) in ((:cublasDnrm2_v2, :cublasDnrm2_v2_64,
(:cublasScnrm2_v2, :cublasScnrm2_v2_64, :ComplexF32, :Float32))
@eval begin
function nrm2(n::Integer,
X::StridedCuArray{$elty})
X::StridedCuVecOrDenseMat{$elty})
result = Ref{$ret_type}()
if CUBLAS.version() >= v"12.0"
$fname_64(handle(), n, X, stride(X, 1), result)
Expand All @@ -198,14 +206,14 @@ for (fname, fname_64, elty, ret_type) in ((:cublasDnrm2_v2, :cublasDnrm2_v2_64,
end
end
end
nrm2(x::StridedCuArray) = nrm2(length(x), x)
nrm2(x::StridedCuVecOrDenseMat) = nrm2(length(x), x)

function nrm2(n::Integer, x::StridedCuArray{Float16})
function nrm2(n::Integer, x::StridedCuVecOrDenseMat{Float16})
result = Ref{Float16}()
cublasNrm2Ex(handle(), n, x, Float16, stride(x, 1), result, Float16, Float32)
return result[]
end
function nrm2(n::Integer, x::StridedCuArray{ComplexF16})
function nrm2(n::Integer, x::StridedCuVecOrDenseMat{ComplexF16})
wide_x = widen.(x)
nrm = nrm2(n, wide_x)
return convert(Float16, nrm)
Expand All @@ -218,7 +226,7 @@ for (fname, fname_64, elty, ret_type) in ((:cublasDasum_v2, :cublasDasum_v2_64,
(:cublasScasum_v2, :cublasScasum_v2_64, :ComplexF32, :Float32))
@eval begin
function asum(n::Integer,
x::StridedCuArray{$elty})
x::StridedCuVecOrDenseMat{$elty})
result = Ref{$ret_type}()
if CUBLAS.version() >= v"12.0"
$fname_64(handle(), n, x, stride(x, 1), result)
Expand All @@ -238,8 +246,8 @@ for (fname, fname_64, elty) in ((:cublasDaxpy_v2, :cublasDaxpy_v2_64, :Float64),
@eval begin
function axpy!(n::Integer,
alpha::Number,
dx::StridedCuArray{$elty},
dy::StridedCuArray{$elty})
dx::StridedCuVecOrDenseMat{$elty},
dy::StridedCuVecOrDenseMat{$elty})
if CUBLAS.version() >= v"12.0"
$fname_64(handle(), n, alpha, dx, stride(dx, 1), dy, stride(dy, 1))
else
Expand All @@ -250,12 +258,12 @@ for (fname, fname_64, elty) in ((:cublasDaxpy_v2, :cublasDaxpy_v2_64, :Float64),
end
end

function axpy!(n::Integer, alpha::Number, dx::StridedCuArray{Float16}, dy::StridedCuArray{Float16})
function axpy!(n::Integer, alpha::Number, dx::StridedCuVecOrDenseMat{Float16}, dy::StridedCuVecOrDenseMat{Float16})
α = convert(Float32, alpha)
cublasAxpyEx(handle(), n, Ref{Float32}(α), Float32, dx, Float16, stride(dx, 1), dy, Float16, stride(dy, 1), Float32)
return dy
end
function axpy!(n::Integer, alpha::Number, dx::StridedCuArray{ComplexF16}, dy::StridedCuArray{ComplexF16})
function axpy!(n::Integer, alpha::Number, dx::StridedCuVecOrDenseMat{ComplexF16}, dy::StridedCuVecOrDenseMat{ComplexF16})
wide_x = widen.(dx)
wide_y = widen.(dy)
axpy!(n, alpha, wide_x, wide_y)
Expand All @@ -273,8 +281,8 @@ for (fname, fname_64, elty, sty) in ((:cublasSrot_v2, :cublasSrot_v2_64, :Float3
(:cublasZdrot_v2, :cublasZdrot_v2_64, :ComplexF64, :Real))
@eval begin
function rot!(n::Integer,
x::StridedCuArray{$elty},
y::StridedCuArray{$elty},
x::StridedCuVecOrDenseMat{$elty},
y::StridedCuVecOrDenseMat{$elty},
c::Real,
s::$sty)
if CUBLAS.version() >= v"12.0"
Expand All @@ -294,8 +302,8 @@ for (fname, fname_64, elty) in ((:cublasSswap_v2, :cublasSswap_v2_64, :Float32),
(:cublasZswap_v2, :cublasZswap_v2_64, :ComplexF64))
@eval begin
function swap!(n::Integer,
x::StridedCuArray{$elty},
y::StridedCuArray{$elty})
x::StridedCuVecOrDenseMat{$elty},
y::StridedCuVecOrDenseMat{$elty})
if CUBLAS.version() >= v"12.0"
$fname_64(handle(), n, x, stride(x, 1), y, stride(y, 1))
else
Expand All @@ -308,9 +316,9 @@ end

function axpby!(n::Integer,
alpha::Number,
dx::StridedCuArray{T},
dx::StridedCuVecOrDenseMat{T},
beta::Number,
dy::StridedCuArray{T}) where T <: Union{Float16, ComplexF16, CublasFloat}
dy::StridedCuVecOrDenseMat{T}) where T <: Union{Float16, ComplexF16, CublasFloat}
scal!(n, beta, dy)
axpy!(n, alpha, dx, dy)
dy
Expand All @@ -324,7 +332,7 @@ for (fname, fname_64, elty) in ((:cublasIdamax_v2, :cublasIdamax_v2_64, :Float64
(:cublasIcamax_v2, :cublasIcamax_v2_64, :ComplexF32))
@eval begin
function iamax(n::Integer,
dx::StridedCuArray{$elty})
dx::StridedCuVecOrDenseMat{$elty})
if CUBLAS.version() >= v"12.0"
result = Ref{Int64}()
$fname_64(handle(), n, dx, stride(dx, 1), result)
Expand All @@ -336,7 +344,7 @@ for (fname, fname_64, elty) in ((:cublasIdamax_v2, :cublasIdamax_v2_64, :Float64
end
end
end
iamax(dx::StridedCuArray) = iamax(length(dx), dx)
iamax(dx::StridedCuVecOrDenseMat) = iamax(length(dx), dx)

## iamin
# iamin is not in standard blas is a CUBLAS extension
Expand All @@ -346,7 +354,7 @@ for (fname, fname_64, elty) in ((:cublasIdamin_v2, :cublasIdamin_v2_64, :Float64
(:cublasIcamin_v2, :cublasIcamin_v2_64, :ComplexF32))
@eval begin
function iamin(n::Integer,
dx::StridedCuArray{$elty},)
dx::StridedCuVecOrDenseMat{$elty},)
if CUBLAS.version() >= v"12.0"
result = Ref{Int64}()
$fname_64(handle(), n, dx, stride(dx, 1), result)
Expand All @@ -358,7 +366,7 @@ for (fname, fname_64, elty) in ((:cublasIdamin_v2, :cublasIdamin_v2_64, :Float64
end
end
end
iamin(dx::StridedCuArray) = iamin(length(dx), dx)
iamin(dx::StridedCuVecOrDenseMat) = iamin(length(dx), dx)

# Level 2
## mv
Expand Down
6 changes: 5 additions & 1 deletion test/base/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ end
end

@test testf(dot, rand(Bool, 1024, 1024), rand(Float64, 1024, 1024))

# https://discourse.julialang.org/t/result-of-inner-product-of-two-cuarray-with-views-is-incorrect/121539
@test testf(dot, view(rand(Float32, 100, 100), 2:99, 2:99),
view(rand(Float32, 100, 100), 2:99, 2:99))
end

@testset "kron" begin
Expand All @@ -33,4 +37,4 @@ end
@test Array(kron(A, B)) ≈ kron(Array(A), Array(B))
@test Array(kron(B, A)) ≈ kron(Array(B), Array(A))
end
end
end