Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add workspace/buffer support to sorting #45330

Merged
merged 1 commit into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 59 additions & 36 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module Sort
import ..@__MODULE__, ..parentmodule
const Base = parentmodule(@__MODULE__)
using .Base.Order
using .Base: copymutable, LinearIndices, length, (:), iterate,
using .Base: copymutable, LinearIndices, length, (:), iterate, elsize,
eachindex, axes, first, last, similar, zip, OrdinalRange, firstindex, lastindex,
AbstractVector, @inbounds, AbstractRange, @eval, @inline, Vector, @noinline,
AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !,
Expand Down Expand Up @@ -599,12 +599,13 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::QuickSortAlg, o::
return v
end

function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::MergeSortAlg, o::Ordering, t=similar(v,0))
function sort!(v::AbstractVector{T}, lo::Integer, hi::Integer, a::MergeSortAlg, o::Ordering,
t0::Union{AbstractVector{T}, Nothing}=nothing) where T
@inbounds if lo < hi
hi-lo <= SMALL_THRESHOLD && return sort!(v, lo, hi, SMALL_ALGORITHM, o)

m = midpoint(lo, hi)
(length(t) < m-lo+1) && resize!(t, m-lo+1)
t = workspace(v, t0, m-lo+1)

sort!(v, lo, m, a, o, t)
sort!(v, m+1, hi, a, o, t)
Expand Down Expand Up @@ -731,7 +732,8 @@ end

# For AbstractVector{Bool}, counting sort is always best.
# This is an implementation of counting sort specialized for Bools.
function sort!(v::AbstractVector{<:Bool}, lo::Integer, hi::Integer, a::AdaptiveSort, o::Ordering)
function sort!(v::AbstractVector{B}, lo::Integer, hi::Integer, a::AdaptiveSort, o::Ordering,
t::Union{AbstractVector{B}, Nothing}=nothing) where {B <: Bool}
first = lt(o, false, true) ? false : lt(o, true, false) ? true : return v
count = 0
@inbounds for i in lo:hi
Expand All @@ -744,6 +746,10 @@ function sort!(v::AbstractVector{<:Bool}, lo::Integer, hi::Integer, a::AdaptiveS
v
end

workspace(v::AbstractVector, ::Nothing, len::Integer) = similar(v, len)
function workspace(v::AbstractVector{T}, t::AbstractVector{T}, len::Integer) where T
length(t) < len ? resize!(t, len) : t
end
maybe_unsigned(x::Integer) = x # this is necessary to avoid calling unsigned on BigInt
maybe_unsigned(x::BitSigned) = unsigned(x)
function _extrema(v::AbstractArray, lo::Integer, hi::Integer, o::Ordering)
Expand All @@ -755,10 +761,11 @@ function _extrema(v::AbstractArray, lo::Integer, hi::Integer, o::Ordering)
end
mn, mx
end
function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::AdaptiveSort, o::Ordering)
function sort!(v::AbstractVector{T}, lo::Integer, hi::Integer, a::AdaptiveSort, o::Ordering,
t::Union{AbstractVector{T}, Nothing}=nothing) where T
# if the sorting task is not UIntMappable, then we can't radix sort or sort_int_range!
# so we skip straight to the fallback algorithm which is comparison based.
U = UIntMappable(eltype(v), o)
U = UIntMappable(T, o)
U === nothing && return sort!(v, lo, hi, a.fallback, o)

# to avoid introducing excessive detection costs for the trivial sorting problem
Expand All @@ -783,7 +790,7 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::AdaptiveSort, o::
# UInt128 does not support fast bit shifting so we never
# dispatch to radix sort but we may still perform count sort
if sizeof(U) > 8
if eltype(v) <: Integer && o isa DirectOrdering
if T <: Integer && o isa DirectOrdering
v_min, v_max = _extrema(v, lo, hi, Forward)
v_range = maybe_unsigned(v_max-v_min)
v_range == 0 && return v # all same
Expand All @@ -799,7 +806,7 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::AdaptiveSort, o::

v_min, v_max = _extrema(v, lo, hi, o)
lt(o, v_min, v_max) || return v # all same
if eltype(v) <: Integer && o isa DirectOrdering
if T <: Integer && o isa DirectOrdering
R = o === Reverse
v_range = maybe_unsigned(R ? v_min-v_max : v_max-v_min)
if v_range < div(lenm1, 2) # count sort will be superior if v's range is very small
Expand Down Expand Up @@ -849,7 +856,7 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::AdaptiveSort, o::
u[i] -= u_min
end

