From b0fcb744b4bdfbe2cedf4b20465223a06fb95f92 Mon Sep 17 00:00:00 2001
From: Mark Carrington <31017244+MarkMpn@users.noreply.github.com>
Date: Sat, 20 Apr 2024 19:54:22 +0100
Subject: [PATCH] Extended validation of GROUP BY and EXECUTE AS expressions
---
MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsError.cs | 23 ++++++++++++-
.../ExecutionPlanBuilder.cs | 33 +++++++++++++++++++
.../MarkMpn.Sql4Cds.Engine.projitems | 1 +
.../Visitors/AggregateCollectingVisitor.cs | 2 +-
.../Visitors/GroupByValidatingVisitor.cs | 32 ++++++++++++++++++
.../Visitors/ScalarSubqueryVisitor.cs | 5 +++
6 files changed, 94 insertions(+), 2 deletions(-)
create mode 100644 MarkMpn.Sql4Cds.Engine/Visitors/GroupByValidatingVisitor.cs
diff --git a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsError.cs b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsError.cs
index ac40b64d..7bf67135 100644
--- a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsError.cs
+++ b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsError.cs
@@ -165,7 +165,7 @@ internal static Sql4CdsError InvalidCollation(Identifier collation)
internal static Sql4CdsError NonAggregateColumnReference(ColumnReferenceExpression column)
{
- var tableName = column.MultiPartIdentifier.Identifiers[column.MultiPartIdentifier.Identifiers.Count - 2].Value;
+ var tableName = column.MultiPartIdentifier.Identifiers.Count == 1 ? "" : column.MultiPartIdentifier.Identifiers[column.MultiPartIdentifier.Identifiers.Count - 2].Value;
var columnName = column.MultiPartIdentifier.Identifiers[column.MultiPartIdentifier.Identifiers.Count - 1].Value;
return Create(8120, column, (SqlInt32)tableName.Length, Collation.USEnglish.ToSqlString(tableName), (SqlInt32)columnName.Length, Collation.USEnglish.ToSqlString(columnName));
@@ -709,6 +709,27 @@ internal static Sql4CdsError DivideByZero()
return Create(8134, null);
}
+ internal static Sql4CdsError InvalidAggregateOrSubqueryInGroupByClause(TSqlFragment fragment)
+ {
+ return Create(144, fragment);
+ }
+
+ internal static Sql4CdsError SubqueriesNotAllowed(TSqlFragment fragment)
+ {
+ return Create(1046, fragment);
+ }
+
+ internal static Sql4CdsError ConstantExpressionsOnly(ColumnReferenceExpression col)
+ {
+ var name = col.GetColumnName();
+ return Create(128, col, (SqlInt32)name.Length, Collation.USEnglish.ToSqlString(name));
+ }
+
+ internal static Sql4CdsError InvalidTypeForStatement(TSqlFragment fragment, string name)
+ {
+ return Create(15533, fragment, name);
+ }
+
private static string GetTypeName(DataTypeReference type)
{
if (type is SqlDataTypeReference sqlType)
diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs
index 806f4dda..1249bacc 100644
--- a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs
+++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs
@@ -1165,6 +1165,23 @@ private ExecuteAsNode ConvertExecuteAsStatement(ExecuteAsStatement impersonate)
impersonate.ExecuteContext.Kind != ExecuteAsOption.User)
throw new NotSupportedQueryFragmentException(Sql4CdsError.NotSupported(impersonate.ExecuteContext, impersonate.ExecuteContext.Kind.ToString()));
+ var subqueries = new ScalarSubqueryVisitor();
+ impersonate.ExecuteContext.Principal.Accept(subqueries);
+ if (subqueries.Subqueries.Count > 0)
+ throw new NotSupportedQueryFragmentException(Sql4CdsError.SubqueriesNotAllowed(subqueries.Subqueries[0]));
+
+ var columns = new ColumnCollectingVisitor();
+ impersonate.ExecuteContext.Principal.Accept(columns);
+ if (columns.Columns.Count > 0)
+ throw new NotSupportedQueryFragmentException(Sql4CdsError.ConstantExpressionsOnly(columns.Columns[0]));
+
+ // Validate the expression
+ var ecc = new ExpressionCompilationContext(_nodeContext, null, null);
+ var type = impersonate.ExecuteContext.Principal.GetType(ecc, out _);
+
+ if (type != typeof(SqlString) && type != typeof(SqlEntityReference))
+ throw new NotSupportedQueryFragmentException(Sql4CdsError.InvalidTypeForStatement(impersonate.ExecuteContext.Principal, "Execute As"));
+
IExecutionPlanNodeInternal source;
// Create a SELECT query to find the user ID
@@ -1219,6 +1236,16 @@ private ExecuteAsNode ConvertExecuteAsStatement(ExecuteAsStatement impersonate)
ComparisonType = BooleanComparisonType.Equals,
SecondExpression = impersonate.ExecuteContext.Principal
}
+ },
+ GroupByClause = new GroupByClause
+ {
+ GroupingSpecifications =
+ {
+ new ExpressionGroupingSpecification
+ {
+ Expression = impersonate.ExecuteContext.Principal
+ }
+ }
}
}
};
@@ -2911,6 +2938,12 @@ private IDataExecutionPlanNodeInternal ConvertGroupByAggregates(IDataExecutionPl
if (querySpec.GroupByClause.GroupByOption != GroupByOption.None)
throw new NotSupportedQueryFragmentException(Sql4CdsError.NotSupported(querySpec.GroupByClause, $"GROUP BY {querySpec.GroupByClause.GroupByOption}"));
+
+ var groupByValidator = new GroupByValidatingVisitor();
+ querySpec.GroupByClause.Accept(groupByValidator);
+
+ if (groupByValidator.Error != null)
+ throw new NotSupportedQueryFragmentException(groupByValidator.Error);
}
var schema = source.GetSchema(context);
diff --git a/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems b/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems
index f65944ee..78cdd812 100644
--- a/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems
+++ b/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems
@@ -133,6 +133,7 @@
+
diff --git a/MarkMpn.Sql4Cds.Engine/Visitors/AggregateCollectingVisitor.cs b/MarkMpn.Sql4Cds.Engine/Visitors/AggregateCollectingVisitor.cs
index 277a7384..0e02bedc 100644
--- a/MarkMpn.Sql4Cds.Engine/Visitors/AggregateCollectingVisitor.cs
+++ b/MarkMpn.Sql4Cds.Engine/Visitors/AggregateCollectingVisitor.cs
@@ -71,7 +71,7 @@ public override void ExplicitVisit(QueryDerivedTable node)
// Do not recurse into subqueries - they'll be handled separately
}
- private bool IsAggregate(FunctionCall func)
+ internal static bool IsAggregate(FunctionCall func)
{
if (func.OverClause != null)
return false;
diff --git a/MarkMpn.Sql4Cds.Engine/Visitors/GroupByValidatingVisitor.cs b/MarkMpn.Sql4Cds.Engine/Visitors/GroupByValidatingVisitor.cs
new file mode 100644
index 00000000..d966a700
--- /dev/null
+++ b/MarkMpn.Sql4Cds.Engine/Visitors/GroupByValidatingVisitor.cs
@@ -0,0 +1,32 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Microsoft.SqlServer.TransactSql.ScriptDom;
+
+namespace MarkMpn.Sql4Cds.Engine.Visitors
+{
+ class GroupByValidatingVisitor : TSqlFragmentVisitor
+ {
+ public Sql4CdsError Error { get; private set; }
+
+ public override void Visit(GroupByClause node)
+ {
+ if (node.All == true)
+ Error = Sql4CdsError.NotSupported(node, "GROUP BY ALL");
+
+ if (node.GroupByOption != GroupByOption.None)
+ Error = Sql4CdsError.NotSupported(node, $"GROUP BY {node.GroupByOption}");
+ }
+
+ public override void Visit(ScalarSubquery node)
+ {
+ Error = Sql4CdsError.InvalidAggregateOrSubqueryInGroupByClause(node);
+ }
+
+ public override void Visit(FunctionCall node)
+ {
+ if (AggregateCollectingVisitor.IsAggregate(node))
+ Error = Sql4CdsError.InvalidAggregateOrSubqueryInGroupByClause(node);
+ }
+ }
+}
diff --git a/MarkMpn.Sql4Cds.Engine/Visitors/ScalarSubqueryVisitor.cs b/MarkMpn.Sql4Cds.Engine/Visitors/ScalarSubqueryVisitor.cs
index 4ada24ef..b8a4528a 100644
--- a/MarkMpn.Sql4Cds.Engine/Visitors/ScalarSubqueryVisitor.cs
+++ b/MarkMpn.Sql4Cds.Engine/Visitors/ScalarSubqueryVisitor.cs
@@ -15,5 +15,10 @@ public override void ExplicitVisit(ScalarSubquery node)
{
Subqueries.Add(node);
}
+
+ public override void ExplicitVisit(GroupByClause node)
+ {
+ // Subqueries aren't allowed in the GROUP BY clause - don't collect them so they are still present to produce validation errors later
+ }
}
}