diff --git a/stdlib/SparseArrays/src/sparsematrix.jl b/stdlib/SparseArrays/src/sparsematrix.jl index 41d0414b662ea..bcc518f527f63 100644 --- a/stdlib/SparseArrays/src/sparsematrix.jl +++ b/stdlib/SparseArrays/src/sparsematrix.jl @@ -2073,7 +2073,7 @@ function _findr(op, A, region, Tv) throw(ArgumentError("array slices must be non-empty")) else ri = Base.reduced_indices0(A, region) - return (similar(A, ri), zeros(Ti, ri)) + return (zeros(Tv, ri), zeros(Ti, ri)) end end @@ -3303,6 +3303,10 @@ dropstored!(A::AbstractSparseMatrixCSC, ::Colon) = dropstored!(A, :, :) # Sparse concatenation +promote_idxtype(::AbstractSparseMatrixCSC{<:Any, Ti}) where {Ti} = Ti +promote_idxtype(::AbstractSparseMatrixCSC{<:Any, Ti}, X::AbstractSparseMatrixCSC...) where {Ti} = + promote_type(Ti, promote_idxtype(X...)) + function vcat(X::AbstractSparseMatrixCSC...) num = length(X) mX = Int[ size(x, 1) for x in X ] @@ -3317,7 +3321,7 @@ function vcat(X::AbstractSparseMatrixCSC...) end Tv = promote_eltype(X...) - Ti = promote_eltype(map(x->rowvals(x), X)...) + Ti = promote_idxtype(X...) nnzX = Int[ nnz(x) for x in X ] nnz_res = sum(nnzX) @@ -3369,7 +3373,7 @@ function hcat(X::AbstractSparseMatrixCSC...) n = sum(nX) Tv = promote_eltype(X...) - Ti = promote_eltype(map(x->rowvals(x), X)...) + Ti = promote_idxtype(X...) colptr = Vector{Ti}(undef, n+1) nnzX = Int[ nnz(x) for x in X ] diff --git a/stdlib/SparseArrays/src/sparsevector.jl b/stdlib/SparseArrays/src/sparsevector.jl index eecda6d72add8..d8745ae21e12d 100644 --- a/stdlib/SparseArrays/src/sparsevector.jl +++ b/stdlib/SparseArrays/src/sparsevector.jl @@ -1101,17 +1101,22 @@ function vcat(Xin::_SparseConcatGroup...) X = map(x -> SparseMatrixCSC(issparse(x) ? x : sparse(x)), Xin) vcat(X...) end -function hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) - nbr = length(rows) # number of block rows - - tmp_rows = Vector{SparseMatrixCSC}(undef, nbr) - k = 0 - @inbounds for i = 1 : nbr - tmp_rows[i] = hcat(X[(1 : rows[i]) .+ k]...) - k += rows[i] - end - vcat(tmp_rows...) -end +hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) = + vcat(_hvcat_rows(rows, X...)...) +function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) + if row1 ≤ 0 + throw(ArgumentError("length of block row must be positive, got $row1")) + end + # assert `X` is non-empty so that inference of `eltype` won't include `Type{Union{}}` + T = eltype(X::Tuple{Any,Vararg{Any}}) + # inference of `getindex` may be imprecise in case `row1` is not const-propagated up + # to here, so help inference with the following type-assertions + return ( + hcat(X[1 : row1]::Tuple{typeof(X[1]),Vararg{T}}...), + _hvcat_rows(rows, X[row1+1:end]::Tuple{Vararg{T}}...)... + ) +end +_hvcat_rows(::Tuple{}, X::_SparseConcatGroup...) = () # make sure UniformScaling objects are converted to sparse matrices for concatenation promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = SparseMatrixCSC diff --git a/stdlib/SparseArrays/test/sparse.jl b/stdlib/SparseArrays/test/sparse.jl index 96630b4be0b8c..2e07f42182a5e 100644 --- a/stdlib/SparseArrays/test/sparse.jl +++ b/stdlib/SparseArrays/test/sparse.jl @@ -184,7 +184,7 @@ end sz34 = spzeros(3, 4) se77 = sparse(1.0I, 7, 7) @testset "h+v concatenation" begin - @test [se44 sz42 sz41; sz34 se33] == se77 + @test @inferred(hvcat((3, 2), se44, sz42, sz41, sz34, se33)) == se77 # [se44 sz42 sz41; sz34 se33] @test length(nonzeros([sp33 0I; 1I 0I])) == 6 end @@ -1355,10 +1355,10 @@ end @testset "argmax, argmin, findmax, findmin" begin S = sprand(100,80, 0.5) A = Array(S) - @test argmax(S) == argmax(A) - @test argmin(S) == argmin(A) - @test findmin(S) == findmin(A) - @test findmax(S) == findmax(A) + @test @inferred(argmax(S)) == argmax(A) + @test @inferred(argmin(S)) == argmin(A) + @test @inferred(findmin(S)) == findmin(A) + @test @inferred(findmax(S)) == findmax(A) for region in [(1,), (2,), (1,2)], m in [findmax, findmin] @test m(S, dims=region) == m(A, dims=region) end @@ -2224,7 +2224,7 @@ end # Test that concatenations of pairs of sparse matrices yield sparse arrays @test issparse(vcat(spmat, spmat)) @test issparse(hcat(spmat, spmat)) - @test issparse(hvcat((2,), spmat, spmat)) + @test issparse(@inferred(hvcat((2,), spmat, spmat))) @test issparse(cat(spmat, spmat; dims=(1,2))) # Test that concatenations of a sparse matrice with a dense matrix/vector yield sparse arrays @test issparse(vcat(spmat, densemat))