Skip to content

Commit

Permalink
Extended type checking for join comparisons
Browse files Browse the repository at this point in the history
Fixes #520
  • Loading branch information
MarkMpn committed Jul 25, 2024
1 parent b2e246f commit 63c13f9
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 39 deletions.
51 changes: 51 additions & 0 deletions MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5825,6 +5825,57 @@ public void CollationConflict()
planBuilder.Build(query, null, out _);
}

[TestMethod]
public void CollationConflictJoin()
{
var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" });
var query = "SELECT * FROM prod.dbo.account p INNER JOIN french.dbo.account f ON p.name = f.name";

try
{
planBuilder.Build(query, null, out _);
Assert.Fail();
}
catch (NotSupportedQueryFragmentException ex)
{
Assert.AreEqual(468, ex.Errors.Single().Number);
}
}

[TestMethod]
public void TypeConflictJoin()
{
var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" });
var query = "SELECT * FROM account p INNER JOIN account f ON p.accountid = f.turnover";

try
{
planBuilder.Build(query, null, out _);
Assert.Fail();
}
catch (NotSupportedQueryFragmentException ex)
{
Assert.AreEqual(206, ex.Errors.Single().Number);
}
}

[TestMethod]
public void TypeConflictCrossInstanceJoin()
{
var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" });
var query = "SELECT * FROM prod.dbo.account p INNER JOIN french.dbo.account f ON p.accountid = f.turnover";

try
{
planBuilder.Build(query, null, out _);
Assert.Fail();
}
catch (NotSupportedQueryFragmentException ex)
{
Assert.AreEqual(206, ex.Errors.Single().Number);
}
}

[TestMethod]
public void ExplicitCollation()
{
Expand Down
14 changes: 14 additions & 0 deletions MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsError.cs
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,20 @@ private static string GetTypeName(DataTypeReference type)

return ((UserDataTypeReference)type).Name.ToSql();
}

/// <summary>
/// Creates a copy of this error for a specific fragment
/// </summary>
/// <param name="expression">The fragment the error should be applied to</param>
/// <returns></returns>
/// <exception cref="NotImplementedException"></exception>
internal Sql4CdsError ForFragment(BooleanComparisonExpression expression)
{
if (expression == null)
return this;

return new Sql4CdsError(Class, -1, Number, Procedure, Server, State, Message, expression);
}
}

/// <summary>
Expand Down
29 changes: 28 additions & 1 deletion MarkMpn.Sql4Cds.Engine/ExecutionPlan/FoldableJoinNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ public ColumnReferenceExpression RightAttribute
[DisplayName("Right Attributes")]
public List<ColumnReferenceExpression> RightAttributes { get; } = new List<ColumnReferenceExpression>();

internal List<BooleanComparisonExpression> Expressions { get; } = new List<BooleanComparisonExpression>();

/// <summary>
/// The type of comparison that is used for the two inputs
/// </summary>
Expand All @@ -87,7 +89,7 @@ public ColumnReferenceExpression RightAttribute

public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList<OptimizerHint> hints)
{
// For inner joins, additional join criteria are eqivalent to doing the join without them and then applying the filter
// For inner joins, additional join criteria are equivalent to doing the join without them and then applying the filter
// We've already got logic in the Filter node for efficiently folding those queries, so split them out and let it do
// what it can
if (JoinType == QualifiedJoinType.Inner && AdditionalJoinCriteria != null)
Expand All @@ -107,7 +109,13 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext
RightSource.Parent = this;

var leftSchema = LeftSource.GetSchema(context);
var leftCompilationContext = new ExpressionCompilationContext(context, leftSchema, null);
var rightSchema = RightSource.GetSchema(context);
var rightCompilationContext = new ExpressionCompilationContext(context, rightSchema, null);

// Check the types of the comparisons
for (var i = 0; i < LeftAttributes.Count; i++)
ValidateComparison(context, leftCompilationContext, rightCompilationContext, i);

FoldDefinedValues(rightSchema);

Expand Down Expand Up @@ -172,6 +180,25 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext
return this;
}

private void ValidateComparison(NodeCompilationContext context, ExpressionCompilationContext leftCompilationContext, ExpressionCompilationContext rightCompilationContext, int i)
{
LeftAttributes[i].GetType(leftCompilationContext, out var leftColType);
RightAttributes[i].GetType(rightCompilationContext, out var rightColType);

var expression = i < Expressions.Count ? Expressions[i] : null;
if (!SqlTypeConverter.CanMakeConsistentTypes(leftColType, rightColType, context.PrimaryDataSource, null, "equals", out var keyType))
throw new NotSupportedQueryFragmentException(Sql4CdsError.TypeClash(expression, leftColType, rightColType));

if (keyType is SqlDataTypeReferenceWithCollation keyTypeWithCollation && keyTypeWithCollation.CollationLabel == CollationLabel.NoCollation)
throw new NotSupportedQueryFragmentException(keyTypeWithCollation.CollationConflictError.ForFragment(expression));

ValidateComparison(expression, keyType, leftColType, leftCompilationContext, rightColType, rightCompilationContext, i);
}