u2 = radix_sort!(u, lo, hi, bits, similar(u))
u2 = radix_sort!(u, lo, hi, bits, reinterpret(U, workspace(v, t, hi)))
uint_unmap!(v, u2, lo, hi, o, u_min)
end

Expand All @@ -860,8 +867,14 @@ defalg(v::AbstractArray{<:Union{Number, Missing}}) = DEFAULT_UNSTABLE
defalg(v::AbstractArray{Missing}) = DEFAULT_UNSTABLE # for method disambiguation
defalg(v::AbstractArray{Union{}}) = DEFAULT_UNSTABLE # for method disambiguation

function sort!(v::AbstractVector, alg::Algorithm, order::Ordering)
sort!(v,firstindex(v),lastindex(v),alg,order)
function sort!(v::AbstractVector{T}, alg::Algorithm,
order::Ordering, t::Union{AbstractVector{T}, Nothing}=nothing) where T
sort!(v, firstindex(v), lastindex(v), alg, order, t)
end

function sort!(v::AbstractVector{T}, lo::Integer, hi::Integer, alg::Algorithm,
order::Ordering, t::Union{AbstractVector{T}, Nothing}=nothing) where T
sort!(v, lo, hi, alg, order)
end

"""
Expand Down Expand Up @@ -904,13 +917,14 @@ julia> v = [(1, "c"), (3, "a"), (2, "b")]; sort!(v, by = x -> x[2]); v
(1, "c")
```
"""
function sort!(v::AbstractVector;
function sort!(v::AbstractVector{T};
alg::Algorithm=defalg(v),
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward)
sort!(v, alg, ord(lt,by,rev,order))
order::Ordering=Forward,
workspace::Union{AbstractVector{T}, Nothing}=nothing) where T
sort!(v, alg, ord(lt,by,rev,order), workspace)
end

