Skip to content

Commit

Permalink
Ensure MaxBy/MinBy return first element if all keys are null. (#61364)
Browse files Browse the repository at this point in the history
  • Loading branch information
eiriktsarpalis authored Nov 9, 2021
1 parent 057e34d commit 61d603f
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 36 deletions.
19 changes: 13 additions & 6 deletions src/libraries/System.Linq/src/System/Linq/Max.cs
Original file line number Diff line number Diff line change
Expand Up @@ -583,15 +583,22 @@ public static decimal Max(this IEnumerable<decimal> source)

if (default(TKey) is null)
{
while (key == null)
if (key == null)
{
if (!e.MoveNext())
TSource firstValue = value;

do
{
return value;
}
if (!e.MoveNext())
{
// All keys are null, surface the first element.
return firstValue;
}

value = e.Current;
key = keySelector(value);
value = e.Current;
key = keySelector(value);
}
while (key == null);
}

while (e.MoveNext())
Expand Down
19 changes: 13 additions & 6 deletions src/libraries/System.Linq/src/System/Linq/Min.cs
Original file line number Diff line number Diff line change
Expand Up @@ -541,15 +541,22 @@ public static decimal Min(this IEnumerable<decimal> source)

if (default(TKey) is null)
{
while (key == null)
if (key == null)
{
if (!e.MoveNext())
TSource firstValue = value;

do
{
return value;
}
if (!e.MoveNext())
{
// All keys are null, surface the first element.
return firstValue;
}

value = e.Current;
key = keySelector(value);
value = e.Current;
key = keySelector(value);
}
while (key == null);
}

while (e.MoveNext())
Expand Down
24 changes: 12 additions & 12 deletions src/libraries/System.Linq/tests/MaxTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -890,27 +890,27 @@ public static void MaxBy_Generic_EmptyReferenceSource_ReturnsNull()
}

[Fact]
public static void MaxBy_Generic_StructSourceAllKeysAreNull_ReturnsLastElement()
public static void MaxBy_Generic_StructSourceAllKeysAreNull_ReturnsFirstElement()
{
Assert.Equal(4, Enumerable.Range(0, 5).MaxBy(x => default(string)));
Assert.Equal(4, Enumerable.Range(0, 5).MaxBy(x => default(string), comparer: null));
Assert.Equal(4, Enumerable.Range(0, 5).MaxBy(x => default(string), Comparer<string>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
Assert.Equal(0, Enumerable.Range(0, 5).MaxBy(x => default(string)));
Assert.Equal(0, Enumerable.Range(0, 5).MaxBy(x => default(string), comparer: null));
Assert.Equal(0, Enumerable.Range(0, 5).MaxBy(x => default(string), Comparer<string>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
}

[Fact]
public static void MaxBy_Generic_NullableSourceAllKeysAreNull_ReturnsLastElement()
public static void MaxBy_Generic_NullableSourceAllKeysAreNull_ReturnsFirstElement()
{
Assert.Equal(4, Enumerable.Range(0, 5).Cast<int?>().MaxBy(x => default(int?)));
Assert.Equal(4, Enumerable.Range(0, 5).Cast<int?>().MaxBy(x => default(int?), comparer: null));
Assert.Equal(4, Enumerable.Range(0, 5).Cast<int?>().MaxBy(x => default(int?), Comparer<int?>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
Assert.Equal(0, Enumerable.Range(0, 5).Cast<int?>().MaxBy(x => default(int?)));
Assert.Equal(0, Enumerable.Range(0, 5).Cast<int?>().MaxBy(x => default(int?), comparer: null));
Assert.Equal(0, Enumerable.Range(0, 5).Cast<int?>().MaxBy(x => default(int?), Comparer<int?>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
}

[Fact]
public static void MaxBy_Generic_ReferenceSourceAllKeysAreNull_ReturnsLastElement()
public static void MaxBy_Generic_ReferenceSourceAllKeysAreNull_ReturnsFirstElement()
{
Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MaxBy(x => default(string)));
Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MaxBy(x => default(string), comparer: null));
Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MaxBy(x => default(string), Comparer<string>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
Assert.Equal("0", Enumerable.Range(0, 5).Select(x => x.ToString()).MaxBy(x => default(string)));
Assert.Equal("0", Enumerable.Range(0, 5).Select(x => x.ToString()).MaxBy(x => default(string), comparer: null));
Assert.Equal("0", Enumerable.Range(0, 5).Select(x => x.ToString()).MaxBy(x => default(string), Comparer<string>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
}

[Theory]
Expand Down
24 changes: 12 additions & 12 deletions src/libraries/System.Linq/tests/MinTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -868,27 +868,27 @@ public static void MinBy_Generic_EmptyReferenceSource_ReturnsNull()
}

[Fact]
public static void MinBy_Generic_StructSourceAllKeysAreNull_ReturnsLastElement()
public static void MinBy_Generic_StructSourceAllKeysAreNull_ReturnsFirstElement()
{
Assert.Equal(4, Enumerable.Range(0, 5).MinBy(x => default(string)));
Assert.Equal(4, Enumerable.Range(0, 5).MinBy(x => default(string), comparer: null));
Assert.Equal(4, Enumerable.Range(0, 5).MinBy(x => default(string), Comparer<string>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
Assert.Equal(0, Enumerable.Range(0, 5).MinBy(x => default(string)));
Assert.Equal(0, Enumerable.Range(0, 5).MinBy(x => default(string), comparer: null));
Assert.Equal(0, Enumerable.Range(0, 5).MinBy(x => default(string), Comparer<string>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
}

[Fact]
public static void MinBy_Generic_NullableSourceAllKeysAreNull_ReturnsLastElement()
public static void MinBy_Generic_NullableSourceAllKeysAreNull_ReturnsFirstElement()
{
Assert.Equal(4, Enumerable.Range(0, 5).Cast<int?>().MinBy(x => default(int?)));
Assert.Equal(4, Enumerable.Range(0, 5).Cast<int?>().MinBy(x => default(int?), comparer: null));
Assert.Equal(4, Enumerable.Range(0, 5).Cast<int?>().MinBy(x => default(int?), Comparer<int?>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
Assert.Equal(0, Enumerable.Range(0, 5).Cast<int?>().MinBy(x => default(int?)));
Assert.Equal(0, Enumerable.Range(0, 5).Cast<int?>().MinBy(x => default(int?), comparer: null));
Assert.Equal(0, Enumerable.Range(0, 5).Cast<int?>().MinBy(x => default(int?), Comparer<int?>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
}

[Fact]
public static void MinBy_Generic_ReferenceSourceAllKeysAreNull_ReturnsLastElement()
public static void MinBy_Generic_ReferenceSourceAllKeysAreNull_ReturnsFirstElement()
{
Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MinBy(x => default(string)));
Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MinBy(x => default(string), comparer: null));
Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MinBy(x => default(string), Comparer<string>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
Assert.Equal("0", Enumerable.Range(0, 5).Select(x => x.ToString()).MinBy(x => default(string)));
Assert.Equal("0", Enumerable.Range(0, 5).Select(x => x.ToString()).MinBy(x => default(string), comparer: null));
Assert.Equal("0", Enumerable.Range(0, 5).Select(x => x.ToString()).MinBy(x => default(string), Comparer<string>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
}

[Theory]
Expand Down

0 comments on commit 61d603f

Please sign in to comment.