Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements for SpanHelpers.IndexOf #64872

Merged
merged 3 commits into from
Feb 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 60 additions & 36 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public static int IndexOf(ref byte searchSpace, int searchSpaceLength, ref byte
if (valueTailLength == 0)
return IndexOf(ref searchSpace, value, searchSpaceLength); // for single-byte values use plain IndexOf

int offset = 0;
nint offset = 0;
byte valueHead = value;
int searchSpaceMinusValueTailLength = searchSpaceLength - valueTailLength;
if (Vector128.IsHardwareAccelerated && searchSpaceMinusValueTailLength >= Vector128<byte>.Count)
Expand Down Expand Up @@ -54,7 +54,7 @@ public static int IndexOf(ref byte searchSpace, int searchSpaceLength, ref byte
if (SequenceEqual(
ref Unsafe.Add(ref searchSpace, offset + 1),
ref valueTail, (nuint)(uint)valueTailLength)) // The (nuint)-cast is necessary to pick the correct overload
return offset; // The tail matched. Return a successful find.
return (int)offset; // The tail matched. Return a successful find.

remainingSearchSpaceLength--;
offset++;
Expand All @@ -69,48 +69,60 @@ ref Unsafe.Add(ref searchSpace, offset + 1),
// Find the last unique (which is not equal to ch1) byte
// the algorithm is fine if both are equal, just a little bit less efficient
byte ch2Val = Unsafe.Add(ref value, valueTailLength);
int ch1ch2Distance = valueTailLength;
nint ch1ch2Distance = valueTailLength;
while (ch2Val == value && ch1ch2Distance > 1)
ch2Val = Unsafe.Add(ref value, --ch1ch2Distance);

Vector256<byte> ch1 = Vector256.Create(value);
Vector256<byte> ch2 = Vector256.Create(ch2Val);

nint searchSpaceMinusValueTailLengthAndVector =
searchSpaceMinusValueTailLength - (nint)Vector256<byte>.Count;

do
{
Debug.Assert(offset >= 0);
// Make sure we don't go out of bounds
Debug.Assert(offset + ch1ch2Distance + Vector256<byte>.Count <= searchSpaceLength);

Vector256<byte> cmpCh1 = Vector256.Equals(ch1, Vector256.LoadUnsafe(ref searchSpace, (nuint)offset));
Vector256<byte> cmpCh2 = Vector256.Equals(ch2, Vector256.LoadUnsafe(ref searchSpace, (nuint)(offset + ch1ch2Distance)));
Vector256<byte> cmpCh1 = Vector256.Equals(ch1, Vector256.LoadUnsafe(ref searchSpace, (nuint)offset));
Vector256<byte> cmpAnd = (cmpCh1 & cmpCh2).AsByte();

// Early out: cmpAnd is all zeros
if (cmpAnd != Vector256<byte>.Zero)
{
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
int bitPos = BitOperations.TrailingZeroCount(mask);
if (valueLength == 2 || // we already matched two bytes
SequenceEqual(
ref Unsafe.Add(ref searchSpace, offset + bitPos),
ref value, (nuint)(uint)valueLength)) // The (nuint)-cast is necessary to pick the correct overload
{
return offset + bitPos;
}
mask = BitOperations.ResetLowestSetBit(mask); // Clear the lowest set bit
} while (mask != 0);
goto CANDIDATE_FOUND;
}

LOOP_FOOTER:
offset += Vector256<byte>.Count;

if (offset == searchSpaceMinusValueTailLength)
return -1;

// Overlap with the current chunk for trailing elements
if (offset > searchSpaceMinusValueTailLength - Vector256<byte>.Count)
offset = searchSpaceMinusValueTailLength - Vector256<byte>.Count;
if (offset > searchSpaceMinusValueTailLengthAndVector)
offset = searchSpaceMinusValueTailLengthAndVector;

continue;

CANDIDATE_FOUND:
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
int bitPos = BitOperations.TrailingZeroCount(mask);
if (valueLength == 2 || // we already matched two bytes
SequenceEqual(
ref Unsafe.Add(ref searchSpace, offset + bitPos),
ref value, (nuint)(uint)valueLength)) // The (nuint)-cast is necessary to pick the correct overload
{
return (int)(offset + bitPos);
}
mask = BitOperations.ResetLowestSetBit(mask); // Clear the lowest set bit
} while (mask != 0);
goto LOOP_FOOTER;

} while (true);
}
else // 128bit vector path (SSE2 or AdvSimd)
Expand All @@ -125,42 +137,54 @@ ref Unsafe.Add(ref searchSpace, offset + bitPos),
Vector128<byte> ch1 = Vector128.Create(value);
Vector128<byte> ch2 = Vector128.Create(ch2Val);

