From bbaaf547292544f4d7a85d2d698ad4001b643952 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sat, 21 Nov 2015 18:03:47 +0100 Subject: [PATCH] Fix findmax and findmin with iterables The state is not guaranteed to be equivalent to the index of the element. --- base/abstractarray.jl | 3 --- base/array.jl | 32 ++++++++++++++++---------------- base/multidimensional.jl | 4 +--- test/arrayops.jl | 12 +++++++++++- 4 files changed, 28 insertions(+), 23 deletions(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 0f45fbe64fde5..039c19f518b31 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -400,9 +400,6 @@ start(A::AbstractArray) = (@_inline_meta(); itr = eachindex(A); (itr, start(itr) next(A::AbstractArray,i) = (@_inline_meta(); (idx, s) = next(i[1], i[2]); (A[idx], (i[1], s))) done(A::AbstractArray,i) = done(i[1], i[2]) -iterstate(i) = i -iterstate(i::Tuple{UnitRange{Int},Int}) = i[2] - # eachindex iterates over all indices. LinearSlow definitions are later. eachindex(A::AbstractArray) = (@_inline_meta(); eachindex(linearindexing(A), A)) diff --git a/base/array.jl b/base/array.jl index 04fd8bf09bf4a..aeb4c6c6d274d 100644 --- a/base/array.jl +++ b/base/array.jl @@ -821,36 +821,36 @@ function findmax(a) if isempty(a) throw(ArgumentError("collection must be non-empty")) end - i = start(a) - mi = i - m, i = next(a, i) - while !done(a, i) - iold = i - ai, i = next(a, i) + s = start(a) + mi = i = 1 + m, s = next(a, s) + while !done(a, s) + ai, s = next(a, s) + i += 1 if ai > m || m!=m m = ai - mi = iold + mi = i end end - return (m, iterstate(mi)) + return (m, mi) end function findmin(a) if isempty(a) throw(ArgumentError("collection must be non-empty")) end - i = start(a) - mi = i - m, i = next(a, i) - while !done(a, i) - iold = i - ai, i = next(a, i) + s = start(a) + mi = i = 1 + m, s = next(a, s) + while !done(a, s) + ai, s = next(a, s) + i += 1 if ai < m || m!=m m = ai - mi = iold + mi = i end end - return (m, iterstate(mi)) + return (m, mi) end indmax(a) = findmax(a)[2] diff --git a/base/multidimensional.jl b/base/multidimensional.jl index ea5e722c8f33f..3f5d0c5815def 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -3,7 +3,7 @@ ### Multidimensional iterators module IteratorsMD -import Base: eltype, length, start, done, next, last, getindex, setindex!, linearindexing, min, max, eachindex, ndims, iterstate +import Base: eltype, length, start, done, next, last, getindex, setindex!, linearindexing, min, max, eachindex, ndims importall ..Base.Operators import Base: simd_outer_range, simd_inner_length, simd_index, @generated import Base: @nref, @ncall, @nif, @nexprs, LinearFast, LinearSlow, to_index @@ -59,8 +59,6 @@ immutable CartesianRange{I<:CartesianIndex} stop::I end -iterstate{CR<:CartesianRange,CI<:CartesianIndex}(i::Tuple{CR,CI}) = Base._sub2ind(i[1].stop.I, i[2].I) - @generated function CartesianRange{N}(I::CartesianIndex{N}) startargs = fill(1, N) :(CartesianRange($I($(startargs...)), I)) diff --git a/test/arrayops.jl b/test/arrayops.jl index 7fd2cfe857026..24d24c2d29624 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -306,7 +306,7 @@ for i = 1:3 end @test isequal(a,findn(z)) -#argmin argmax +#findmin findmax indmin indmax @test indmax([10,12,9,11]) == 2 @test indmin([10,12,9,11]) == 3 @test findmin([NaN,3.2,1.8]) == (1.8,3) @@ -316,6 +316,16 @@ end @test findmin([3.2,1.8,NaN,2.0]) == (1.8,2) @test findmax([3.2,1.8,NaN,2.0]) == (3.2,1) +# #14085 +@test findmax(4:9) == (9,6) +@test indmax(4:9) == 6 +@test findmin(4:9) == (4,1) +@test indmin(4:9) == 1 +@test findmax(5:-2:1) == (5,1) +@test indmax(5:-2:1) == 1 +@test findmin(5:-2:1) == (1,3) +@test indmin(5:-2:1) == 3 + ## permutedims ## #keeps the num of dim