# sort! for vectors of few unique integers
Expand Down Expand Up @@ -1098,7 +1112,8 @@ function sortperm(v::AbstractVector;
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward)
order::Ordering=Forward,
workspace::Union{AbstractVector, Nothing}=nothing)
ordr = ord(lt,by,rev,order)
if ordr === Forward && isa(v,Vector) && eltype(v)<:Integer
n = length(v)
Expand All @@ -1112,7 +1127,7 @@ function sortperm(v::AbstractVector;
end
end
p = copymutable(eachindex(v))
sort!(p, alg, Perm(ordr,v))
sort!(p, alg, Perm(ordr,v), workspace)
end


Expand All @@ -1139,13 +1154,14 @@ julia> v[p]
3
```
"""
function sortperm!(x::AbstractVector{<:Integer}, v::AbstractVector;
function sortperm!(x::AbstractVector{T}, v::AbstractVector;
alg::Algorithm=DEFAULT_UNSTABLE,
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
initialized::Bool=false)
initialized::Bool=false,
workspace::Union{AbstractVector{T}, Nothing}=nothing) where T <: Integer
if axes(x,1) != axes(v,1)
throw(ArgumentError("index vector must have the same length/indices as the source vector, $(axes(x,1)) != $(axes(v,1))"))
end
Expand All @@ -1154,7 +1170,7 @@ function sortperm!(x::AbstractVector{<:Integer}, v::AbstractVector;
x[i] = i
end
end
sort!(x, alg, Perm(ord(lt,by,rev,order),v))
sort!(x, alg, Perm(ord(lt,by,rev,order),v), workspace)
end

# sortperm for vectors of few unique integers
Expand Down Expand Up @@ -1212,33 +1228,34 @@ julia> sort(A, dims = 2)
1 2
```
"""
function sort(A::AbstractArray;
function sort(A::AbstractArray{T};
dims::Integer,
alg::Algorithm=DEFAULT_UNSTABLE,
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward)
order::Ordering=Forward,
workspace::Union{AbstractVector{T}, Nothing}=similar(A, 0)) where T
dim = dims
order = ord(lt,by,rev,order)
n = length(axes(A, dim))
if dim != 1
pdims = (dim, setdiff(1:ndims(A), dim)...) # put the selected dimension first
Ap = permutedims(A, pdims)
Av = vec(Ap)
sort_chunks!(Av, n, alg, order)
sort_chunks!(Av, n, alg, order, workspace)
permutedims(Ap, invperm(pdims))
else
Av = A[:]
sort_chunks!(Av, n, alg, order)
sort_chunks!(Av, n, alg, order, workspace)
reshape(Av, axes(A))
end
end

@noinline function sort_chunks!(Av, n, alg, order)
@noinline function sort_chunks!(Av, n, alg, order, t)
inds = LinearIndices(Av)
for s = first(inds):n:last(inds)
sort!(Av, s, s+n-1, alg, order)
sort!(Av, s, s+n-1, alg, order, t)
end
Av
end
Expand Down Expand Up @@ -1272,13 +1289,14 @@ julia> sort!(A, dims = 2); A
3 4
```
"""
function sort!(A::AbstractArray;
function sort!(A::AbstractArray{T};
dims::Integer,
alg::Algorithm=defalg(A),
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward)
order::Ordering=Forward,
workspace::Union{AbstractVector{T}, Nothing}=nothing) where T
ordr = ord(lt, by, rev, order)
nd = ndims(A)
k = dims
Expand All @@ -1288,7 +1306,7 @@ function sort!(A::AbstractArray;
remdims = ntuple(i -> i == k ? 1 : axes(A, i), nd)
for idx in CartesianIndices(remdims)
Av = view(A, ntuple(i -> i == k ? Colon() : idx[i], nd)...)
sort!(Av, alg, ordr)
sort!(Av, alg, ordr, workspace)
end
A
end
Expand Down Expand Up @@ -1505,10 +1523,11 @@ issignleft(o::ForwardOrdering, x::Floats) = lt(o, x, zero(x))
issignleft(o::ReverseOrdering, x::Floats) = lt(o, x, -zero(x))
issignleft(o::Perm, i::Integer) = issignleft(o.order, o.data[i])

function fpsort!(v::AbstractVector, a::Algorithm, o::Ordering)
function fpsort!(v::AbstractVector, a::Algorithm, o::Ordering,
t::Union{AbstractVector, Nothing}=nothing)
# fpsort!'s optimizations speed up comparisons, of which there are O(nlogn).
# The overhead is O(n). For n < 10, it's not worth it.
length(v) < 10 && return sort!(v, firstindex(v), lastindex(v), SMALL_ALGORITHM, o)
length(v) < 10 && return sort!(v, firstindex(v), lastindex(v), SMALL_ALGORITHM, o, t)

i, j = lo, hi = specials2end!(v,a,o)
@inbounds while true
Expand All @@ -1518,19 +1537,23 @@ function fpsort!(v::AbstractVector, a::Algorithm, o::Ordering)
v[i], v[j] = v[j], v[i]
i += 1; j -= 1
end
sort!(v, lo, j, a, left(o))
sort!(v, i, hi, a, right(o))
sort!(v, lo, j, a, left(o), t)
sort!(v, i, hi, a, right(o), t)
return v
end


fpsort!(v::AbstractVector, a::Sort.PartialQuickSort, o::Ordering) =
sort!(v, firstindex(v), lastindex(v), a, o)

sort!(v::FPSortable, a::Algorithm, o::DirectOrdering) =
fpsort!(v, a, o)
sort!(v::AbstractVector{<:Union{Signed, Unsigned}}, a::Algorithm, o::Perm{<:DirectOrdering,<:FPSortable}) =
fpsort!(v, a, o)
function sort!(v::FPSortable, a::Algorithm, o::DirectOrdering,
t::Union{FPSortable, Nothing}=nothing)
fpsort!(v, a, o, t)
end
function sort!(v::AbstractVector{<:Union{Signed, Unsigned}}, a::Algorithm,
o::Perm{<:DirectOrdering,<:FPSortable}, t::Union{AbstractVector, Nothing}=nothing)
fpsort!(v, a, o, t)
end

end # module Sort.Float

Expand Down
28 changes: 28 additions & 0 deletions test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,34 @@ end
end
end

@testset "workspace()" begin
for v in [[1, 2, 3], [0.0]]
for t0 in vcat([nothing], [similar(v,i) for i in 1:5]), len in 0:5
t = Base.Sort.workspace(v, t0, len)
@test eltype(t) == eltype(v)
@test length(t) >= len
@test firstindex(t) == 1
end
end
end

@testset "sort(x; workspace=w) " begin
for n in [1,10,100,1000]
v = rand(n)
w = [0.0]
@test sort(v) == sort(v; workspace=w)
@test sort!(copy(v)) == sort!(copy(v); workspace=w)
@test sortperm(v) == sortperm(v; workspace=[4])
@test sortperm!(Vector{Int}(undef, n), v) == sortperm!(Vector{Int}(undef, n), v; workspace=[4])

n > 100 && continue
M = rand(n, n)
@test sort(M; dims=2) == sort(M; dims=2, workspace=w)
@test sort!(copy(M); dims=1) == sort!(copy(M); dims=1, workspace=w)
end
end


@testset "searchsorted" begin
numTypes = [ Int8, Int16, Int32, Int64, Int128,
UInt8, UInt16, UInt32, UInt64, UInt128,
Expand Down