nint searchSpaceMinusValueTailLengthAndVector =
searchSpaceMinusValueTailLength - (nint)Vector128<byte>.Count;

do
{
Debug.Assert(offset >= 0);
// Make sure we don't go out of bounds
Debug.Assert(offset + ch1ch2Distance + Vector128<byte>.Count <= searchSpaceLength);

Vector128<byte> cmpCh1 = Vector128.Equals(ch1, Vector128.LoadUnsafe(ref searchSpace, (nuint)offset));
Vector128<byte> cmpCh2 = Vector128.Equals(ch2, Vector128.LoadUnsafe(ref searchSpace, (nuint)(offset + ch1ch2Distance)));
Vector128<byte> cmpCh1 = Vector128.Equals(ch1, Vector128.LoadUnsafe(ref searchSpace, (nuint)offset));
Vector128<byte> cmpAnd = (cmpCh1 & cmpCh2).AsByte();

// Early out: cmpAnd is all zeros
if (cmpAnd != Vector128<byte>.Zero)
{
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
int bitPos = BitOperations.TrailingZeroCount(mask);
if (valueLength == 2 || // we already matched two bytes
SequenceEqual(
ref Unsafe.Add(ref searchSpace, offset + bitPos),
ref value, (nuint)(uint)valueLength)) // The (nuint)-cast is necessary to pick the correct overload
{
return offset + bitPos;
}
// Clear the lowest set bit
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
goto CANDIDATE_FOUND;
}

LOOP_FOOTER:
offset += Vector128<byte>.Count;

if (offset == searchSpaceMinusValueTailLength)
return -1;

// Overlap with the current chunk for trailing elements
if (offset > searchSpaceMinusValueTailLength - Vector128<byte>.Count)
offset = searchSpaceMinusValueTailLength - Vector128<byte>.Count;
if (offset > searchSpaceMinusValueTailLengthAndVector)
offset = searchSpaceMinusValueTailLengthAndVector;

continue;

CANDIDATE_FOUND:
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
int bitPos = BitOperations.TrailingZeroCount(mask);
if (valueLength == 2 || // we already matched two bytes
SequenceEqual(
ref Unsafe.Add(ref searchSpace, offset + bitPos),
ref value, (nuint)(uint)valueLength)) // The (nuint)-cast is necessary to pick the correct overload
{
return (int)(offset + bitPos);
}
// Clear the lowest set bit
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
goto LOOP_FOOTER;

} while (true);
}
}
Expand Down
124 changes: 74 additions & 50 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public static int IndexOf(ref char searchSpace, int searchSpaceLength, ref char
return IndexOf(ref searchSpace, value, searchSpaceLength);
}

int offset = 0;
nint offset = 0;
char valueHead = value;
int searchSpaceMinusValueTailLength = searchSpaceLength - valueTailLength;
if (Vector128.IsHardwareAccelerated && searchSpaceMinusValueTailLength >= Vector128<ushort>.Count)
Expand Down Expand Up @@ -59,7 +59,7 @@ ref Unsafe.As<char, byte>(ref Unsafe.Add(ref searchSpace, offset + 1)),
ref valueTail,
(nuint)(uint)valueTailLength * 2))
{
return offset; // The tail matched. Return a successful find.
return (int)offset; // The tail matched. Return a successful find.
}

