diff --git a/base/sort.jl b/base/sort.jl index e5a2e822ac6e2..5f602be32febb 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -5,7 +5,7 @@ module Sort import ..@__MODULE__, ..parentmodule const Base = parentmodule(@__MODULE__) using .Base.Order -using .Base: copymutable, LinearIndices, length, (:), iterate, +using .Base: copymutable, LinearIndices, length, (:), iterate, elsize, eachindex, axes, first, last, similar, zip, OrdinalRange, firstindex, lastindex, AbstractVector, @inbounds, AbstractRange, @eval, @inline, Vector, @noinline, AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !, @@ -599,12 +599,13 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::QuickSortAlg, o:: return v end -function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::MergeSortAlg, o::Ordering, t=similar(v,0)) +function sort!(v::AbstractVector{T}, lo::Integer, hi::Integer, a::MergeSortAlg, o::Ordering, + t0::Union{AbstractVector{T}, Nothing}=nothing) where T @inbounds if lo < hi hi-lo <= SMALL_THRESHOLD && return sort!(v, lo, hi, SMALL_ALGORITHM, o) m = midpoint(lo, hi) - (length(t) < m-lo+1) && resize!(t, m-lo+1) + t = workspace(v, t0, m-lo+1) sort!(v, lo, m, a, o, t) sort!(v, m+1, hi, a, o, t) @@ -731,7 +732,8 @@ end # For AbstractVector{Bool}, counting sort is always best. # This is an implementation of counting sort specialized for Bools. -function sort!(v::AbstractVector{<:Bool}, lo::Integer, hi::Integer, a::AdaptiveSort, o::Ordering) +function sort!(v::AbstractVector{B}, lo::Integer, hi::Integer, a::AdaptiveSort, o::Ordering, + t::Union{AbstractVector{B}, Nothing}=nothing) where {B <: Bool} first = lt(o, false, true) ? false : lt(o, true, false) ? true : return v count = 0 @inbounds for i in lo:hi @@ -744,6 +746,10 @@ function sort!(v::AbstractVector{<:Bool}, lo::Integer, hi::Integer, a::AdaptiveS v end +workspace(v::AbstractVector, ::Nothing, len::Integer) = similar(v, len) +function workspace(v::AbstractVector{T}, t::AbstractVector{T}, len::Integer) where T + length(t) < len ? resize!(t, len) : t +end maybe_unsigned(x::Integer) = x # this is necessary to avoid calling unsigned on BigInt maybe_unsigned(x::BitSigned) = unsigned(x) function _extrema(v::AbstractArray, lo::Integer, hi::Integer, o::Ordering) @@ -755,10 +761,11 @@ function _extrema(v::AbstractArray, lo::Integer, hi::Integer, o::Ordering) end mn, mx end -function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::AdaptiveSort, o::Ordering) +function sort!(v::AbstractVector{T}, lo::Integer, hi::Integer, a::AdaptiveSort, o::Ordering, + t::Union{AbstractVector{T}, Nothing}=nothing) where T # if the sorting task is not UIntMappable, then we can't radix sort or sort_int_range! # so we skip straight to the fallback algorithm which is comparison based. - U = UIntMappable(eltype(v), o) + U = UIntMappable(T, o) U === nothing && return sort!(v, lo, hi, a.fallback, o) # to avoid introducing excessive detection costs for the trivial sorting problem @@ -783,7 +790,7 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::AdaptiveSort, o:: # UInt128 does not support fast bit shifting so we never # dispatch to radix sort but we may still perform count sort if sizeof(U) > 8 - if eltype(v) <: Integer && o isa DirectOrdering + if T <: Integer && o isa DirectOrdering v_min, v_max = _extrema(v, lo, hi, Forward) v_range = maybe_unsigned(v_max-v_min) v_range == 0 && return v # all same @@ -799,7 +806,7 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::AdaptiveSort, o:: v_min, v_max = _extrema(v, lo, hi, o) lt(o, v_min, v_max) || return v # all same - if eltype(v) <: Integer && o isa DirectOrdering + if T <: Integer && o isa DirectOrdering R = o === Reverse v_range = maybe_unsigned(R ? v_min-v_max : v_max-v_min) if v_range < div(lenm1, 2) # count sort will be superior if v's range is very small @@ -849,7 +856,7 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::AdaptiveSort, o:: u[i] -= u_min end - u2 = radix_sort!(u, lo, hi, bits, similar(u)) + u2 = radix_sort!(u, lo, hi, bits, reinterpret(U, workspace(v, t, hi))) uint_unmap!(v, u2, lo, hi, o, u_min) end @@ -860,8 +867,14 @@ defalg(v::AbstractArray{<:Union{Number, Missing}}) = DEFAULT_UNSTABLE defalg(v::AbstractArray{Missing}) = DEFAULT_UNSTABLE # for method disambiguation defalg(v::AbstractArray{Union{}}) = DEFAULT_UNSTABLE # for method disambiguation -function sort!(v::AbstractVector, alg::Algorithm, order::Ordering) - sort!(v,firstindex(v),lastindex(v),alg,order) +function sort!(v::AbstractVector{T}, alg::Algorithm, + order::Ordering, t::Union{AbstractVector{T}, Nothing}=nothing) where T + sort!(v, firstindex(v), lastindex(v), alg, order, t) +end + +function sort!(v::AbstractVector{T}, lo::Integer, hi::Integer, alg::Algorithm, + order::Ordering, t::Union{AbstractVector{T}, Nothing}=nothing) where T + sort!(v, lo, hi, alg, order) end """ @@ -904,13 +917,14 @@ julia> v = [(1, "c"), (3, "a"), (2, "b")]; sort!(v, by = x -> x[2]); v (1, "c") ``` """ -function sort!(v::AbstractVector; +function sort!(v::AbstractVector{T}; alg::Algorithm=defalg(v), lt=isless, by=identity, rev::Union{Bool,Nothing}=nothing, - order::Ordering=Forward) - sort!(v, alg, ord(lt,by,rev,order)) + order::Ordering=Forward, + workspace::Union{AbstractVector{T}, Nothing}=nothing) where T + sort!(v, alg, ord(lt,by,rev,order), workspace) end # sort! for vectors of few unique integers @@ -1098,7 +1112,8 @@ function sortperm(v::AbstractVector; lt=isless, by=identity, rev::Union{Bool,Nothing}=nothing, - order::Ordering=Forward) + order::Ordering=Forward, + workspace::Union{AbstractVector, Nothing}=nothing) ordr = ord(lt,by,rev,order) if ordr === Forward && isa(v,Vector) && eltype(v)<:Integer n = length(v) @@ -1112,7 +1127,7 @@ function sortperm(v::AbstractVector; end end p = copymutable(eachindex(v)) - sort!(p, alg, Perm(ordr,v)) + sort!(p, alg, Perm(ordr,v), workspace) end @@ -1139,13 +1154,14 @@ julia> v[p] 3 ``` """ -function sortperm!(x::AbstractVector{<:Integer}, v::AbstractVector; +function sortperm!(x::AbstractVector{T}, v::AbstractVector; alg::Algorithm=DEFAULT_UNSTABLE, lt=isless, by=identity, rev::Union{Bool,Nothing}=nothing, order::Ordering=Forward, - initialized::Bool=false) + initialized::Bool=false, + workspace::Union{AbstractVector{T}, Nothing}=nothing) where T <: Integer if axes(x,1) != axes(v,1) throw(ArgumentError("index vector must have the same length/indices as the source vector, $(axes(x,1)) != $(axes(v,1))")) end @@ -1154,7 +1170,7 @@ function sortperm!(x::AbstractVector{<:Integer}, v::AbstractVector; x[i] = i end end - sort!(x, alg, Perm(ord(lt,by,rev,order),v)) + sort!(x, alg, Perm(ord(lt,by,rev,order),v), workspace) end # sortperm for vectors of few unique integers @@ -1212,13 +1228,14 @@ julia> sort(A, dims = 2) 1 2 ``` """ -function sort(A::AbstractArray; +function sort(A::AbstractArray{T}; dims::Integer, alg::Algorithm=DEFAULT_UNSTABLE, lt=isless, by=identity, rev::Union{Bool,Nothing}=nothing, - order::Ordering=Forward) + order::Ordering=Forward, + workspace::Union{AbstractVector{T}, Nothing}=similar(A, 0)) where T dim = dims order = ord(lt,by,rev,order) n = length(axes(A, dim)) @@ -1226,19 +1243,19 @@ function sort(A::AbstractArray; pdims = (dim, setdiff(1:ndims(A), dim)...) # put the selected dimension first Ap = permutedims(A, pdims) Av = vec(Ap) - sort_chunks!(Av, n, alg, order) + sort_chunks!(Av, n, alg, order, workspace) permutedims(Ap, invperm(pdims)) else Av = A[:] - sort_chunks!(Av, n, alg, order) + sort_chunks!(Av, n, alg, order, workspace) reshape(Av, axes(A)) end end -@noinline function sort_chunks!(Av, n, alg, order) +@noinline function sort_chunks!(Av, n, alg, order, t) inds = LinearIndices(Av) for s = first(inds):n:last(inds) - sort!(Av, s, s+n-1, alg, order) + sort!(Av, s, s+n-1, alg, order, t) end Av end @@ -1272,13 +1289,14 @@ julia> sort!(A, dims = 2); A 3 4 ``` """ -function sort!(A::AbstractArray; +function sort!(A::AbstractArray{T}; dims::Integer, alg::Algorithm=defalg(A), lt=isless, by=identity, rev::Union{Bool,Nothing}=nothing, - order::Ordering=Forward) + order::Ordering=Forward, + workspace::Union{AbstractVector{T}, Nothing}=nothing) where T ordr = ord(lt, by, rev, order) nd = ndims(A) k = dims @@ -1288,7 +1306,7 @@ function sort!(A::AbstractArray; remdims = ntuple(i -> i == k ? 1 : axes(A, i), nd) for idx in CartesianIndices(remdims) Av = view(A, ntuple(i -> i == k ? Colon() : idx[i], nd)...) - sort!(Av, alg, ordr) + sort!(Av, alg, ordr, workspace) end A end @@ -1505,10 +1523,11 @@ issignleft(o::ForwardOrdering, x::Floats) = lt(o, x, zero(x)) issignleft(o::ReverseOrdering, x::Floats) = lt(o, x, -zero(x)) issignleft(o::Perm, i::Integer) = issignleft(o.order, o.data[i]) -function fpsort!(v::AbstractVector, a::Algorithm, o::Ordering) +function fpsort!(v::AbstractVector, a::Algorithm, o::Ordering, + t::Union{AbstractVector, Nothing}=nothing) # fpsort!'s optimizations speed up comparisons, of which there are O(nlogn). # The overhead is O(n). For n < 10, it's not worth it. - length(v) < 10 && return sort!(v, firstindex(v), lastindex(v), SMALL_ALGORITHM, o) + length(v) < 10 && return sort!(v, firstindex(v), lastindex(v), SMALL_ALGORITHM, o, t) i, j = lo, hi = specials2end!(v,a,o) @inbounds while true @@ -1518,8 +1537,8 @@ function fpsort!(v::AbstractVector, a::Algorithm, o::Ordering) v[i], v[j] = v[j], v[i] i += 1; j -= 1 end - sort!(v, lo, j, a, left(o)) - sort!(v, i, hi, a, right(o)) + sort!(v, lo, j, a, left(o), t) + sort!(v, i, hi, a, right(o), t) return v end @@ -1527,10 +1546,14 @@ end fpsort!(v::AbstractVector, a::Sort.PartialQuickSort, o::Ordering) = sort!(v, firstindex(v), lastindex(v), a, o) -sort!(v::FPSortable, a::Algorithm, o::DirectOrdering) = - fpsort!(v, a, o) -sort!(v::AbstractVector{<:Union{Signed, Unsigned}}, a::Algorithm, o::Perm{<:DirectOrdering,<:FPSortable}) = - fpsort!(v, a, o) +function sort!(v::FPSortable, a::Algorithm, o::DirectOrdering, + t::Union{FPSortable, Nothing}=nothing) + fpsort!(v, a, o, t) +end +function sort!(v::AbstractVector{<:Union{Signed, Unsigned}}, a::Algorithm, + o::Perm{<:DirectOrdering,<:FPSortable}, t::Union{AbstractVector, Nothing}=nothing) + fpsort!(v, a, o, t) +end end # module Sort.Float diff --git a/test/sorting.jl b/test/sorting.jl index dd577f3baaef5..43d7ebbdf67de 100644 --- a/test/sorting.jl +++ b/test/sorting.jl @@ -654,6 +654,34 @@ end end end +@testset "workspace()" begin + for v in [[1, 2, 3], [0.0]] + for t0 in vcat([nothing], [similar(v,i) for i in 1:5]), len in 0:5 + t = Base.Sort.workspace(v, t0, len) + @test eltype(t) == eltype(v) + @test length(t) >= len + @test firstindex(t) == 1 + end + end +end + +@testset "sort(x; workspace=w) " begin + for n in [1,10,100,1000] + v = rand(n) + w = [0.0] + @test sort(v) == sort(v; workspace=w) + @test sort!(copy(v)) == sort!(copy(v); workspace=w) + @test sortperm(v) == sortperm(v; workspace=[4]) + @test sortperm!(Vector{Int}(undef, n), v) == sortperm!(Vector{Int}(undef, n), v; workspace=[4]) + + n > 100 && continue + M = rand(n, n) + @test sort(M; dims=2) == sort(M; dims=2, workspace=w) + @test sort!(copy(M); dims=1) == sort!(copy(M); dims=1, workspace=w) + end +end + + @testset "searchsorted" begin numTypes = [ Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128,