diff --git a/base/sparse/abstractsparse.jl b/base/sparse/abstractsparse.jl index 1402241c422c5..05d65d97d9645 100644 --- a/base/sparse/abstractsparse.jl +++ b/base/sparse/abstractsparse.jl @@ -31,16 +31,16 @@ end # The following two methods should be overloaded by concrete types to avoid # allocating the I = find(...) -_sparse_findnext(v::AbstractSparseArray, i) = (I = find(v); n = searchsortedfirst(I, i); n<=length(I) ? I[n] : 0) -_sparse_findprev(v::AbstractSparseArray, i) = (I = find(v); n = searchsortedlast(I, i); n>0 ? I[n] : 0) +_sparse_findnextnz(v::AbstractSparseArray, i) = (I = find(!iszero, v); n = searchsortedfirst(I, i); n<=length(I) ? I[n] : 0) +_sparse_findprevnz(v::AbstractSparseArray, i) = (I = find(!iszero, v); n = searchsortedlast(I, i); n>0 ? I[n] : 0) -function findnext(v::AbstractSparseArray, i::Int) - j = _sparse_findnext(v, i) +function findnext(f::typeof(!iszero), v::AbstractSparseArray, i::Int) + j = _sparse_findnextnz(v, i) if j == 0 return 0 end - while v[j] == 0 - j = _sparse_findnext(v, j+1) + while !f(v[j]) + j = _sparse_findnextnz(v, j+1) if j == 0 return 0 end @@ -48,13 +48,13 @@ function findnext(v::AbstractSparseArray, i::Int) return j end -function findprev(v::AbstractSparseArray, i::Int) - j = _sparse_findprev(v, i) +function findprev(f::typeof(!iszero), v::AbstractSparseArray, i::Int) + j = _sparse_findprevnz(v, i) if j == 0 return 0 end - while v[j] == 0 - j = _sparse_findprev(v, j-1) + while !f(v[j]) + j = _sparse_findprevnz(v, j-1) if j == 0 return 0 end diff --git a/base/sparse/sparsematrix.jl b/base/sparse/sparsematrix.jl index 1405dcb841ea8..ace945eb5ff0d 100644 --- a/base/sparse/sparsematrix.jl +++ b/base/sparse/sparsematrix.jl @@ -1318,7 +1318,7 @@ function findnz(S::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} return (I, J, V) end -function _sparse_findnext(m::SparseMatrixCSC, i::Int) +function _sparse_findnextnz(m::SparseMatrixCSC, i::Int) if i > length(m) return 0 end @@ -1336,7 +1336,7 @@ function _sparse_findnext(m::SparseMatrixCSC, i::Int) return sub2ind(m, m.rowval[nextlo], nextcol-1) end -function _sparse_findprev(m::SparseMatrixCSC, i::Int) +function _sparse_findprevnz(m::SparseMatrixCSC, i::Int) if i < 1 return 0 end diff --git a/base/sparse/sparsevector.jl b/base/sparse/sparsevector.jl index 922b1b10e55c1..b900b696cacb5 100644 --- a/base/sparse/sparsevector.jl +++ b/base/sparse/sparsevector.jl @@ -737,7 +737,7 @@ function findnz(x::SparseVector{Tv,Ti}) where {Tv,Ti} return (I, V) end -function _sparse_findnext(v::SparseVector, i::Int) +function _sparse_findnextnz(v::SparseVector, i::Int) n = searchsortedfirst(v.nzind, i) if n > length(v.nzind) return 0 @@ -746,7 +746,7 @@ function _sparse_findnext(v::SparseVector, i::Int) end end -function _sparse_findprev(v::SparseVector, i::Int) +function _sparse_findprevnz(v::SparseVector, i::Int) n = searchsortedlast(v.nzind, i) if n < 1 return 0 diff --git a/test/sparse/sparse.jl b/test/sparse/sparse.jl index 19732dc88dc83..645182a30ff5e 100644 --- a/test/sparse/sparse.jl +++ b/test/sparse/sparse.jl @@ -2155,8 +2155,8 @@ end x_sp = sparse(x) for i=1:length(x) - @test findnext(x,i) == findnext(x_sp,i) - @test findprev(x,i) == findprev(x_sp,i) + @test findnext(!iszero, x,i) == findnext(!iszero, x_sp,i) + @test findprev(!iszero, x,i) == findprev(!iszero, x_sp,i) end y = [0 0 0 0 0; @@ -2167,15 +2167,15 @@ end y_sp = sparse(y) for i=1:length(y) - @test findnext(y,i) == findnext(y_sp,i) - @test findprev(y,i) == findprev(y_sp,i) + @test findnext(!iszero, y,i) == findnext(!iszero, y_sp,i) + @test findprev(!iszero, y,i) == findprev(!iszero, y_sp,i) end z_sp = sparsevec(Dict(1=>1, 5=>1, 8=>0, 10=>1)) z = collect(z_sp) for i=1:length(z) - @test findnext(z,i) == findnext(z_sp,i) - @test findprev(z,i) == findprev(z_sp,i) + @test findnext(!iszero, z,i) == findnext(!iszero, z_sp,i) + @test findprev(!iszero, z,i) == findprev(!iszero, z_sp,i) end end