remainingSearchSpaceLength--;
Expand All @@ -75,109 +75,133 @@ ref Unsafe.As<char, byte>(ref Unsafe.Add(ref searchSpace, offset + 1)),
// Find the last unique (which is not equal to ch1) character
// the algorithm is fine if both are equal, just a little bit less efficient
ushort ch2Val = Unsafe.Add(ref value, valueTailLength);
int ch1ch2Distance = valueTailLength;
nint ch1ch2Distance = valueTailLength;
while (ch2Val == valueHead && ch1ch2Distance > 1)
ch2Val = Unsafe.Add(ref value, --ch1ch2Distance);

Vector256<ushort> ch1 = Vector256.Create((ushort)valueHead);
Vector256<ushort> ch2 = Vector256.Create(ch2Val);

nint searchSpaceMinusValueTailLengthAndVector =
searchSpaceMinusValueTailLength - (nint)Vector256<ushort>.Count;

do
{
// Make sure we don't go out of bounds
Debug.Assert(offset + ch1ch2Distance + Vector256<ushort>.Count <= searchSpaceLength);

Vector256<ushort> cmpCh1 = Vector256.Equals(ch1, LoadVector256(ref searchSpace, offset));
Vector256<ushort> cmpCh2 = Vector256.Equals(ch2, LoadVector256(ref searchSpace, offset + ch1ch2Distance));
Vector256<ushort> cmpCh1 = Vector256.Equals(ch1, LoadVector256(ref searchSpace, offset));
Vector256<byte> cmpAnd = (cmpCh1 & cmpCh2).AsByte();

// Early out: cmpAnd is all zeros
if (cmpAnd != Vector256<byte>.Zero)
{
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
int bitPos = BitOperations.TrailingZeroCount(mask);
// div by 2 (shr) because we work with 2-byte chars
int charPos = (int)((uint)bitPos / 2);
if (valueLength == 2 || // we already matched two chars
SequenceEqual(
ref Unsafe.As<char, byte>(ref Unsafe.Add(ref searchSpace, offset + charPos)),
ref Unsafe.As<char, byte>(ref value), (nuint)(uint)valueLength * 2))
{
return offset + charPos;
}

// Clear two the lowest set bits
if (Bmi1.IsSupported)
mask = Bmi1.ResetLowestSetBit(Bmi1.ResetLowestSetBit(mask));
else
mask &= ~(uint)(0b11 << bitPos);
} while (mask != 0);
goto CANDIDATE_FOUND;
}

LOOP_FOOTER:
offset += Vector256<ushort>.Count;

if (offset == searchSpaceMinusValueTailLength)
return -1;

// Overlap with the current chunk for trailing elements
if (offset > searchSpaceMinusValueTailLength - Vector256<ushort>.Count)
offset = searchSpaceMinusValueTailLength - Vector256<ushort>.Count;
if (offset > searchSpaceMinusValueTailLengthAndVector)
offset = searchSpaceMinusValueTailLengthAndVector;

continue;

CANDIDATE_FOUND:
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
int bitPos = BitOperations.TrailingZeroCount(mask);
// div by 2 (shr) because we work with 2-byte chars
nint charPos = (nint)((uint)bitPos / 2);
if (valueLength == 2 || // we already matched two chars
SequenceEqual(
ref Unsafe.As<char, byte>(ref Unsafe.Add(ref searchSpace, offset + charPos)),
ref Unsafe.As<char, byte>(ref value), (nuint)(uint)valueLength * 2))
{
return (int)(offset + charPos);
}

