diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs index ee8146c9a95..051eceb7b5e 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs @@ -796,7 +796,8 @@ private sealed class CloningExpressionVisitor : ExpressionVisitor _usedAliases = selectExpression._usedAliases.ToHashSet(), _projectionMapping = newProjectionMappings, _groupingCorrelationPredicate = groupingCorrelationPredicate, - _groupingParentSelectExpressionId = selectExpression._groupingParentSelectExpressionId + _groupingParentSelectExpressionId = selectExpression._groupingParentSelectExpressionId, + _groupingParentSelectExpressionTableCount = selectExpression._groupingParentSelectExpressionTableCount, }; newSelectExpression._tptLeftJoinTables.AddRange(selectExpression._tptLeftJoinTables); @@ -869,29 +870,122 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio && subquery.Offset == null && subquery._groupBy.Count == 0 && subquery.Predicate != null - && ((AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled) && enabled) + && ((AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled27102) && enabled27102) || subquery.Predicate.Equals(subquery._groupingCorrelationPredicate)) - && ((AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27094", out var enabled2) && enabled2) + && ((AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27094", out var enabled27094) && enabled27094) || subquery._groupingParentSelectExpressionId == _selectExpression._groupingParentSelectExpressionId)) { var initialTableCounts = 0; - var potentialTableCount = Math.Min(_selectExpression._tables.Count, subquery._tables.Count); - for (var i = 0; i < potentialTableCount; i++) + if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27163", out var enabled27163) && enabled27163) { - if (!string.Equals( - _selectExpression._tableReferences[i].Alias, - subquery._tableReferences[i].Alias, StringComparison.OrdinalIgnoreCase)) + var potentialTableCount = Math.Min(_selectExpression._tables.Count, subquery._tables.Count); + for (var i = 0; i < potentialTableCount; i++) { - break; + if (!string.Equals( + _selectExpression._tableReferences[i].Alias, + subquery._tableReferences[i].Alias, StringComparison.OrdinalIgnoreCase)) + { + break; + } + + if (_selectExpression._tables[i] is SelectExpression originalNestedSelectExpression + && subquery._tables[i] is SelectExpression subqueryNestedSelectExpression) + { + CopyOverOwnedJoinInSameTable(originalNestedSelectExpression, subqueryNestedSelectExpression); + } + + initialTableCounts++; + } + } + else + { + initialTableCounts = _selectExpression._groupingParentSelectExpressionTableCount!.Value; + var potentialTableCount = Math.Min(_selectExpression._tables.Count, subquery._tables.Count); + // First verify that subquery has same structure for initial tables, + // If not then subquery may have different root than grouping element. + for (var i = 0; i < initialTableCounts; i++) + { + if (!string.Equals( + _selectExpression._tableReferences[i].Alias, + subquery._tableReferences[i].Alias, StringComparison.OrdinalIgnoreCase)) + { + initialTableCounts = 0; + break; + } } - if (_selectExpression._tables[i] is SelectExpression originalNestedSelectExpression - && subquery._tables[i] is SelectExpression subqueryNestedSelectExpression) + if (initialTableCounts > 0) { - CopyOverOwnedJoinInSameTable(originalNestedSelectExpression, subqueryNestedSelectExpression); + // If initial table structure matches and + // Parent has additional joins lifted already one of them is a subquery join + // Then we abort lifting if any of the joins from the subquery to lift are a subquery join + if (_selectExpression._tables.Skip(initialTableCounts) + .Select(e => UnwrapJoinExpression(e)) + .Any(e => e is SelectExpression)) + { + for (var i = initialTableCounts; i < subquery._tables.Count; i++) + { + if (UnwrapJoinExpression(subquery._tables[i]) is SelectExpression) + { + // If any of the join is to subquery then we abort the lifting group by term altogether. + initialTableCounts = 0; + break; + } + } + } } - initialTableCounts++; + if (initialTableCounts > 0) + { + // We need to copy over owned join which are coming from same initial tables. + for (var i = 0; i < initialTableCounts; i++) + { + if (_selectExpression._tables[i] is SelectExpression originalNestedSelectExpression + && subquery._tables[i] is SelectExpression subqueryNestedSelectExpression) + { + CopyOverOwnedJoinInSameTable(originalNestedSelectExpression, subqueryNestedSelectExpression); + } + } + + + for (var i = initialTableCounts; i < potentialTableCount; i++) + { + // Try to match additional tables for the cases where we can match exact so we can avoid lifting + // same joins to parent + if (!string.Equals( + _selectExpression._tableReferences[i].Alias, + subquery._tableReferences[i].Alias, StringComparison.OrdinalIgnoreCase)) + { + break; + } + + var outerTableExpressionBase = _selectExpression._tables[i]; + var innerTableExpressionBase = subquery._tables[i]; + + if (outerTableExpressionBase is InnerJoinExpression outerInnerJoin + && innerTableExpressionBase is InnerJoinExpression innerInnerJoin) + { + outerTableExpressionBase = outerInnerJoin.Table as TableExpression; + innerTableExpressionBase = innerInnerJoin.Table as TableExpression; + } + else if (outerTableExpressionBase is LeftJoinExpression outerLeftJoin + && innerTableExpressionBase is LeftJoinExpression innerLeftJoin) + { + outerTableExpressionBase = outerLeftJoin.Table as TableExpression; + innerTableExpressionBase = innerLeftJoin.Table as TableExpression; + } + + if (outerTableExpressionBase is TableExpression outerTable + && innerTableExpressionBase is TableExpression innerTable + && !(string.Equals(outerTable.Name, innerTable.Name, StringComparison.OrdinalIgnoreCase) + && string.Equals(outerTable.Schema, innerTable.Schema, StringComparison.OrdinalIgnoreCase))) + { + break; + } + + initialTableCounts++; + } + } } if (initialTableCounts > 0) @@ -900,7 +994,7 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio // We only replace columns from initial tables. // Additional tables may have been added to outer from other terms which may end up matching on table alias var columnExpressionReplacingExpressionVisitor = - AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27083", out var enabled3) && enabled3 + AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27083", out var enabled27083) && enabled27083 ? new ColumnExpressionReplacingExpressionVisitor( subquery, _selectExpression._tableReferences) : new ColumnExpressionReplacingExpressionVisitor( diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs index 208b73cf9f6..b0a99727d4b 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs @@ -63,6 +63,7 @@ public sealed partial class SelectExpression : TableExpressionBase private SqlExpression? _groupingCorrelationPredicate; private Guid? _groupingParentSelectExpressionId; + private int? _groupingParentSelectExpressionTableCount; private CloningExpressionVisitor? _cloningExpressionVisitor; private SelectExpression( @@ -1256,7 +1257,7 @@ public GroupByShaperExpression ApplyGrouping( // We generate the cloned expression before changing identifier for this SelectExpression // because we are going to erase grouping for cloned expression. - if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27094", out var enabled2) && enabled2)) + if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27094", out var enabled27094) && enabled27094)) { _groupingParentSelectExpressionId = Guid.NewGuid(); @@ -1267,10 +1268,15 @@ public GroupByShaperExpression ApplyGrouping( .Aggregate((l, r) => sqlExpressionFactory.AndAlso(l, r)); clonedSelectExpression._groupBy.Clear(); clonedSelectExpression.ApplyPredicate(correlationPredicate); - if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled) && enabled)) + if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled27102) && enabled27102)) { clonedSelectExpression._groupingCorrelationPredicate = clonedSelectExpression.Predicate; } + if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27163", out var enabled27163) && enabled27163)) + { + _groupingParentSelectExpressionTableCount = _tables.Count; + + } if (!_identifier.All(e => _groupBy.Contains(e.Column))) { @@ -1495,6 +1501,7 @@ private void ApplySetOperation(SetOperationType setOperationType, SelectExpressi Offset = Offset, Limit = Limit, _groupingParentSelectExpressionId = _groupingParentSelectExpressionId, + _groupingParentSelectExpressionTableCount = _groupingParentSelectExpressionTableCount, _groupingCorrelationPredicate = _groupingCorrelationPredicate }; Offset = null; @@ -1504,6 +1511,7 @@ private void ApplySetOperation(SetOperationType setOperationType, SelectExpressi Having = null; _groupingCorrelationPredicate = null; _groupingParentSelectExpressionId = null; + _groupingParentSelectExpressionTableCount = null; _groupBy.Clear(); _orderings.Clear(); _tables.Clear(); @@ -2819,7 +2827,8 @@ private SqlRemappingVisitor PushdownIntoSubqueryInternal() Having = Having, Offset = Offset, Limit = Limit, - _groupingParentSelectExpressionId = _groupingParentSelectExpressionId + _groupingParentSelectExpressionId = _groupingParentSelectExpressionId, + _groupingParentSelectExpressionTableCount = _groupingParentSelectExpressionTableCount }; subquery._usedAliases = _usedAliases; _tables.Clear(); @@ -3482,7 +3491,8 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) Tags = Tags, _usedAliases = _usedAliases, _groupingCorrelationPredicate = groupingCorrelationPredicate, - _groupingParentSelectExpressionId = _groupingParentSelectExpressionId + _groupingParentSelectExpressionId = _groupingParentSelectExpressionId, + _groupingParentSelectExpressionTableCount = _groupingParentSelectExpressionTableCount, }; newSelectExpression._tptLeftJoinTables.AddRange(_tptLeftJoinTables); diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs index f0fdfd2057e..d3e4cfed26e 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs @@ -786,5 +786,108 @@ protected class Table public int Id { get; set; } public int? Value { get; set; } } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Group_by_multiple_aggregate_joining_different_tables(bool async) + { + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + + var query = context.Parents + .GroupBy(x => new { }) + .Select(g => new + { + Test1 = g + .Select(x => x.Child1.Value1) + .Distinct() + .Count(), + Test2 = g + .Select(x => x.Child2.Value2) + .Distinct() + .Count() + }); + + var orders = async + ? await query.ToListAsync() + : query.ToList(); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Group_by_multiple_aggregate_joining_different_tables_with_query_filter(bool async) + { + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + + var query = context.Parents + .GroupBy(x => new { }) + .Select(g => new + { + Test1 = g + .Select(x => x.ChildFilter1.Value1) + .Distinct() + .Count(), + Test2 = g + .Select(x => x.ChildFilter2.Value2) + .Distinct() + .Count() + }); + + var orders = async + ? await query.ToListAsync() + : query.ToList(); + } + + protected class Context27163 : DbContext + { + public Context27163(DbContextOptions options) + : base(options) + { + } + + public DbSet Parents { get; set; } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity().HasQueryFilter(e => e.Filter1 == "Filter1"); + modelBuilder.Entity().HasQueryFilter(e => e.Filter2 == "Filter2"); + } + } + + public class Parent + { + public int Id { get; set; } + public Child1 Child1 { get; set; } + public Child2 Child2 { get; set; } + public ChildFilter1 ChildFilter1 { get; set; } + public ChildFilter2 ChildFilter2 { get; set; } + } + + public class Child1 + { + public int Id { get; set; } + public string Value1 { get; set; } + } + + public class Child2 + { + public int Id { get; set; } + public string Value2 { get; set; } + } + + public class ChildFilter1 + { + public int Id { get; set; } + public string Filter1 { get; set; } + public string Value1 { get; set; } + } + + public class ChildFilter2 + { + public int Id { get; set; } + public string Filter2 { get; set; } + public string Value2 { get; set; } + } } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/ComplexNavigationsSharedTypeQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/ComplexNavigationsSharedTypeQuerySqlServerTest.cs index e8bfd57fd2a..1bf2105ecca 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/ComplexNavigationsSharedTypeQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/ComplexNavigationsSharedTypeQuerySqlServerTest.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Threading.Tasks; +using Xunit; using Xunit.Abstractions; namespace Microsoft.EntityFrameworkCore.Query @@ -299,6 +300,18 @@ FROM [Level1] AS [l1] ) AS [t0] ON [l].[Id] = [t0].[Level1_Optional_Id]"); } + [ConditionalTheory(Skip = "Issue#26104")] + public override Task GroupBy_aggregate_where_required_relationship(bool async) + { + return base.GroupBy_aggregate_where_required_relationship(async); + } + + [ConditionalTheory(Skip = "Issue#26104")] + public override Task GroupBy_aggregate_where_required_relationship_2(bool async) + { + return base.GroupBy_aggregate_where_required_relationship_2(async); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs index 75499b65b24..da76b8115e4 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs @@ -213,5 +213,49 @@ ORDER BY (SELECT 1)), 0) AS [C] FROM [Table] AS [t] GROUP BY [t].[Value]"); } + + public override async Task Group_by_multiple_aggregate_joining_different_tables(bool async) + { + await base.Group_by_multiple_aggregate_joining_different_tables(async); + + AssertSql( + @"SELECT COUNT(DISTINCT ([c].[Value1])) AS [Test1], COUNT(DISTINCT ([c0].[Value2])) AS [Test2] +FROM ( + SELECT [p].[Child1Id], [p].[Child2Id], 1 AS [Key] + FROM [Parents] AS [p] +) AS [t] +LEFT JOIN [Child1] AS [c] ON [t].[Child1Id] = [c].[Id] +LEFT JOIN [Child2] AS [c0] ON [t].[Child2Id] = [c0].[Id] +GROUP BY [t].[Key]"); + } + + public override async Task Group_by_multiple_aggregate_joining_different_tables_with_query_filter(bool async) + { + await base.Group_by_multiple_aggregate_joining_different_tables_with_query_filter(async); + + AssertSql( + @"SELECT COUNT(DISTINCT ([t0].[Value1])) AS [Test1], ( + SELECT DISTINCT COUNT(DISTINCT ([t2].[Value2])) + FROM ( + SELECT [p0].[Id], [p0].[Child1Id], [p0].[Child2Id], [p0].[ChildFilter1Id], [p0].[ChildFilter2Id], 1 AS [Key] + FROM [Parents] AS [p0] + ) AS [t1] + LEFT JOIN ( + SELECT [c0].[Id], [c0].[Filter2], [c0].[Value2] + FROM [ChildFilter2] AS [c0] + WHERE [c0].[Filter2] = N'Filter2' + ) AS [t2] ON [t1].[ChildFilter2Id] = [t2].[Id] + WHERE [t].[Key] = [t1].[Key]) AS [Test2] +FROM ( + SELECT [p].[ChildFilter1Id], 1 AS [Key] + FROM [Parents] AS [p] +) AS [t] +LEFT JOIN ( + SELECT [c].[Id], [c].[Value1] + FROM [ChildFilter1] AS [c] + WHERE [c].[Filter1] = N'Filter1' +) AS [t0] ON [t].[ChildFilter1Id] = [t0].[Id] +GROUP BY [t].[Key]"); + } } } diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/ComplexNavigationsSharedTypeQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/ComplexNavigationsSharedTypeQuerySqliteTest.cs index 56c6b9a0459..19b4cea9cc5 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/ComplexNavigationsSharedTypeQuerySqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/ComplexNavigationsSharedTypeQuerySqliteTest.cs @@ -28,5 +28,17 @@ public override async Task Prune_does_not_throw_null_ref(bool async) SqliteStrings.ApplyNotSupported, (await Assert.ThrowsAsync( () => base.Prune_does_not_throw_null_ref(async))).Message); + + [ConditionalTheory(Skip = "Issue#26104")] + public override Task GroupBy_aggregate_where_required_relationship(bool async) + { + return base.GroupBy_aggregate_where_required_relationship(async); + } + + [ConditionalTheory(Skip = "Issue#26104")] + public override Task GroupBy_aggregate_where_required_relationship_2(bool async) + { + return base.GroupBy_aggregate_where_required_relationship_2(async); + } } }