Skip to content

Commit

Permalink
Fix aqua tests
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jun 14, 2023
1 parent d8fd62d commit 163d507
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 14 deletions.
8 changes: 6 additions & 2 deletions lib/cusparse/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,13 @@ for SparseMatrixType in [:CuSparseMatrixCSC, :CuSparseMatrixCSR]
LinearAlgebra.tril(A::Adjoint{T,<:$SparseMatrixType}, k::Integer) where {T} =
$SparseMatrixType( tril(CuSparseMatrixCOO(_spadjoint(parent(A))), k) )

LinearAlgebra.triu(A::Union{$SparseMatrixType{T,M}, Transpose{T,<:$SparseMatrixType}, Adjoint{T,<:$SparseMatrixType}}) where {T,M} =
LinearAlgebra.triu(A::$SparseMatrixType{T,M}) where {T,M} =
$SparseMatrixType( triu(CuSparseMatrixCOO(A), 0) )
LinearAlgebra.tril(A::Union{$SparseMatrixType{T,M}, Transpose{T,<:$SparseMatrixType}, Adjoint{T,<:$SparseMatrixType}}) where {T,M} =
LinearAlgebra.triu(A::Union{Transpose{T,<:$SparseMatrixType}, Adjoint{T,<:$SparseMatrixType}}) where {T} =
$SparseMatrixType( triu(CuSparseMatrixCOO(A), 0) )
LinearAlgebra.tril(A::$SparseMatrixType{T,M}) where {T,M} =
$SparseMatrixType( tril(CuSparseMatrixCOO(A), 0) )
LinearAlgebra.tril(A::Union{Transpose{T,<:$SparseMatrixType}, Adjoint{T,<:$SparseMatrixType}}) where {T} =
$SparseMatrixType( tril(CuSparseMatrixCOO(A), 0) )

LinearAlgebra.kron(A::$SparseMatrixType{T,M}, B::$SparseMatrixType{T,M}) where {T,M} =
Expand Down
35 changes: 26 additions & 9 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,27 @@ array) or a tuple of the array dimensions. `own` optionally specified whether Ju
take ownership of the memory, calling `cudaFree` when the array is no longer referenced. The
`ctx` argument determines the CUDA context where the data is allocated in.
"""
function Base.unsafe_wrap(::Union{Type{CuArray},Type{CuArray{T}},Type{CuArray{T,N}},Type{CuArray{T,N,B}}},
function Base.unsafe_wrap(::Union{Type{CuArray},Type{CuArray{T}},Type{CuArray{T,N}}},
ptr::CuPtr{T}, dims::NTuple{N,Int};
own::Bool=false, ctx::CuContext=context()) where {T,N}
buf = _unsafe_wrapped_buff(T, ptr, dims; own, ctx)
storage = ArrayStorage(buf, own ? 1 : -1)
CuArray{T, length(dims)}(storage, dims)
end
function Base.unsafe_wrap(::Type{CuArray{T,N,B}},
ptr::CuPtr{T}, dims::NTuple{N,Int};
own::Bool=false, ctx::CuContext=context()) where {T,N,B}
buf = _unsafe_wrapped_buff(T, ptr, dims; own, ctx)
if typeof(buf) !== B
error("Declared buffer type does not match inferred buffer type.")
end
storage = ArrayStorage(buf, own ? 1 : -1)
CuArray{T, length(dims)}(storage, dims)
end

function _unsafe_wrapped_buff(::Type{T},
ptr::CuPtr{T}, dims::NTuple{N,Int};
own::Bool=false, ctx::CuContext=context()) where {T,N}
isbitstype(T) || error("Can only unsafe_wrap a pointer to a bits type")
sz = prod(dims) * sizeof(T)

Expand All @@ -259,16 +277,15 @@ function Base.unsafe_wrap(::Union{Type{CuArray},Type{CuArray{T}},Type{CuArray{T,
catch err
error("Could not identify the buffer type; are you passing a valid CUDA pointer to unsafe_wrap?")
end

if @isdefined(B) && typeof(buf) !== B
error("Declared buffer type does not match inferred buffer type.")
end

storage = ArrayStorage(buf, own ? 1 : -1)
CuArray{T, length(dims)}(storage, dims)
return buf
end

function Base.unsafe_wrap(Atype::Union{Type{CuArray},Type{CuArray{T}},Type{CuArray{T,1}},Type{CuArray{T,1,B}}},
function Base.unsafe_wrap(Atype::Union{Type{CuArray},Type{CuArray{T}},Type{CuArray{T,1}}},
p::CuPtr{T}, dim::Int;
own::Bool=false, ctx::CuContext=context()) where {T}
unsafe_wrap(Atype, p, (dim,); own, ctx)
end
function Base.unsafe_wrap(Atype::Type{CuArray{T,1,B}},
p::CuPtr{T}, dim::Int;
own::Bool=false, ctx::CuContext=context()) where {T,B}
unsafe_wrap(Atype, p, (dim,); own, ctx)
Expand Down
4 changes: 2 additions & 2 deletions src/device/intrinsics/wmma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ flatten_recurse(typ::Type{VecElement{T}}, e) where T = [:($e.value)]
unflatten_recurse(typ::Type{VecElement{T}}, e, idx) where T = :(VecElement{$T}($e[$idx])), idx + 1

# NTuples
function flatten_recurse(typ::Type{NTuple{N, T}}, e) where {N, T}
function flatten_recurse(typ::Type{T}, e) where {T <: NTuple}
ret = Expr[]

for (i, eltyp) in enumerate(typ.types)
Expand All @@ -372,7 +372,7 @@ function flatten_recurse(typ::Type{NTuple{N, T}}, e) where {N, T}
return ret
end

function unflatten_recurse(typ::Type{NTuple{N, T}}, e, idx) where {N, T}
function unflatten_recurse(typ::Type{T}, e, idx) where {T<:NTuple}
ret = Expr(:tuple)

for (i, eltyp) in enumerate(typ.types)
Expand Down
2 changes: 1 addition & 1 deletion test/aqua.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Aqua
# Aqua.test_unbound_args(CUDA)
ua = Aqua.detect_unbound_args_recursively(CUDA)
@info "Number of unbound argument methods: $(length(ua))"
@test length(ua) 26
@test length(ua) 16

# See: https://github.com/SciML/OrdinaryDiffEq.jl/issues/1750
# Test that we're not introducing method ambiguities across deps
Expand Down

0 comments on commit 163d507

Please sign in to comment.