// Clear two the lowest set bits
if (Bmi1.IsSupported)
mask = Bmi1.ResetLowestSetBit(Bmi1.ResetLowestSetBit(mask));
else
mask &= ~(uint)(0b11 << bitPos);
} while (mask != 0);
goto LOOP_FOOTER;

} while (true);
}
else // 128bit vector path (SSE2 or AdvSimd)
{
// Find the last unique (which is not equal to ch1) character
// the algorithm is fine if both are equal, just a little bit less efficient
ushort ch2Val = Unsafe.Add(ref value, valueTailLength);
int ch1ch2Distance = valueTailLength;
nint ch1ch2Distance = valueTailLength;
while (ch2Val == valueHead && ch1ch2Distance > 1)
ch2Val = Unsafe.Add(ref value, --ch1ch2Distance);

Vector128<ushort> ch1 = Vector128.Create((ushort)valueHead);
Vector128<ushort> ch2 = Vector128.Create(ch2Val);

nint searchSpaceMinusValueTailLengthAndVector =
searchSpaceMinusValueTailLength - (nint)Vector128<ushort>.Count;

do
{
// Make sure we don't go out of bounds
Debug.Assert(offset + ch1ch2Distance + Vector128<ushort>.Count <= searchSpaceLength);

Vector128<ushort> cmpCh1 = Vector128.Equals(ch1, LoadVector128(ref searchSpace, offset));
Vector128<ushort> cmpCh2 = Vector128.Equals(ch2, LoadVector128(ref searchSpace, offset + ch1ch2Distance));
Vector128<ushort> cmpCh1 = Vector128.Equals(ch1, LoadVector128(ref searchSpace, offset));
Vector128<byte> cmpAnd = (cmpCh1 & cmpCh2).AsByte();

// Early out: cmpAnd is all zeros
if (cmpAnd != Vector128<byte>.Zero)
{
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
int bitPos = BitOperations.TrailingZeroCount(mask);
// div by 2 (shr) because we work with 2-byte chars
int charPos = (int)((uint)bitPos / 2);
if (valueLength == 2 || // we already matched two chars
SequenceEqual(
ref Unsafe.As<char, byte>(ref Unsafe.Add(ref searchSpace, offset + charPos)),
ref Unsafe.As<char, byte>(ref value), (nuint)(uint)valueLength * 2))
{
return offset + charPos;
}

// Clear two lowest set bits
if (Bmi1.IsSupported)
mask = Bmi1.ResetLowestSetBit(Bmi1.ResetLowestSetBit(mask));
else
mask &= ~(uint)(0b11 << bitPos);
} while (mask != 0);
goto CANDIDATE_FOUND;
}

LOOP_FOOTER:
offset += Vector128<ushort>.Count;

if (offset == searchSpaceMinusValueTailLength)
return -1;

// Overlap with the current chunk for trailing elements
if (offset > searchSpaceMinusValueTailLength - Vector128<ushort>.Count)
offset = searchSpaceMinusValueTailLength - Vector128<ushort>.Count;
if (offset > searchSpaceMinusValueTailLengthAndVector)
offset = searchSpaceMinusValueTailLengthAndVector;

continue;

CANDIDATE_FOUND:
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
int bitPos = BitOperations.TrailingZeroCount(mask);
// div by 2 (shr) because we work with 2-byte chars
int charPos = (int)((uint)bitPos / 2);
if (valueLength == 2 || // we already matched two chars
SequenceEqual(
ref Unsafe.As<char, byte>(ref Unsafe.Add(ref searchSpace, offset + charPos)),
ref Unsafe.As<char, byte>(ref value), (nuint)(uint)valueLength * 2))
{
return (int)(offset + charPos);
}

// Clear two lowest set bits
if (Bmi1.IsSupported)
mask = Bmi1.ResetLowestSetBit(Bmi1.ResetLowestSetBit(mask));
else
mask &= ~(uint)(0b11 << bitPos);
} while (mask != 0);
goto LOOP_FOOTER;

} while (true);
}
}
Expand Down