Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

union-types: use insertion (stable) sort instead of qsort #45896

Merged
merged 1 commit into from
Jul 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,12 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, ::InsertionSortAlg,
@inbounds for i = lo+1:hi
j = i
x = v[i]
while j > lo && lt(o, x, v[j-1])
v[j] = v[j-1]
while j > lo
y = v[j-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any case where the compiler doesn't already do this optimization?
The only case I can think of would be stateful getindex

function getindex(x::LogAtReadArray, i::Int)
    println(i)
    getindex(x.data, i)
end

which isn't particularly compelling.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any moderately complex or non-inlined getindex might not be eliminated

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess those likely exist, I've just never seen a subtype of AbstractVector with a complex getindex.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any non-inlined lt user-method will also disallow this from occurring as a direct optimization (it would be semantically invalid)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how that would be semantically invalid, but I'm not familiar with the relevant compiler semantics.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could throw for example.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not seeing it. Even if lt throws, y = v[j-1]; lt(o, x, y) is indistinguishable from lt(o, x, v[j-1]).

As a sanity check,

using Primes

@noinline function lt(x, y)
    isprime(x) && error()
    x < y
end

@code_llvm sort!([1,2,3], 1, 3, Lt(lt))

Looks the same before and after.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks different before/after to me. There is an excess load in the before case.

Copy link
Member

@LilithHafner LilithHafner Jul 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now! lt could mutate v. That would make this optimization change behavior. That behavior change is okay with me, but not okay with the compiler. The compiler cannot store the result of v[j - 1] unless v[j - 1] is pure with respect to the interleaving calls (including lt) which it can only know if both getindex and lt are simple and inlined.

I must have made a mistake in diffing. You're right, there is an extra load in the before case.

I'd be hard-pressed to find a case where this makes a measurable runtime difference, but I now see why it is an improvement in theory.

Thanks, @KristofferC and @vtjnash for explaining this to me!

if !lt(o, x, y)
break
end
v[j] = y
j -= 1
end
v[j] = x
Expand Down
35 changes: 25 additions & 10 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -420,10 +420,8 @@ static int datatype_name_cmp(jl_value_t *a, jl_value_t *b) JL_NOTSAFEPOINT

// sort singletons first, then DataTypes, then UnionAlls,
// ties broken alphabetically including module name & type parameters
static int union_sort_cmp(const void *ap, const void *bp) JL_NOTSAFEPOINT
static int union_sort_cmp(jl_value_t *a, jl_value_t *b) JL_NOTSAFEPOINT
{
jl_value_t *a = *(jl_value_t**)ap;
jl_value_t *b = *(jl_value_t**)bp;
if (a == NULL)
return b == NULL ? 0 : 1;
if (b == NULL)
Expand Down Expand Up @@ -458,16 +456,33 @@ static int union_sort_cmp(const void *ap, const void *bp) JL_NOTSAFEPOINT
}
}

static void isort_union(jl_value_t **a, size_t len) JL_NOTSAFEPOINT
{
size_t i, j;
for (i = 1; i < len; i++) {
jl_value_t *x = a[i];
for (j = i; j > 0; j--) {
jl_value_t *y = a[j - 1];
if (!(union_sort_cmp(x, y) < 0))
break;
a[j] = y;
}
a[j] = x;
}
}

JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n)
{
if (n == 0) return (jl_value_t*)jl_bottom_type;
if (n == 0)
return (jl_value_t*)jl_bottom_type;
size_t i;
for(i=0; i < n; i++) {
for (i = 0; i < n; i++) {
jl_value_t *pi = ts[i];
if (!(jl_is_type(pi) || jl_is_typevar(pi)))
jl_type_error("Union", (jl_value_t*)jl_type_type, pi);
}
if (n == 1) return ts[0];
if (n == 1)
return ts[0];

size_t nt = count_union_components(ts, n);
jl_value_t **temp;
Expand All @@ -476,9 +491,9 @@ JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n)
flatten_type_union(ts, n, temp, &count);
assert(count == nt);
size_t j;
for(i=0; i < nt; i++) {
int has_free = temp[i]!=NULL && jl_has_free_typevars(temp[i]);
for(j=0; j < nt; j++) {
for (i = 0; i < nt; i++) {
int has_free = temp[i] != NULL && jl_has_free_typevars(temp[i]);
for (j = 0; j < nt; j++) {
if (j != i && temp[i] && temp[j]) {
if (temp[i] == jl_bottom_type ||
temp[j] == (jl_value_t*)jl_any_type ||
Expand All @@ -490,7 +505,7 @@ JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n)
}
}
}
qsort(temp, nt, sizeof(jl_value_t*), union_sort_cmp);
isort_union(temp, nt);
jl_value_t **ptu = &temp[nt];
*ptu = jl_bottom_type;
int k;
Expand Down