Skip to content

Commit

Permalink
Use existing Range vectorized fill in Enumerable.OrderBy (#99538)
Browse files Browse the repository at this point in the history
We already have a method that vectorizes the fill in Enumerable.Range. We can use the same helper in OrderBy when filling the integer map used to enable stability.
  • Loading branch information
stephentoub authored Mar 12, 2024
1 parent 19d7768 commit 342bc7e
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,9 @@ private abstract class EnumerableSorter<TElement>
private int[] ComputeMap(TElement[] elements, int count)
{
ComputeKeys(elements, count);
int[] map = new int[count];
for (int i = 0; i < map.Length; i++)
{
map[i] = i;
}

int[] map = new int[count];
FillIncrementing(map, 0);
return map;
}

Expand Down
40 changes: 3 additions & 37 deletions src/libraries/System.Linq/src/System/Linq/Range.SpeedOpt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace System.Linq
{
Expand All @@ -21,51 +18,20 @@ public override int[] ToArray()
{
int start = _start;
int[] array = new int[_end - start];
Fill(array, start);
FillIncrementing(array, start);
return array;
}

public override List<int> ToList()
{
(int start, int end) = (_start, _end);
List<int> list = new List<int>(end - start);
Fill(SetCountAndGetSpan(list, end - start), start);
FillIncrementing(SetCountAndGetSpan(list, end - start), start);
return list;
}

public void CopyTo(int[] array, int arrayIndex) =>
Fill(array.AsSpan(arrayIndex, _end - _start), _start);

private static void Fill(Span<int> destination, int value)
{
ref int pos = ref MemoryMarshal.GetReference(destination);
ref int end = ref Unsafe.Add(ref pos, destination.Length);

if (Vector.IsHardwareAccelerated &&
destination.Length >= Vector<int>.Count)
{
Vector<int> init = Vector<int>.Indices;
Vector<int> current = new Vector<int>(value) + init;
Vector<int> increment = new Vector<int>(Vector<int>.Count);

ref int oneVectorFromEnd = ref Unsafe.Subtract(ref end, Vector<int>.Count);
do
{
current.StoreUnsafe(ref pos);
current += increment;
pos = ref Unsafe.Add(ref pos, Vector<int>.Count);
}
while (!Unsafe.IsAddressGreaterThan(ref pos, ref oneVectorFromEnd));

value = current[0];
}

while (Unsafe.IsAddressLessThan(ref pos, ref end))
{
pos = value++;
pos = ref Unsafe.Add(ref pos, 1);
}
}
FillIncrementing(array.AsSpan(arrayIndex, _end - _start), _start);

public override int GetCount(bool onlyIfCheap) => _end - _start;

Expand Down
35 changes: 35 additions & 0 deletions src/libraries/System.Linq/src/System/Linq/Range.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

using System.Collections.Generic;
using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace System.Linq
{
Expand Down Expand Up @@ -71,5 +74,37 @@ public override void Dispose()
_state = -1; // Don't reset current
}
}

/// <summary>Fills the <paramref name="destination"/> with incrementing numbers, starting from <paramref name="value"/>.</summary>
private static void FillIncrementing(Span<int> destination, int value)
{
ref int pos = ref MemoryMarshal.GetReference(destination);
ref int end = ref Unsafe.Add(ref pos, destination.Length);

if (Vector.IsHardwareAccelerated &&
destination.Length >= Vector<int>.Count)
{
Vector<int> init = Vector<int>.Indices;
Vector<int> current = new Vector<int>(value) + init;
Vector<int> increment = new Vector<int>(Vector<int>.Count);

ref int oneVectorFromEnd = ref Unsafe.Subtract(ref end, Vector<int>.Count);
do
{
current.StoreUnsafe(ref pos);
current += increment;
pos = ref Unsafe.Add(ref pos, Vector<int>.Count);
}
while (!Unsafe.IsAddressGreaterThan(ref pos, ref oneVectorFromEnd));

value = current[0];
}

while (Unsafe.IsAddressLessThan(ref pos, ref end))
{
pos = value++;
pos = ref Unsafe.Add(ref pos, 1);
}
}
}
}

0 comments on commit 342bc7e

Please sign in to comment.