protected virtual void ValidateComparison(BooleanComparisonExpression expression, DataTypeReference keyType, DataTypeReference leftColType, ExpressionCompilationContext leftCompilationContext, DataTypeReference rightColType, ExpressionCompilationContext rightCompilationContext, int i)
{
}

private IDataExecutionPlanNodeInternal PrependFilters(IDataExecutionPlanNodeInternal folded, NodeCompilationContext context, IList<OptimizerHint> hints, params FilterNode[] filters)
{
foreach (var filter in filters)
Expand Down
80 changes: 44 additions & 36 deletions MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashJoinNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ public override bool Equals(object obj)
}
}

private Func<ExpressionExecutionContext, object>[] _leftKeyAccessors;
private Func<ExpressionExecutionContext, object>[] _rightKeyAccessors;
private IDictionary<CompoundKey, List<OuterRecord>> _hashTable;

protected override IEnumerable<Entity> ExecuteInternal(NodeExecutionContext context)
Expand All @@ -68,46 +70,13 @@ protected override IEnumerable<Entity> ExecuteInternal(NodeExecutionContext cont

// Build the hash table
var leftSchema = LeftSource.GetSchema(context);
var leftCompilationContext = new ExpressionCompilationContext(context, leftSchema, null);
var rightSchema = RightSource.GetSchema(context);
var rightCompilationContext = new ExpressionCompilationContext(context, rightSchema, null);

Func<ExpressionExecutionContext, object>[] leftKeyAccessors = new Func<ExpressionExecutionContext, object>[LeftAttributes.Count];
Func<ExpressionExecutionContext, object>[] rightKeyAccessors = new Func<ExpressionExecutionContext, object>[LeftAttributes.Count];

for (var i = 0; i < LeftAttributes.Count; i++)
{
LeftAttributes[i].GetType(leftCompilationContext, out var leftColType);
RightAttributes[i].GetType(rightCompilationContext, out var rightColType);

if (!SqlTypeConverter.CanMakeConsistentTypes(leftColType, rightColType, context.PrimaryDataSource, null, null, out var keyType))
throw new QueryExecutionException($"Cannot match key types {leftColType.ToSql()} and {rightColType.ToSql()}");

Identifier keyTypeCollation = null;

if (keyType is SqlDataTypeReferenceWithCollation keyTypeWithCollation)
keyTypeCollation = new Identifier { Value = keyTypeWithCollation.Collation.Name };

var leftKeyAccessor = (ScalarExpression)LeftAttributes[i];
if (!leftColType.IsSameAs(keyType))
leftKeyAccessor = new ConvertCall { Parameter = leftKeyAccessor, DataType = keyType, Collation = keyTypeCollation };
var leftKeyConverter = leftKeyAccessor.Compile(leftCompilationContext);

var rightKeyAccessor = (ScalarExpression)RightAttributes[i];
if (!rightColType.IsSameAs(keyType))
rightKeyAccessor = new ConvertCall { Parameter = rightKeyAccessor, DataType = keyType, Collation = keyTypeCollation };
var rightKeyConverter = rightKeyAccessor.Compile(rightCompilationContext);

leftKeyAccessors[i] = leftKeyConverter;
rightKeyAccessors[i] = rightKeyConverter;
}

var expressionContext = new ExpressionExecutionContext(context);

foreach (var entity in LeftSource.Execute(context))
{
expressionContext.Entity = entity;
var key = new CompoundKey(leftKeyAccessors.Select(accessor => accessor(expressionContext)));
var key = new CompoundKey(_leftKeyAccessors.Select(accessor => accessor(expressionContext)));

if (!_hashTable.TryGetValue(key, out var list))
{
Expand All @@ -122,7 +91,7 @@ protected override IEnumerable<Entity> ExecuteInternal(NodeExecutionContext cont
foreach (var entity in RightSource.Execute(context))
{
expressionContext.Entity = entity;
var key = new CompoundKey(rightKeyAccessors.Select(accessor => accessor(expressionContext)));
var key = new CompoundKey(_rightKeyAccessors.Select(accessor => accessor(expressionContext)));

var matched = false;

Expand Down Expand Up @@ -169,13 +138,21 @@ protected override IReadOnlyList<string> GetSortOrder(INodeSchema outerSchema, I

public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList<OptimizerHint> hints)
{
_leftKeyAccessors = new Func<ExpressionExecutionContext, object>[LeftAttributes.Count];
_rightKeyAccessors = new Func<ExpressionExecutionContext, object>[LeftAttributes.Count];

var folded = base.FoldQuery(context, hints);

if (folded != this)
return folded;

var leftSchema = LeftSource.GetSchema(context);
var leftCompilationContext = new ExpressionCompilationContext(context, leftSchema, null);
var rightSchema = RightSource.GetSchema(context);
var rightCompilationContext = new ExpressionCompilationContext(context, rightSchema, null);

if (SemiJoin)
return folded;
return this;

// If we can't fold this query, try to make sure the smaller table is used as the left input to reduce the
// number of records held in memory in the hash table
Expand Down Expand Up @@ -212,6 +189,32 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext
return this;
}

protected override void ValidateComparison(BooleanComparisonExpression expression, DataTypeReference keyType, DataTypeReference leftColType, ExpressionCompilationContext leftCompilationContext, DataTypeReference rightColType, ExpressionCompilationContext rightCompilationContext, int i)
{
Identifier keyTypeCollation = null;

if (keyType is SqlDataTypeReferenceWithCollation keyTypeWithCollation)
{
if (keyTypeWithCollation.CollationLabel == CollationLabel.NoCollation)
throw new NotSupportedQueryFragmentException(keyTypeWithCollation.CollationConflictError.ForFragment(expression));

keyTypeCollation = new Identifier { Value = keyTypeWithCollation.Collation.Name };
}

var leftKeyAccessor = (ScalarExpression)LeftAttributes[i];
if (!leftColType.IsSameAs(keyType))
leftKeyAccessor = new ConvertCall { Parameter = leftKeyAccessor, DataType = keyType, Collation = keyTypeCollation };
var leftKeyConverter = leftKeyAccessor.Compile(leftCompilationContext);

var rightKeyAccessor = (ScalarExpression)RightAttributes[i];
if (!rightColType.IsSameAs(keyType))
rightKeyAccessor = new ConvertCall { Parameter = rightKeyAccessor, DataType = keyType, Collation = keyTypeCollation };
var rightKeyConverter = rightKeyAccessor.Compile(rightCompilationContext);

_leftKeyAccessors[i] = leftKeyConverter;
_rightKeyAccessors[i] = rightKeyConverter;
}

public override object Clone()
{
var clone = new HashJoinNode
Expand All @@ -225,6 +228,8 @@ public override object Clone()
OutputRightSchema = OutputRightSchema,
ComparisonType = ComparisonType,
AntiJoin = AntiJoin,
_leftKeyAccessors = _leftKeyAccessors,
_rightKeyAccessors = _rightKeyAccessors
};

foreach (var attr in LeftAttributes)
Expand All @@ -233,6 +238,9 @@ public override object Clone()
foreach (var attr in RightAttributes)
clone.RightAttributes.Add(attr);

foreach (var expr in Expressions)
clone.Expressions.Add(expr);

foreach (var kvp in DefinedValues)
clone.DefinedValues.Add(kvp);

Expand Down
14 changes: 12 additions & 2 deletions MarkMpn.Sql4Cds.Engine/ExecutionPlan/MergeJoinNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,20 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext
{
LeftSource = LeftSource,
RightSource = RightSource,
LeftAttribute = LeftAttribute,
RightAttribute = RightAttribute,
JoinType = JoinType,
AdditionalJoinCriteria = AdditionalJoinCriteria,
SemiJoin = SemiJoin
};

foreach (var attr in LeftAttributes)
hashJoin.LeftAttributes.Add(attr);

foreach (var attr in RightAttributes)
hashJoin.RightAttributes.Add(attr);

foreach (var expr in Expressions)
hashJoin.Expressions.Add(expr);

foreach (var kvp in DefinedValues)
hashJoin.DefinedValues.Add(kvp);

Expand Down Expand Up @@ -429,6 +436,9 @@ public override object Clone()
foreach (var attr in RightAttributes)
clone.RightAttributes.Add(attr);

foreach (var expr in Expressions)
clone.Expressions.Add(expr);

foreach (var kvp in DefinedValues)
clone.DefinedValues.Add(kvp);

Expand Down
3 changes: 3 additions & 0 deletions MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4648,6 +4648,7 @@ private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference refe
LeftAttribute = joinConditionVisitor.LhsKey.Clone(),
RightSource = rhs,
RightAttribute = joinConditionVisitor.RhsKey.Clone(),
Expressions = { joinConditionVisitor.JoinCondition },
JoinType = join.QualifiedJoinType,
AdditionalJoinCriteria = join.SearchCondition.RemoveCondition(joinConditionVisitor.JoinCondition).Clone()
};
Expand All @@ -4660,6 +4661,7 @@ private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference refe
LeftAttribute = joinConditionVisitor.RhsKey.Clone(),
RightSource = lhs,
RightAttribute = joinConditionVisitor.LhsKey.Clone(),
Expressions = { joinConditionVisitor.JoinCondition },
AdditionalJoinCriteria = join.SearchCondition.RemoveCondition(joinConditionVisitor.JoinCondition).Clone()
};

Expand Down Expand Up @@ -4690,6 +4692,7 @@ private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference refe
LeftAttribute = joinConditionVisitor.LhsKey.Clone(),
RightSource = rhs,
RightAttribute = joinConditionVisitor.RhsKey.Clone(),
Expressions = { joinConditionVisitor.JoinCondition },
JoinType = join.QualifiedJoinType,
AdditionalJoinCriteria = join.SearchCondition.RemoveCondition(joinConditionVisitor.JoinCondition).Clone()
};
Expand Down

0 comments on commit 63c13f9

Please sign in to comment.