Skip to content

Commit

Permalink
Revert "Fix some inferability issues in SparseArrays" (JuliaLang#40332)
Browse files Browse the repository at this point in the history
  • Loading branch information
vtjnash authored and johanmon committed Jul 5, 2021
1 parent 28e65e4 commit d7d921e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 21 deletions.
10 changes: 3 additions & 7 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2045,7 +2045,7 @@ function _findr(op, A, region, Tv)
throw(ArgumentError("array slices must be non-empty"))
else
ri = Base.reduced_indices0(A, region)
return (zeros(Tv, ri), zeros(Ti, ri))
return (similar(A, ri), zeros(Ti, ri))
end
end

Expand Down Expand Up @@ -3274,10 +3274,6 @@ dropstored!(A::AbstractSparseMatrixCSC, ::Colon) = dropstored!(A, :, :)

# Sparse concatenation

promote_idxtype(::AbstractSparseMatrixCSC{<:Any, Ti}) where {Ti} = Ti
promote_idxtype(::AbstractSparseMatrixCSC{<:Any, Ti}, X::AbstractSparseMatrixCSC...) where {Ti} =
promote_type(Ti, promote_idxtype(X...))

function vcat(X::AbstractSparseMatrixCSC...)
num = length(X)
mX = Int[ size(x, 1) for x in X ]
Expand All @@ -3292,7 +3288,7 @@ function vcat(X::AbstractSparseMatrixCSC...)
end

Tv = promote_eltype(X...)
Ti = promote_idxtype(X...)
Ti = promote_eltype(map(x->rowvals(x), X)...)

nnzX = Int[ nnz(x) for x in X ]
nnz_res = sum(nnzX)
Expand Down Expand Up @@ -3344,7 +3340,7 @@ function hcat(X::AbstractSparseMatrixCSC...)
n = sum(nX)

Tv = promote_eltype(X...)
Ti = promote_idxtype(X...)
Ti = promote_eltype(map(x->rowvals(x), X)...)

colptr = Vector{Ti}(undef, n+1)
nnzX = Int[ nnz(x) for x in X ]
Expand Down
17 changes: 9 additions & 8 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1091,16 +1091,17 @@ function vcat(Xin::_SparseConcatGroup...)
X = map(x -> SparseMatrixCSC(issparse(x) ? x : sparse(x)), Xin)
vcat(X...)
end
hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) =
vcat(_hvcat_rows(rows, X...)...)
function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
if row1 0
throw(ArgumentError("length of block row must be positive, got $row1"))
function hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
nbr = length(rows) # number of block rows

tmp_rows = Vector{SparseMatrixCSC}(undef, nbr)
k = 0
@inbounds for i = 1 : nbr
tmp_rows[i] = hcat(X[(1 : rows[i]) .+ k]...)
k += rows[i]
end
# provide X[1] separately to convince inference that we don't call hcat() without arguments
return (hcat(X[1], X[2 : row1]...), _hvcat_rows(rows, X[row1+1:end]...)...)
vcat(tmp_rows...)
end
_hvcat_rows(::Tuple{}, X::_SparseConcatGroup...) = ()

# make sure UniformScaling objects are converted to sparse matrices for concatenation
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = SparseMatrixCSC
Expand Down
12 changes: 6 additions & 6 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ end
sz34 = spzeros(3, 4)
se77 = sparse(1.0I, 7, 7)
@testset "h+v concatenation" begin
@test @inferred(hvcat((3, 2), se44, sz42, sz41, sz34, se33)) == se77 # [se44 sz42 sz41; sz34 se33]
@test [se44 sz42 sz41; sz34 se33] == se77
@test length(nonzeros([sp33 0I; 1I 0I])) == 6
end

Expand Down Expand Up @@ -1338,10 +1338,10 @@ end
@testset "argmax, argmin, findmax, findmin" begin
S = sprand(100,80, 0.5)
A = Array(S)
@test @inferred(argmax(S)) == argmax(A)
@test @inferred(argmin(S)) == argmin(A)
@test @inferred(findmin(S)) == findmin(A)
@test @inferred(findmax(S)) == findmax(A)
@test argmax(S) == argmax(A)
@test argmin(S) == argmin(A)
@test findmin(S) == findmin(A)
@test findmax(S) == findmax(A)
for region in [(1,), (2,), (1,2)], m in [findmax, findmin]
@test m(S, dims=region) == m(A, dims=region)
end
Expand Down Expand Up @@ -2201,7 +2201,7 @@ end
# Test that concatenations of pairs of sparse matrices yield sparse arrays
@test issparse(vcat(spmat, spmat))
@test issparse(hcat(spmat, spmat))
@test issparse(@inferred(hvcat((2,), spmat, spmat)))
@test issparse(hvcat((2,), spmat, spmat))
@test issparse(cat(spmat, spmat; dims=(1,2)))
# Test that concatenations of a sparse matrice with a dense matrix/vector yield sparse arrays
@test issparse(vcat(spmat, densemat))
Expand Down

0 comments on commit d7d921e

Please sign in to comment.