Skip to content

Commit

Permalink
Fixed errors with CTE rewrites using alises in recursive references a…
Browse files Browse the repository at this point in the history
…nd unnecessary quotes in column references
  • Loading branch information
MarkMpn committed Nov 21, 2023
1 parent f7cd407 commit f6ec7cc
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1826,7 +1826,7 @@ public static string EscapeIdentifier(this string identifier)
/// <remarks>
/// https://learn.microsoft.com/en-us/sql/relational-databases/databases/database-identifiers?view=sql-server-ver16&redirectedfrom=MSDN#rules-for-regular-identifiers
/// </remarks>
private static bool IsValidIdentifier(string identifier)
public static bool IsValidIdentifier(this string identifier)
{
if (String.IsNullOrEmpty(identifier))
return false;
Expand Down
34 changes: 33 additions & 1 deletion MarkMpn.Sql4Cds.Engine/TSqlFragmentExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using Microsoft.SqlServer.TransactSql.ScriptDom;
using MarkMpn.Sql4Cds.Engine.ExecutionPlan;
using Microsoft.SqlServer.TransactSql.ScriptDom;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Linq.Expressions;

Expand Down Expand Up @@ -31,6 +33,36 @@ public static string ToSql(this TSqlFragment fragment)
return sql;
}

/// <summary>
/// Converts a <see cref="TSqlFragment"/> to the corresponding SQL string in a standardised way
/// </summary>
/// <param name="fragment">The SQL DOM fragment to convert</param>
/// <returns>The SQL string that the fragment can be parsed from</returns>
public static string ToNormalizedSql(this TSqlFragment fragment)
{
var tokens = new Sql160ScriptGenerator().GenerateTokens(fragment);

using (var writer = new StringWriter())
{
foreach (var token in tokens)
{
if (token.TokenType == TSqlTokenType.Identifier)
{
var value = Identifier.DecodeIdentifier(token.Text, out var quoteType);

if (quoteType != QuoteType.NotQuoted && value.IsValidIdentifier())
token.Text = value;
}

writer.Write(token.Text);
}

writer.Flush();

return writer.ToString();
}
}

/// <summary>
/// Creates a clone of a <see cref="TSqlFragment"/>
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace MarkMpn.Sql4Cds.Engine.Visitors
/// </summary>
class RemoveRecursiveCTETableReferencesVisitor : TSqlConcreteFragmentVisitor
{
private readonly string _name;
private string _name;
private readonly string[] _columnNames;
private readonly Dictionary<string, string> _outerReferences;
private BooleanExpression _joinPredicate;
Expand All @@ -34,7 +34,13 @@ private bool IsRecursiveReference(TableReference tableReference)
if (namedTable.SchemaObject.Identifiers.Count != 1)
return false;

return namedTable.SchemaObject.BaseIdentifier.Value.Equals(_name, StringComparison.OrdinalIgnoreCase);
if (!namedTable.SchemaObject.BaseIdentifier.Value.Equals(_name, StringComparison.OrdinalIgnoreCase))
return false;

if (namedTable.Alias != null)
_name = namedTable.Alias.Value;

return true;
}

private InlineDerivedTable CreateInlineDerivedTable()
Expand Down
6 changes: 3 additions & 3 deletions MarkMpn.Sql4Cds.Engine/Visitors/RewriteVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class RewriteVisitor : RewriteVisitorBase
public RewriteVisitor(IDictionary<ScalarExpression,string> rewrites)
{
_mappings = rewrites
.GroupBy(kvp => kvp.Key.ToSql(), StringComparer.OrdinalIgnoreCase)
.GroupBy(kvp => kvp.Key.ToNormalizedSql(), StringComparer.OrdinalIgnoreCase)
.ToDictionary(
g => g.Key,
g => (ScalarExpression) g.First().Value.ToColumnReference(),
Expand All @@ -37,7 +37,7 @@ public RewriteVisitor(IDictionary<ScalarExpression,string> rewrites)
public RewriteVisitor(IDictionary<ScalarExpression,ScalarExpression> rewrites)
{
_mappings = rewrites
.GroupBy(kvp => kvp.Key.ToSql(), StringComparer.OrdinalIgnoreCase)
.GroupBy(kvp => kvp.Key.ToNormalizedSql(), StringComparer.OrdinalIgnoreCase)
.ToDictionary(
g => g.Key,
g => g.First().Value,
Expand All @@ -51,7 +51,7 @@ protected override ScalarExpression ReplaceExpression(ScalarExpression expressio
if (expression == null)
return null;

if (_mappings.TryGetValue(expression.ToSql(), out var column))
if (_mappings.TryGetValue(expression.ToNormalizedSql(), out var column))
{
name = (column as ColumnReferenceExpression)?.MultiPartIdentifier?.Identifiers?.Last()?.Value;
return column;
Expand Down

0 comments on commit f6ec7cc

Please sign in to comment.