From 4c4c94f4781da4f4109086368205db8a2f7ec7c4 Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Wed, 1 Jun 2022 03:16:54 +0200 Subject: [PATCH] Optimize findall(f, ::AbstractArray{Bool}) (#42202) Co-authored-by: Milan Bouchet-Valat --- base/array.jl | 39 +++++++++++++++++++++++++++++++-------- test/arrayops.jl | 8 ++++++++ 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/base/array.jl b/base/array.jl index c1b8e4cc1f07f..28e52a538b64b 100644 --- a/base/array.jl +++ b/base/array.jl @@ -2342,19 +2342,42 @@ function findall(A) end # Allocating result upfront is faster (possible only when collection can be iterated twice) -function findall(A::AbstractArray{Bool}) - n = count(A) +function _findall(f::Function, A::AbstractArray{Bool}) + n = count(f, A) I = Vector{eltype(keys(A))}(undef, n) + isempty(I) && return I + _findall(f, I, A) +end + +function _findall(f::Function, I::Vector, A::AbstractArray{Bool}) cnt = 1 - for (i,a) in pairs(A) - if a - I[cnt] = i - cnt += 1 - end + len = length(I) + for (k, v) in pairs(A) + @inbounds I[cnt] = k + cnt += f(v) + cnt > len && return I end - I + # In case of impure f, this line could potentially be hit. In that case, + # we can't assume I is the correct length. + resize!(I, cnt - 1) +end + +function _findall(f::Function, I::Vector, A::AbstractVector{Bool}) + i = firstindex(A) + cnt = 1 + len = length(I) + while cnt ≤ len + @inbounds I[cnt] = i + cnt += f(@inbounds A[i]) + i = nextind(A, i) + end + cnt - 1 == len ? I : resize!(I, cnt - 1) end +findall(f::Function, A::AbstractArray{Bool}) = _findall(f, A) +findall(f::Fix2{typeof(in)}, A::AbstractArray{Bool}) = _findall(f, A) +findall(A::AbstractArray{Bool}) = _findall(identity, A) + findall(x::Bool) = x ? [1] : Vector{Int}() findall(testf::Function, x::Number) = testf(x) ? [1] : Vector{Int}() findall(p::Fix2{typeof(in)}, x::Number) = x in p.x ? [1] : Vector{Int}() diff --git a/test/arrayops.jl b/test/arrayops.jl index 627847a2d1ace..b11731d394b65 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -545,9 +545,17 @@ end @testset "findall, findfirst, findnext, findlast, findprev" begin a = [0,1,2,3,0,1,2,3] + m = [false false; true false] @test findall(!iszero, a) == [2,3,4,6,7,8] @test findall(a.==2) == [3,7] @test findall(isodd,a) == [2,4,6,8] + @test findall(Bool[]) == Int[] + @test findall([false, false]) == Int[] + @test findall(m) == [k for (k,v) in pairs(m) if v] + @test findall(!, [false, true, true]) == [1] + @test findall(i -> true, [false, true, false]) == [1, 2, 3] + @test findall(i -> false, rand(2, 2)) == Int[] + @test findall(!, m) == [k for (k,v) in pairs(m) if !v] @test findfirst(!iszero, a) == 2 @test findfirst(a.==0) == 1 @test findfirst(a.==5) == nothing