Skip to content

Commit

Permalink
Optimized findnext() for sparse: update now that predicate needs to b…
Browse files Browse the repository at this point in the history
…e explicit

Since we now need explicit predicates [1], this optimization only works
if we know that the predicate is a function that is false for zero
values. As suggested in that pull request, we could find out by calling
`f(zero(eltype(array)))` and hoping that `f` is pure, but I like being
a bit more conservative and only applying this optimization only to the
case where we *know* `f` is equal to `!iszero`.

For clarity, this commit also renames the helper method
_sparse_findnext()  to _sparse_findnextnz(), because now that the
predicate-less version doesn't exist anymore, the `nz` part isn't
implicit anymore either.

[1]: JuliaLang#23812
  • Loading branch information
tkluck committed Nov 3, 2017
1 parent 1ac4141 commit fe4b76e
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 20 deletions.
20 changes: 10 additions & 10 deletions base/sparse/abstractsparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,30 @@ end

# The following two methods should be overloaded by concrete types to avoid
# allocating the I = find(...)
_sparse_findnext(v::AbstractSparseArray, i) = (I = find(v); n = searchsortedfirst(I, i); n<=length(I) ? I[n] : 0)
_sparse_findprev(v::AbstractSparseArray, i) = (I = find(v); n = searchsortedlast(I, i); n>0 ? I[n] : 0)
_sparse_findnextnz(v::AbstractSparseArray, i) = (I = find(!iszero, v); n = searchsortedfirst(I, i); n<=length(I) ? I[n] : 0)
_sparse_findprevnz(v::AbstractSparseArray, i) = (I = find(!iszero, v); n = searchsortedlast(I, i); n>0 ? I[n] : 0)

function findnext(v::AbstractSparseArray, i::Int)
j = _sparse_findnext(v, i)
function findnext(f::typeof(!iszero), v::AbstractSparseArray, i::Int)
j = _sparse_findnextnz(v, i)
if j == 0
return 0
end
while v[j] == 0
j = _sparse_findnext(v, j+1)
while !f(v[j])
j = _sparse_findnextnz(v, j+1)
if j == 0
return 0
end
end
return j
end

function findprev(v::AbstractSparseArray, i::Int)
j = _sparse_findprev(v, i)
function findprev(f::typeof(!iszero), v::AbstractSparseArray, i::Int)
j = _sparse_findprevnz(v, i)
if j == 0
return 0
end
while v[j] == 0
j = _sparse_findprev(v, j-1)
while !f(v[j])
j = _sparse_findprevnz(v, j-1)
if j == 0
return 0
end
Expand Down
4 changes: 2 additions & 2 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,7 @@ function findnz(S::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
return (I, J, V)
end

function _sparse_findnext(m::SparseMatrixCSC, i::Int)
function _sparse_findnextnz(m::SparseMatrixCSC, i::Int)
if i > length(m)
return 0
end
Expand All @@ -1336,7 +1336,7 @@ function _sparse_findnext(m::SparseMatrixCSC, i::Int)
return sub2ind(m, m.rowval[nextlo], nextcol-1)
end

function _sparse_findprev(m::SparseMatrixCSC, i::Int)
function _sparse_findprevnz(m::SparseMatrixCSC, i::Int)
if i < 1
return 0
end
Expand Down
4 changes: 2 additions & 2 deletions base/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ function findnz(x::SparseVector{Tv,Ti}) where {Tv,Ti}
return (I, V)
end

function _sparse_findnext(v::SparseVector, i::Int)
function _sparse_findnextnz(v::SparseVector, i::Int)
n = searchsortedfirst(v.nzind, i)
if n > length(v.nzind)
return 0
Expand All @@ -746,7 +746,7 @@ function _sparse_findnext(v::SparseVector, i::Int)
end
end

function _sparse_findprev(v::SparseVector, i::Int)
function _sparse_findprevnz(v::SparseVector, i::Int)
n = searchsortedlast(v.nzind, i)
if n < 1
return 0
Expand Down
12 changes: 6 additions & 6 deletions test/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2155,8 +2155,8 @@ end
x_sp = sparse(x)

for i=1:length(x)
@test findnext(x,i) == findnext(x_sp,i)
@test findprev(x,i) == findprev(x_sp,i)
@test findnext(!iszero, x,i) == findnext(!iszero, x_sp,i)
@test findprev(!iszero, x,i) == findprev(!iszero, x_sp,i)
end

y = [0 0 0 0 0;
Expand All @@ -2167,15 +2167,15 @@ end
y_sp = sparse(y)

for i=1:length(y)
@test findnext(y,i) == findnext(y_sp,i)
@test findprev(y,i) == findprev(y_sp,i)
@test findnext(!iszero, y,i) == findnext(!iszero, y_sp,i)
@test findprev(!iszero, y,i) == findprev(!iszero, y_sp,i)
end

z_sp = sparsevec(Dict(1=>1, 5=>1, 8=>0, 10=>1))
z = collect(z_sp)

for i=1:length(z)
@test findnext(z,i) == findnext(z_sp,i)
@test findprev(z,i) == findprev(z_sp,i)
@test findnext(!iszero, z,i) == findnext(!iszero, z_sp,i)
@test findprev(!iszero, z,i) == findprev(!iszero, z_sp,i)
end
end

0 comments on commit fe4b76e

Please sign in to comment.