Skip to content

Commit

Permalink
Fix some inferability issues in SparseArrays (#41187)
Browse files Browse the repository at this point in the history
* Fix a type-instability in sparse `findmin`/`findmax`

The helper function `_findr` would usually return a `Vector` as first
argument, but would use a `SparseMatrixCSC` in the empty case. Fix by
always using `Vector`.

* Make sparse `hvcat` inferable

This also requires making sparse `vcat` and `hcat` inferable in the
vararg case which in turn requires a different way to determine the
resulting index type, now implemented similar to `promote_eltype`.
  • Loading branch information
martinholters authored Jun 16, 2021
1 parent dd94ceb commit d98fb01
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 20 deletions.
10 changes: 7 additions & 3 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2073,7 +2073,7 @@ function _findr(op, A, region, Tv)
throw(ArgumentError("array slices must be non-empty"))
else
ri = Base.reduced_indices0(A, region)
return (similar(A, ri), zeros(Ti, ri))
return (zeros(Tv, ri), zeros(Ti, ri))
end
end

Expand Down Expand Up @@ -3303,6 +3303,10 @@ dropstored!(A::AbstractSparseMatrixCSC, ::Colon) = dropstored!(A, :, :)

# Sparse concatenation

promote_idxtype(::AbstractSparseMatrixCSC{<:Any, Ti}) where {Ti} = Ti
promote_idxtype(::AbstractSparseMatrixCSC{<:Any, Ti}, X::AbstractSparseMatrixCSC...) where {Ti} =
promote_type(Ti, promote_idxtype(X...))

function vcat(X::AbstractSparseMatrixCSC...)
num = length(X)
mX = Int[ size(x, 1) for x in X ]
Expand All @@ -3317,7 +3321,7 @@ function vcat(X::AbstractSparseMatrixCSC...)
end

Tv = promote_eltype(X...)
Ti = promote_eltype(map(x->rowvals(x), X)...)
Ti = promote_idxtype(X...)

nnzX = Int[ nnz(x) for x in X ]
nnz_res = sum(nnzX)
Expand Down Expand Up @@ -3369,7 +3373,7 @@ function hcat(X::AbstractSparseMatrixCSC...)
n = sum(nX)

Tv = promote_eltype(X...)
Ti = promote_eltype(map(x->rowvals(x), X)...)
Ti = promote_idxtype(X...)

colptr = Vector{Ti}(undef, n+1)
nnzX = Int[ nnz(x) for x in X ]
Expand Down
27 changes: 16 additions & 11 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1101,17 +1101,22 @@ function vcat(Xin::_SparseConcatGroup...)
X = map(x -> SparseMatrixCSC(issparse(x) ? x : sparse(x)), Xin)
vcat(X...)
end
function hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
nbr = length(rows) # number of block rows

tmp_rows = Vector{SparseMatrixCSC}(undef, nbr)
k = 0
@inbounds for i = 1 : nbr
tmp_rows[i] = hcat(X[(1 : rows[i]) .+ k]...)
k += rows[i]
end
vcat(tmp_rows...)
end
hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) =
vcat(_hvcat_rows(rows, X...)...)
function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
if row1 0
throw(ArgumentError("length of block row must be positive, got $row1"))
end
# assert `X` is non-empty so that inference of `eltype` won't include `Type{Union{}}`
T = eltype(X::Tuple{Any,Vararg{Any}})
# inference of `getindex` may be imprecise in case `row1` is not const-propagated up
# to here, so help inference with the following type-assertions
return (
hcat(X[1 : row1]::Tuple{typeof(X[1]),Vararg{T}}...),
_hvcat_rows(rows, X[row1+1:end]::Tuple{Vararg{T}}...)...
)
end
_hvcat_rows(::Tuple{}, X::_SparseConcatGroup...) = ()

# make sure UniformScaling objects are converted to sparse matrices for concatenation
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = SparseMatrixCSC
Expand Down
12 changes: 6 additions & 6 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ end
sz34 = spzeros(3, 4)
se77 = sparse(1.0I, 7, 7)
@testset "h+v concatenation" begin
@test [se44 sz42 sz41; sz34 se33] == se77
@test @inferred(hvcat((3, 2), se44, sz42, sz41, sz34, se33)) == se77 # [se44 sz42 sz41; sz34 se33]
@test length(nonzeros([sp33 0I; 1I 0I])) == 6
end

Expand Down Expand Up @@ -1355,10 +1355,10 @@ end
@testset "argmax, argmin, findmax, findmin" begin
S = sprand(100,80, 0.5)
A = Array(S)
@test argmax(S) == argmax(A)
@test argmin(S) == argmin(A)
@test findmin(S) == findmin(A)
@test findmax(S) == findmax(A)
@test @inferred(argmax(S)) == argmax(A)
@test @inferred(argmin(S)) == argmin(A)
@test @inferred(findmin(S)) == findmin(A)
@test @inferred(findmax(S)) == findmax(A)
for region in [(1,), (2,), (1,2)], m in [findmax, findmin]
@test m(S, dims=region) == m(A, dims=region)
end
Expand Down Expand Up @@ -2224,7 +2224,7 @@ end
# Test that concatenations of pairs of sparse matrices yield sparse arrays
@test issparse(vcat(spmat, spmat))
@test issparse(hcat(spmat, spmat))
@test issparse(hvcat((2,), spmat, spmat))
@test issparse(@inferred(hvcat((2,), spmat, spmat)))
@test issparse(cat(spmat, spmat; dims=(1,2)))
# Test that concatenations of a sparse matrice with a dense matrix/vector yield sparse arrays
@test issparse(vcat(spmat, densemat))
Expand Down

0 comments on commit d98fb01

Please sign in to comment.