From b0fcb744b4bdfbe2cedf4b20465223a06fb95f92 Mon Sep 17 00:00:00 2001 From: Mark Carrington <31017244+MarkMpn@users.noreply.github.com> Date: Sat, 20 Apr 2024 19:54:22 +0100 Subject: [PATCH] Extended validation of GROUP BY and EXECUTE AS expressions --- MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsError.cs | 23 ++++++++++++- .../ExecutionPlanBuilder.cs | 33 +++++++++++++++++++ .../MarkMpn.Sql4Cds.Engine.projitems | 1 + .../Visitors/AggregateCollectingVisitor.cs | 2 +- .../Visitors/GroupByValidatingVisitor.cs | 32 ++++++++++++++++++ .../Visitors/ScalarSubqueryVisitor.cs | 5 +++ 6 files changed, 94 insertions(+), 2 deletions(-) create mode 100644 MarkMpn.Sql4Cds.Engine/Visitors/GroupByValidatingVisitor.cs diff --git a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsError.cs b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsError.cs index ac40b64d..7bf67135 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsError.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsError.cs @@ -165,7 +165,7 @@ internal static Sql4CdsError InvalidCollation(Identifier collation) internal static Sql4CdsError NonAggregateColumnReference(ColumnReferenceExpression column) { - var tableName = column.MultiPartIdentifier.Identifiers[column.MultiPartIdentifier.Identifiers.Count - 2].Value; + var tableName = column.MultiPartIdentifier.Identifiers.Count == 1 ? "" : column.MultiPartIdentifier.Identifiers[column.MultiPartIdentifier.Identifiers.Count - 2].Value; var columnName = column.MultiPartIdentifier.Identifiers[column.MultiPartIdentifier.Identifiers.Count - 1].Value; return Create(8120, column, (SqlInt32)tableName.Length, Collation.USEnglish.ToSqlString(tableName), (SqlInt32)columnName.Length, Collation.USEnglish.ToSqlString(columnName)); @@ -709,6 +709,27 @@ internal static Sql4CdsError DivideByZero() return Create(8134, null); } + internal static Sql4CdsError InvalidAggregateOrSubqueryInGroupByClause(TSqlFragment fragment) + { + return Create(144, fragment); + } + + internal static Sql4CdsError SubqueriesNotAllowed(TSqlFragment fragment) + { + return Create(1046, fragment); + } + + internal static Sql4CdsError ConstantExpressionsOnly(ColumnReferenceExpression col) + { + var name = col.GetColumnName(); + return Create(128, col, (SqlInt32)name.Length, Collation.USEnglish.ToSqlString(name)); + } + + internal static Sql4CdsError InvalidTypeForStatement(TSqlFragment fragment, string name) + { + return Create(15533, fragment, name); + } + private static string GetTypeName(DataTypeReference type) { if (type is SqlDataTypeReference sqlType) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs index 806f4dda..1249bacc 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs @@ -1165,6 +1165,23 @@ private ExecuteAsNode ConvertExecuteAsStatement(ExecuteAsStatement impersonate) impersonate.ExecuteContext.Kind != ExecuteAsOption.User) throw new NotSupportedQueryFragmentException(Sql4CdsError.NotSupported(impersonate.ExecuteContext, impersonate.ExecuteContext.Kind.ToString())); + var subqueries = new ScalarSubqueryVisitor(); + impersonate.ExecuteContext.Principal.Accept(subqueries); + if (subqueries.Subqueries.Count > 0) + throw new NotSupportedQueryFragmentException(Sql4CdsError.SubqueriesNotAllowed(subqueries.Subqueries[0])); + + var columns = new ColumnCollectingVisitor(); + impersonate.ExecuteContext.Principal.Accept(columns); + if (columns.Columns.Count > 0) + throw new NotSupportedQueryFragmentException(Sql4CdsError.ConstantExpressionsOnly(columns.Columns[0])); + + // Validate the expression + var ecc = new ExpressionCompilationContext(_nodeContext, null, null); + var type = impersonate.ExecuteContext.Principal.GetType(ecc, out _); + + if (type != typeof(SqlString) && type != typeof(SqlEntityReference)) + throw new NotSupportedQueryFragmentException(Sql4CdsError.InvalidTypeForStatement(impersonate.ExecuteContext.Principal, "Execute As")); + IExecutionPlanNodeInternal source; // Create a SELECT query to find the user ID @@ -1219,6 +1236,16 @@ private ExecuteAsNode ConvertExecuteAsStatement(ExecuteAsStatement impersonate) ComparisonType = BooleanComparisonType.Equals, SecondExpression = impersonate.ExecuteContext.Principal } + }, + GroupByClause = new GroupByClause + { + GroupingSpecifications = + { + new ExpressionGroupingSpecification + { + Expression = impersonate.ExecuteContext.Principal + } + } } } }; @@ -2911,6 +2938,12 @@ private IDataExecutionPlanNodeInternal ConvertGroupByAggregates(IDataExecutionPl if (querySpec.GroupByClause.GroupByOption != GroupByOption.None) throw new NotSupportedQueryFragmentException(Sql4CdsError.NotSupported(querySpec.GroupByClause, $"GROUP BY {querySpec.GroupByClause.GroupByOption}")); + + var groupByValidator = new GroupByValidatingVisitor(); + querySpec.GroupByClause.Accept(groupByValidator); + + if (groupByValidator.Error != null) + throw new NotSupportedQueryFragmentException(groupByValidator.Error); } var schema = source.GetSchema(context); diff --git a/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems b/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems index f65944ee..78cdd812 100644 --- a/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems +++ b/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems @@ -133,6 +133,7 @@ + diff --git a/MarkMpn.Sql4Cds.Engine/Visitors/AggregateCollectingVisitor.cs b/MarkMpn.Sql4Cds.Engine/Visitors/AggregateCollectingVisitor.cs index 277a7384..0e02bedc 100644 --- a/MarkMpn.Sql4Cds.Engine/Visitors/AggregateCollectingVisitor.cs +++ b/MarkMpn.Sql4Cds.Engine/Visitors/AggregateCollectingVisitor.cs @@ -71,7 +71,7 @@ public override void ExplicitVisit(QueryDerivedTable node) // Do not recurse into subqueries - they'll be handled separately } - private bool IsAggregate(FunctionCall func) + internal static bool IsAggregate(FunctionCall func) { if (func.OverClause != null) return false; diff --git a/MarkMpn.Sql4Cds.Engine/Visitors/GroupByValidatingVisitor.cs b/MarkMpn.Sql4Cds.Engine/Visitors/GroupByValidatingVisitor.cs new file mode 100644 index 00000000..d966a700 --- /dev/null +++ b/MarkMpn.Sql4Cds.Engine/Visitors/GroupByValidatingVisitor.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.SqlServer.TransactSql.ScriptDom; + +namespace MarkMpn.Sql4Cds.Engine.Visitors +{ + class GroupByValidatingVisitor : TSqlFragmentVisitor + { + public Sql4CdsError Error { get; private set; } + + public override void Visit(GroupByClause node) + { + if (node.All == true) + Error = Sql4CdsError.NotSupported(node, "GROUP BY ALL"); + + if (node.GroupByOption != GroupByOption.None) + Error = Sql4CdsError.NotSupported(node, $"GROUP BY {node.GroupByOption}"); + } + + public override void Visit(ScalarSubquery node) + { + Error = Sql4CdsError.InvalidAggregateOrSubqueryInGroupByClause(node); + } + + public override void Visit(FunctionCall node) + { + if (AggregateCollectingVisitor.IsAggregate(node)) + Error = Sql4CdsError.InvalidAggregateOrSubqueryInGroupByClause(node); + } + } +} diff --git a/MarkMpn.Sql4Cds.Engine/Visitors/ScalarSubqueryVisitor.cs b/MarkMpn.Sql4Cds.Engine/Visitors/ScalarSubqueryVisitor.cs index 4ada24ef..b8a4528a 100644 --- a/MarkMpn.Sql4Cds.Engine/Visitors/ScalarSubqueryVisitor.cs +++ b/MarkMpn.Sql4Cds.Engine/Visitors/ScalarSubqueryVisitor.cs @@ -15,5 +15,10 @@ public override void ExplicitVisit(ScalarSubquery node) { Subqueries.Add(node); } + + public override void ExplicitVisit(GroupByClause node) + { + // Subqueries aren't allowed in the GROUP BY clause - don't collect them so they are still present to produce validation errors later + } } }