Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Commit

Permalink
Adding IndexOf overload with three byte values.
Browse files Browse the repository at this point in the history
  • Loading branch information
ahsonkhan committed Mar 21, 2017
1 parent 3e53510 commit ebbbfdf
Show file tree
Hide file tree
Showing 4 changed files with 446 additions and 7 deletions.
6 changes: 2 additions & 4 deletions src/System.Memory/src/System/SpanExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ public static int IndexOf(this Span<byte> span, byte value0, byte value1)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int IndexOf(this Span<byte> span, byte value0, byte value1, byte value2)
{
return -1;
//return SpanHelpers.IndexOf(ref span.DangerousGetPinnableReference(), value0, value1, value2, span.Length);
return SpanHelpers.IndexOf(ref span.DangerousGetPinnableReference(), value0, value1, value2, span.Length);
}

/// <summary>
Expand All @@ -173,8 +172,7 @@ public static int IndexOf(this ReadOnlySpan<byte> span, byte value0, byte value1
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int IndexOf(this ReadOnlySpan<byte> span, byte value0, byte value1, byte value2)
{
return -1;
//return SpanHelpers.IndexOf(ref span.DangerousGetPinnableReference(), value0, value1, value2, span.Length);
return SpanHelpers.IndexOf(ref span.DangerousGetPinnableReference(), value0, value1, value2, span.Length);
}

/// <summary>
Expand Down
173 changes: 170 additions & 3 deletions src/System.Memory/src/System/SpanHelpers.byte.cs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ public static unsafe int IndexOf(ref byte searchSpace, byte value0, byte value1,
}
SequentialScan:
#endif
while ((byte*)nLength >= (byte*)8)
while ((byte*)nLength >= (byte*)9)
{
nLength -= 8;

Expand All @@ -219,7 +219,7 @@ public static unsafe int IndexOf(ref byte searchSpace, byte value0, byte value1,
index += 8;
}

if ((byte*)nLength >= (byte*)4)
if ((byte*)nLength >= (byte*)5)
{
nLength -= 4;

Expand All @@ -235,7 +235,7 @@ public static unsafe int IndexOf(ref byte searchSpace, byte value0, byte value1,
index += 4;
}

while ((byte*)nLength > (byte*)1)
while ((byte*)nLength >= (byte*)2)
{
nLength -= 1;

Expand Down Expand Up @@ -311,6 +311,173 @@ public static unsafe int IndexOf(ref byte searchSpace, byte value0, byte value1,
return (int)(byte*)(index + 7);
}

public static unsafe int IndexOf(ref byte searchSpace, byte value0, byte value1, byte value2, int length)
{
Debug.Assert(length >= 0);

uint uValue0 = value0; // Use uint for comparisions to avoid unnecessary 8->32 extensions
uint uValue1 = value1; // Use uint for comparisions to avoid unnecessary 8->32 extensions
uint uValue2 = value2; // Use uint for comparisions to avoid unnecessary 8->32 extensions
IntPtr index = (IntPtr)0; // Use UIntPtr for arithmetic to avoid unnecessary 64->32->64 truncations
IntPtr nLength = (IntPtr)(uint)length;
#if !netstandard10
if (Vector.IsHardwareAccelerated && length >= Vector<byte>.Count * 2)
{
unchecked
{
int unaligned = (int)(byte*)Unsafe.AsPointer(ref searchSpace) & (Vector<byte>.Count - 1);
nLength = (IntPtr)(uint)unaligned;
}
}
SequentialScan:
#endif
while ((byte*)nLength >= (byte*)10)
{
nLength -= 8;

if (uValue0 == Unsafe.Add(ref searchSpace, index) &&
uValue1 == Unsafe.Add(ref searchSpace, index + 1) &&
uValue2 == Unsafe.Add(ref searchSpace, index + 2))
goto Found;
if (uValue0 == Unsafe.Add(ref searchSpace, index + 1) &&
uValue1 == Unsafe.Add(ref searchSpace, index + 2) &&
uValue2 == Unsafe.Add(ref searchSpace, index + 3))
goto Found1;
if (uValue0 == Unsafe.Add(ref searchSpace, index + 2) &&
uValue1 == Unsafe.Add(ref searchSpace, index + 3) &&
uValue2 == Unsafe.Add(ref searchSpace, index + 4))
goto Found2;
if (uValue0 == Unsafe.Add(ref searchSpace, index + 3) &&
uValue1 == Unsafe.Add(ref searchSpace, index + 4) &&
uValue2 == Unsafe.Add(ref searchSpace, index + 5))
goto Found3;
if (uValue0 == Unsafe.Add(ref searchSpace, index + 4) &&
uValue1 == Unsafe.Add(ref searchSpace, index + 5) &&
uValue2 == Unsafe.Add(ref searchSpace, index + 6))
goto Found4;
if (uValue0 == Unsafe.Add(ref searchSpace, index + 5) &&
uValue1 == Unsafe.Add(ref searchSpace, index + 6) &&
uValue2 == Unsafe.Add(ref searchSpace, index + 7))
goto Found5;
if (uValue0 == Unsafe.Add(ref searchSpace, index + 6) &&
uValue1 == Unsafe.Add(ref searchSpace, index + 7) &&
uValue2 == Unsafe.Add(ref searchSpace, index + 8))
goto Found6;
if (uValue0 == Unsafe.Add(ref searchSpace, index + 7) &&
uValue1 == Unsafe.Add(ref searchSpace, index + 8) &&
uValue2 == Unsafe.Add(ref searchSpace, index + 9))
goto Found7;

index += 8;
}

if ((byte*)nLength >= (byte*)6)
{
nLength -= 4;

if (uValue0 == Unsafe.Add(ref searchSpace, index) &&
uValue1 == Unsafe.Add(ref searchSpace, index + 1) &&
uValue2 == Unsafe.Add(ref searchSpace, index + 2))
goto Found;
if (uValue0 == Unsafe.Add(ref searchSpace, index + 1) &&
uValue1 == Unsafe.Add(ref searchSpace, index + 2) &&
uValue2 == Unsafe.Add(ref searchSpace, index + 3))
goto Found1;
if (uValue0 == Unsafe.Add(ref searchSpace, index + 2) &&
uValue1 == Unsafe.Add(ref searchSpace, index + 3) &&
uValue2 == Unsafe.Add(ref searchSpace, index + 4))
goto Found2;
if (uValue0 == Unsafe.Add(ref searchSpace, index + 3) &&
uValue1 == Unsafe.Add(ref searchSpace, index + 4) &&
uValue2 == Unsafe.Add(ref searchSpace, index + 5))
goto Found3;

index += 4;
}

while ((byte*)nLength >= (byte*)3)
{
nLength -= 1;

if (uValue0 == Unsafe.Add(ref searchSpace, index) &&
uValue1 == Unsafe.Add(ref searchSpace, index + 1) &&
uValue2 == Unsafe.Add(ref searchSpace, index + 2))
goto Found;

index += 1;
}
#if !netstandard10
if (Vector.IsHardwareAccelerated)
{
if ((int)(byte*)index >= length - 2)
{
goto NotFound;
}
nLength = (IntPtr)(uint)(length - Vector<byte>.Count);
// Get comparision Vector
Vector<byte> values0 = GetVector(value0);
Vector<byte> values1 = GetVector(value1);
Vector<byte> values2 = GetVector(value2);
do
{
var vData0 = Unsafe.ReadUnaligned<Vector<byte>>(ref Unsafe.AddByteOffset(ref searchSpace, index));
var vData1 = Unsafe.ReadUnaligned<Vector<byte>>(ref Unsafe.AddByteOffset(ref searchSpace, index + 1));
var vData2 = Unsafe.ReadUnaligned<Vector<byte>>(ref Unsafe.AddByteOffset(ref searchSpace, index + 2));

var vMatches = Vector.BitwiseAnd(
Vector.BitwiseAnd(
Vector.Equals(vData0, values0),
Vector.Equals(vData1, values1)),
Vector.Equals(vData2, values2));

if (!vMatches.Equals(Vector<byte>.Zero))
{
// Found match, reuse Vector values0 to keep register pressure low
values0 = vMatches;
break;
}
index += Vector<byte>.Count;
} while ((byte*)nLength > (byte*)index);

// Found match? Perform secondary search outside out of loop, so above loop body is small
if ((byte*)nLength > (byte*)index)
{
// Find offset of first match
index += LocateFirstFoundByte(values0);
// goto rather than inline return to keep function smaller
goto Found;
}

if ((int)(byte*)index <= length - 2)
{
unchecked
{
nLength = (IntPtr)(length - (int)(byte*)index);
}
goto SequentialScan;
}
}
NotFound: // Workaround for https://github.com/dotnet/coreclr/issues/9692
#endif
return -1;
Found: // Workaround for https://github.com/dotnet/coreclr/issues/9692
return (int)(byte*)index;
Found1:
return (int)(byte*)(index + 1);
Found2:
return (int)(byte*)(index + 2);
Found3:
return (int)(byte*)(index + 3);
Found4:
return (int)(byte*)(index + 4);
Found5:
return (int)(byte*)(index + 5);
Found6:
return (int)(byte*)(index + 6);
Found7:
return (int)(byte*)(index + 7);
}

public static unsafe bool SequenceEqual(ref byte first, ref byte second, int length)
{
Debug.Assert(length >= 0);
Expand Down
137 changes: 137 additions & 0 deletions src/System.Memory/tests/ReadOnlySpan/IndexOf.byte.cs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,143 @@ public static void MakeSureNoChecksGoOutOfRangeTwo_Byte()
int index = span.IndexOf(99, 99);
Assert.Equal(-1, index);
}

for (int length = 0; length < 100; length++)
{
byte[] a = new byte[length + 3];
a[0] = 99;
a[1] = 99;
a[length] = 99;
a[length + 1] = 99;
a[length + 2] = 99;
ReadOnlySpan<byte> span = new ReadOnlySpan<byte>(a, 1, length);
int index = span.IndexOf(99, 98);
Assert.Equal(-1, index);
}
}

[Fact]
public static void ZeroLengthIndexOfThree_Byte()
{
ReadOnlySpan<byte> sp = new ReadOnlySpan<byte>(Array.Empty<byte>());
int idx = sp.IndexOf(0, 0, 0);
Assert.Equal(-1, idx);
}

[Fact]
public static void TestMatchThree_Byte()
{
for (int length = 0; length < 32; length++)
{
byte[] a = new byte[length];
for (int i = 0; i < length; i++)
{
a[i] = (byte)(i + 1);
}
ReadOnlySpan<byte> span = new ReadOnlySpan<byte>(a);

for (int targetIndex = 0; targetIndex < length - 2; targetIndex++)
{
byte target0 = a[targetIndex];
byte target1 = a[targetIndex + 1];
byte target2 = a[targetIndex + 2];
int idx = span.IndexOf(target0, target1, target2);
Assert.Equal(targetIndex, idx);
}
}
}

[Fact]
public static void TestNoMatchThree_Byte()
{
for (int length = 0; length < 32; length++)
{
byte[] a = new byte[length];
for (int i = 0; i < length; i++)
{
a[i] = (byte)(i + 1);
}
ReadOnlySpan<byte> span = new ReadOnlySpan<byte>(a);

for (int targetIndex = 0; targetIndex < length - 2; targetIndex++)
{
byte target0 = a[targetIndex];
byte target1 = (byte)(a[targetIndex + 1] + 1);
byte target2 = a[targetIndex + 2];
int idx = span.IndexOf(target0, target1, target2);
Assert.Equal(-1, idx);
}
}
}

[Fact]
public static void TestMultipleMatchThree_Byte()
{
for (int length = 4; length < 32; length++)
{
byte[] a = new byte[length];
for (int i = 0; i < length; i++)
{
a[i] = (byte)(i + 1);
}

a[length - 1] = 200;
a[length - 2] = 200;
a[length - 3] = 200;
a[length - 4] = 200;

ReadOnlySpan<byte> span = new ReadOnlySpan<byte>(a);
int idx = span.IndexOf(200, 200, 200);
Assert.Equal(length - 4, idx);
}
}

[Fact]
public static void MakeSureNoChecksGoOutOfRangeThree_Byte()
{
for (int length = 4; length < 5; length++)
{
byte[] a = new byte[length + 4];
a[0] = 99;
a[1] = 99;
a[2] = 98;
a[length] = 99;
a[length + 1] = 99;
a[length + 2] = 98;
a[length + 3] = 98;
ReadOnlySpan<byte> span = new ReadOnlySpan<byte>(a, 1, length);
int index = span.IndexOf(99, 99, 98);
Assert.Equal(-1, index);
}

for (int length = 0; length < 100; length++)
{
byte[] a = new byte[length + 4];
a[0] = 99;
a[1] = 99;
a[2] = 99;
a[length + 1] = 99;
a[length + 2] = 99;
a[length + 3] = 99;
ReadOnlySpan<byte> span = new ReadOnlySpan<byte>(a, 1, length);
int index = span.IndexOf(99, 99, 99);
Assert.Equal(-1, index);
}

for (int length = 0; length < 100; length++)
{
byte[] a = new byte[length + 4];
a[0] = 99;
a[1] = 99;
a[2] = 99;
a[length] = 99;
a[length + 1] = 99;
a[length + 2] = 99;
a[length + 3] = 99;
ReadOnlySpan<byte> span = new ReadOnlySpan<byte>(a, 1, length);
int index = span.IndexOf(99, 99, 98);
Assert.Equal(-1, index);
}
}
}
}
Loading

0 comments on commit ebbbfdf

Please sign in to comment.