Skip to content

Commit

Permalink
Build basic structure of recursive CTE plan
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkMpn committed Sep 8, 2023
1 parent 8e31b54 commit 180694d
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 22 deletions.
62 changes: 62 additions & 0 deletions MarkMpn.Sql4Cds.Engine.Tests/CteTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,39 @@ WITH cte (id, n) AS (SELECT accountid, name FROM account)
</fetch>");
}

[TestMethod]
public void MultipleAnchorQueries()
{
var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this);

var query = @"
WITH cte (id, n) AS (SELECT accountid, name FROM account UNION ALL select contactid, fullname FROM contact)
SELECT * FROM cte";

var plans = planBuilder.Build(query, null, out _);

Assert.AreEqual(1, plans.Length);

var select = AssertNode<SelectNode>(plans[0]);
var concat = AssertNode<ConcatenateNode>(select.Source);
var account = AssertNode<FetchXmlScan>(concat.Sources[0]);
AssertFetchXml(account, @"
<fetch>
<entity name='account'>
<attribute name='accountid' />
<attribute name='name' />
</entity>
</fetch>");
var contact = AssertNode<FetchXmlScan>(concat.Sources[1]);
AssertFetchXml(contact, @"
<fetch>
<entity name='contact'>
<attribute name='contactid' />
<attribute name='fullname' />
</entity>
</fetch>");
}

[TestMethod]
public void MergeFilters()
{
Expand Down Expand Up @@ -457,6 +490,35 @@ WITH cte (id, fname, lname) AS (
</fetch>");
}

[TestMethod]
public void SimpleRecursion()
{
var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this);

var query = @"
WITH cte AS (
SELECT contactid, firstname, lastname FROM contact WHERE firstname = 'Mark'
UNION ALL
SELECT c.contactid, c.firstname, c.lastname FROM contact c INNER JOIN cte ON c.parentcustomerid = cte.contactid
)
SELECT * FROM cte";

var plans = planBuilder.Build(query, null, out _);

Assert.AreEqual(1, plans.Length);

var select = AssertNode<SelectNode>(plans[0]);
var spoolProducer = AssertNode<IndexSpoolNode>(select.Source);
var concat = AssertNode<ConcatenateNode>(spoolProducer.Source);
var depth0 = AssertNode<ComputeScalarNode>(concat.Sources[0]);
var anchor = AssertNode<FetchXmlScan>(depth0.Source);
var assert = AssertNode<AssertNode>(concat.Sources[1]);
var nestedLoop = AssertNode<NestedLoopNode>(assert.Source);
var depthPlus1 = AssertNode<ComputeScalarNode>(nestedLoop.LeftSource);
var spoolConsumer = AssertNode<TableSpoolNode>(depthPlus1);
var children = AssertNode<FetchXmlScan>(nestedLoop.RightSource);
}

private T AssertNode<T>(IExecutionPlanNode node) where T : IExecutionPlanNode
{
Assert.IsInstanceOfType(node, typeof(T));
Expand Down
133 changes: 112 additions & 21 deletions MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -184,38 +184,129 @@ private void ConvertStatement(TSqlStatement statement, ExecutionPlanOptimizer op
var cteValidator = new CteValidatorVisitor();
cte.Accept(cteValidator);

if (!cteValidator.IsRecursive)
// Start by converting the anchor query to a subquery
var plan = ConvertSelectStatement(cteValidator.AnchorQuery, hints, null, null, _nodeContext);

// Apply column aliases
if (cte.Columns.Count > 0)
{
plan.ExpandWildcardColumns(_nodeContext);

if (cte.Columns.Count < plan.ColumnSet.Count)
throw new NotSupportedQueryFragmentException($"'{cteValidator.Name}' has more columns than were specified in the column list.", cte);

if (cte.Columns.Count > plan.ColumnSet.Count)
throw new NotSupportedQueryFragmentException($"'{cteValidator.Name}' has fewer columns than were specified in the column list.", cte);

for (var i = 0; i < cte.Columns.Count; i++)
plan.ColumnSet[i].OutputColumn = cte.Columns[i].Value;
}

for (var i = 0; i < plan.ColumnSet.Count; i++)
{
if (plan.ColumnSet[i].OutputColumn == null)
throw new NotSupportedQueryFragmentException($"No column name was specified for column {i+1} of '{cteValidator.Name}'", cte);
}

var anchorQuery = new AliasNode(plan, cte.ExpressionName, _nodeContext);

if (cteValidator.RecursiveQueries.Count > 0)
{
// If the CTE isn't recursive then we can just convert it to a subquery
var plan = ConvertSelectStatement(cte.QueryExpression, hints, null, null, _nodeContext);
var ctePlan = anchorQuery.Source;

// Apply column aliases
if (cte.Columns.Count > 0)
// Add a ComputeScalar node to add the initial recursion depth (0)
var recursionDepthField = _nodeContext.GetExpressionName();
var initialRecursionDepthComputeScalar = new ComputeScalarNode
{
plan.ExpandWildcardColumns(_nodeContext);
Source = ctePlan,
Columns =
{
[recursionDepthField] = new IntegerLiteral { Value = "0" }
}
};

if (cte.Columns.Count < plan.ColumnSet.Count)
throw new NotSupportedQueryFragmentException($"'{cteValidator.Name}' has more columns than were specified in the column list.", cte);
// Add a ConcatenateNode to combine the anchor results with the recursion results
var recurseConcat = new ConcatenateNode
{
Sources = { initialRecursionDepthComputeScalar },
};

if (cte.Columns.Count > plan.ColumnSet.Count)
throw new NotSupportedQueryFragmentException($"'{cteValidator.Name}' has fewer columns than were specified in the column list.", cte);
foreach (var col in anchorQuery.ColumnSet)
{
var concatCol = new ConcatenateColumn
{
SourceColumns = { col.SourceColumn },
OutputColumn = col.OutputColumn
};

for (var i = 0; i < cte.Columns.Count; i++)
plan.ColumnSet[i].OutputColumn = cte.Columns[i].Value;
recurseConcat.ColumnSet.Add(concatCol);
}

for (var i = 0; i < plan.ColumnSet.Count; i++)
recurseConcat.ColumnSet.Add(new ConcatenateColumn
{
if (plan.ColumnSet[i].OutputColumn == null)
throw new NotSupportedQueryFragmentException($"No column name was specified for column {i+1} of '{cteValidator.Name}'", cte);
}
SourceColumns = { recursionDepthField },
OutputColumn = recursionDepthField
});

_cteSubplans.Add(cte.ExpressionName.Value, new AliasNode(plan, cte.ExpressionName, _nodeContext));
}
else
{
throw new NotSupportedQueryFragmentException("Recursive CTEs are not yet supported", cte);
// Add an IndexSpool node in stack mode to enable the recursion
var recurseIndexStack = new IndexSpoolNode
{
Source = recurseConcat,
// TODO: WithStack = true
};

// Pull the same records into the recursive loop
var recurseTableSpool = new TableSpoolNode
{
// TODO: Producer = recurseIndexStack
};

// Increment the depth
var incrementedDepthField = _nodeContext.GetExpressionName();
var incrementRecursionDepthComputeScalar = new ComputeScalarNode
{
Source = recurseTableSpool,
Columns =
{
[incrementedDepthField] = new BinaryExpression
{
FirstExpression = recursionDepthField.ToColumnReference(),
BinaryExpressionType = BinaryExpressionType.Add,
SecondExpression = new IntegerLiteral { Value = "1" }
}
}
};

// Use a nested loop to pass through the records to the recusive queries
var recurseLoop = new NestedLoopNode
{
LeftSource = incrementRecursionDepthComputeScalar,
// TODO: Capture all CTE fields in the outer references
JoinType = QualifiedJoinType.Inner,
};

// Ensure we don't get stuck in an infinite loop
var assert = new AssertNode
{
Source = recurseLoop,
Assertion = e =>
{
var depth = e.GetAttributeValue<SqlInt32>(incrementedDepthField);
return depth.Value < 100;
},
ErrorMessage = "Recursion depth exceeded"
};

// Combine the recursion results into the main results
recurseConcat.Sources.Add(assert);

// TODO: Update the sources for each field in the concat node
recurseConcat.ColumnSet.Last().SourceColumns.Add(incrementedDepthField);

anchorQuery.Source = incrementRecursionDepthComputeScalar;
}

_cteSubplans.Add(cte.ExpressionName.Value, new AliasNode(plan, cte.ExpressionName, _nodeContext));
}
}
}
Expand Down
29 changes: 28 additions & 1 deletion MarkMpn.Sql4Cds.Engine/Visitors/CteValidatorVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace MarkMpn.Sql4Cds.Engine.Visitors
/// <remarks>
/// https://learn.microsoft.com/en-us/sql/t-sql/queries/with-common-table-expression-transact-sql?view=sql-server-ver16
/// </remarks>
class CteValidatorVisitor : TSqlFragmentVisitor
class CteValidatorVisitor : TSqlConcreteFragmentVisitor
{
private int _cteReferenceCount;
private FunctionCall _scalarAggregate;
Expand All @@ -22,6 +22,10 @@ class CteValidatorVisitor : TSqlFragmentVisitor

public bool IsRecursive { get; private set; }

public QueryExpression AnchorQuery { get; private set; }

public List<QueryExpression> RecursiveQueries { get; } = new List<QueryExpression>();

public override void Visit(CommonTableExpression node)
{
Name = node.ExpressionName.Value;
Expand Down Expand Up @@ -69,11 +73,34 @@ public override void ExplicitVisit(BinaryQueryExpression node)
{
base.ExplicitVisit(node);

if (!IsRecursive)
AnchorQuery = node;

// UNION ALL is the only set operator allowed between the last anchor member and first recursive member, and when combining multiple recursive members.
if (IsRecursive && (node.BinaryQueryExpressionType != BinaryQueryExpressionType.Union || !node.All))
throw new NotSupportedQueryFragmentException($"Recursive common table expression '{Name}' does not contain a top-level UNION ALL operator", node);
}

public override void ExplicitVisit(QuerySpecification node)
{
base.ExplicitVisit(node);

if (!IsRecursive)
AnchorQuery = node;
else
RecursiveQueries.Add(node);
}

public override void ExplicitVisit(QueryParenthesisExpression node)
{
base.ExplicitVisit(node);

if (!IsRecursive)
AnchorQuery = node;
else
RecursiveQueries.Add(node);
}

public override void Visit(QuerySpecification node)
{
_scalarAggregate = null;
Expand Down

0 comments on commit 180694d

Please sign in to comment.