From 48692a85f2b926b95797f2c42916af5feb321583 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Wed, 15 Mar 2023 19:56:23 +0000 Subject: [PATCH 01/34] Collation progress --- .../CollationTests.cs | 69 ++++ .../ExecutionPlanNodeTests.cs | 60 ++-- .../ExecutionPlanTests.cs | 30 +- .../MarkMpn.Sql4Cds.Engine.Tests.csproj | 1 + .../Ado/Sql4CdsConnection.cs | 4 +- .../Ado/Sql4CdsParameter.cs | 18 +- MarkMpn.Sql4Cds.Engine/Collation.cs | 193 +++++++++++ MarkMpn.Sql4Cds.Engine/DataSource.cs | 49 +++ MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs | 133 +++++++- .../ExecutionPlan/BaseDataNode.cs | 20 +- .../ExecutionPlan/BaseDmlNode.cs | 17 +- .../ExecutionPlan/ConcatenateNode.cs | 2 +- .../ExecutionPlan/ExpressionExtensions.cs | 306 ++++++++++-------- .../ExecutionPlan/FetchXmlScan.cs | 32 +- .../ExecutionPlan/FilterNode.cs | 32 +- .../ExecutionPlan/GlobalOptionSetQueryNode.cs | 2 +- .../ExecutionPlan/MergeJoinNode.cs | 11 +- .../ExecutionPlan/MetadataQueryNode.cs | 4 +- .../ExecutionPlan/SqlTypeConverter.cs | 40 ++- .../ExecutionPlan/UpdateNode.cs | 4 +- .../MarkMpn.Sql4Cds.Engine.projitems | 7 + MarkMpn.Sql4Cds.Engine/MetadataExtensions.cs | 16 +- .../Resources/CollationNameToLCID.txt | 135 ++++++++ 23 files changed, 934 insertions(+), 251 deletions(-) create mode 100644 MarkMpn.Sql4Cds.Engine.Tests/CollationTests.cs create mode 100644 MarkMpn.Sql4Cds.Engine/Collation.cs create mode 100644 MarkMpn.Sql4Cds.Engine/Resources/CollationNameToLCID.txt diff --git a/MarkMpn.Sql4Cds.Engine.Tests/CollationTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/CollationTests.cs new file mode 100644 index 00000000..36770021 --- /dev/null +++ b/MarkMpn.Sql4Cds.Engine.Tests/CollationTests.cs @@ -0,0 +1,69 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace MarkMpn.Sql4Cds.Engine.Tests +{ + [TestClass] + public class CollationTests + { + [DataRow("Latin1_General")] // Missing case & accent sensitivity + [DataRow("Latin1_General_CI")] // Missing accent sensitivity + [DataRow("Latin1_General_AI")] // Missing case sensitivity + [DataRow("Latin1_General_CS_CI_AI")] // Conflicting case sensitivity + [DataRow("Latin1_General_CS_AS_AI")] // Conflicting accent sensitivity + [DataRow("Latin1_General_BIN_CS_AS")] // Binary comparision can't be combined + [DataRow("Latin1_General_BIN2_CS_AS")] // Binary comparision can't be combined + [DataRow("Latin2_General_CS_AS")] // Invalid name + [DataTestMethod] + public void InvalidCollations(string name) + { + Assert.IsFalse(Collation.TryParse(name, out _)); + } + + [DataRow(true, true)] + [DataRow(true, false)] + [DataRow(false, true)] + [DataRow(false, false)] + [DataTestMethod] + public void Latin1_General(bool cs, bool @as) + { + Assert.IsTrue(Collation.TryParse($"Latin1_General_{(cs?"CS":"CI")}_{(@as ? "AS" : "AI")}", out var coll)); + + var s1 = coll.ToSqlString("hello"); + var s2 = coll.ToSqlString("Héllo"); + var s3 = coll.ToSqlString("héllo"); + var s4 = coll.ToSqlString("Hello"); + + if (!cs && !@as) + Assert.AreEqual(s1, s2); + else + Assert.AreNotEqual(s1, s2); + + if (!@as) + { + Assert.AreEqual(s1, s3); + Assert.AreEqual(s2, s4); + } + else + { + Assert.AreNotEqual(s1, s3); + Assert.AreNotEqual(s2, s4); + } + + if (!cs) + { + Assert.AreEqual(s1, s4); + Assert.AreEqual(s2, s3); + } + else + { + Assert.AreNotEqual(s1, s4); + Assert.AreNotEqual(s2, s3); + } + } + } +} diff --git a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanNodeTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanNodeTests.cs index 9f4343fb..41721271 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanNodeTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanNodeTests.cs @@ -32,7 +32,7 @@ public void ConstantScanTest() }, Schema = { - ["firstname"] = typeof(SqlString).ToSqlType() + ["firstname"] = typeof(SqlString).ToSqlType(null) }, Alias = "test" }; @@ -63,7 +63,7 @@ public void FilterNodeTest() }, Schema = { - ["firstname"] = typeof(SqlString).ToSqlType() + ["firstname"] = typeof(SqlString).ToSqlType(null) }, Alias = "test" }, @@ -112,8 +112,8 @@ public void MergeJoinInnerTest() }, Schema = { - ["key"] = typeof(SqlInt32).ToSqlType(), - ["firstname"] = typeof(SqlString).ToSqlType() + ["key"] = typeof(SqlInt32).ToSqlType(null), + ["firstname"] = typeof(SqlString).ToSqlType(null) }, Alias = "f" }, @@ -146,8 +146,8 @@ public void MergeJoinInnerTest() }, Schema = { - ["key"] = typeof(SqlInt32).ToSqlType(), - ["lastname"] = typeof(SqlString).ToSqlType() + ["key"] = typeof(SqlInt32).ToSqlType(null), + ["lastname"] = typeof(SqlString).ToSqlType(null) }, Alias = "l" }, @@ -192,8 +192,8 @@ public void MergeJoinLeftOuterTest() }, Schema = { - ["key"] = typeof(SqlInt32).ToSqlType(), - ["firstname"] = typeof(SqlString).ToSqlType() + ["key"] = typeof(SqlInt32).ToSqlType(null), + ["firstname"] = typeof(SqlString).ToSqlType(null) }, Alias = "f" }, @@ -226,8 +226,8 @@ public void MergeJoinLeftOuterTest() }, Schema = { - ["key"] = typeof(SqlInt32).ToSqlType(), - ["lastname"] = typeof(SqlString).ToSqlType() + ["key"] = typeof(SqlInt32).ToSqlType(null), + ["lastname"] = typeof(SqlString).ToSqlType(null) }, Alias = "l" }, @@ -274,8 +274,8 @@ public void MergeJoinRightOuterTest() }, Schema = { - ["key"] = typeof(SqlInt32).ToSqlType(), - ["firstname"] = typeof(SqlString).ToSqlType() + ["key"] = typeof(SqlInt32).ToSqlType(null), + ["firstname"] = typeof(SqlString).ToSqlType(null) }, Alias = "f" }, @@ -308,8 +308,8 @@ public void MergeJoinRightOuterTest() }, Schema = { - ["key"] = typeof(SqlInt32).ToSqlType(), - ["lastname"] = typeof(SqlString).ToSqlType() + ["key"] = typeof(SqlInt32).ToSqlType(null), + ["lastname"] = typeof(SqlString).ToSqlType(null) }, Alias = "l" }, @@ -354,7 +354,7 @@ public void AssertionTest() }, Schema = { - ["name"] = typeof(SqlString).ToSqlType() + ["name"] = typeof(SqlString).ToSqlType(null) }, Alias = "test" }, @@ -393,8 +393,8 @@ public void ComputeScalarTest() }, Schema = { - ["value1"] = typeof(SqlInt32).ToSqlType(), - ["value2"] = typeof(SqlInt32).ToSqlType() + ["value1"] = typeof(SqlInt32).ToSqlType(null), + ["value2"] = typeof(SqlInt32).ToSqlType(null) } }, Columns = @@ -455,8 +455,8 @@ public void DistinctTest() }, Schema = { - ["value1"] = typeof(SqlInt32).ToSqlType(), - ["value2"] = typeof(SqlInt32).ToSqlType() + ["value1"] = typeof(SqlInt32).ToSqlType(null), + ["value2"] = typeof(SqlInt32).ToSqlType(null) }, Alias = "test" } @@ -497,8 +497,8 @@ public void DistinctCaseInsensitiveTest() }, Schema = { - ["value1"] = typeof(SqlString).ToSqlType(), - ["value2"] = typeof(SqlInt32).ToSqlType() + ["value1"] = typeof(SqlString).ToSqlType(null), + ["value2"] = typeof(SqlInt32).ToSqlType(null) }, Alias = "test" } @@ -554,9 +554,9 @@ public void SortNodeTest() }, Schema = { - ["value1"] = typeof(SqlString).ToSqlType(), - ["value2"] = typeof(SqlInt32).ToSqlType(), - ["expectedorder"] = typeof(SqlInt32).ToSqlType() + ["value1"] = typeof(SqlString).ToSqlType(null), + ["value2"] = typeof(SqlInt32).ToSqlType(null), + ["expectedorder"] = typeof(SqlInt32).ToSqlType(null) }, Alias = "test" } @@ -619,9 +619,9 @@ public void SortNodePresortedTest() }, Schema = { - ["value1"] = typeof(SqlString).ToSqlType(), - ["value2"] = typeof(SqlInt32).ToSqlType(), - ["expectedorder"] = typeof(SqlInt32).ToSqlType() + ["value1"] = typeof(SqlString).ToSqlType(null), + ["value2"] = typeof(SqlInt32).ToSqlType(null), + ["expectedorder"] = typeof(SqlInt32).ToSqlType(null) }, Alias = "test" } @@ -652,7 +652,7 @@ public void TableSpoolTest() }, Schema = { - ["value1"] = typeof(SqlInt32).ToSqlType() + ["value1"] = typeof(SqlInt32).ToSqlType(null) }, Alias = "test" }; @@ -690,7 +690,7 @@ public void CaseInsenstiveHashMatchAggregateNodeTest() }, Schema = { - ["value1"] = typeof(SqlString).ToSqlType() + ["value1"] = typeof(SqlString).ToSqlType(null) }, Alias = "src" }; @@ -823,7 +823,7 @@ private HashMatchAggregateNode CreateAggregateTest(params int[] values) { Schema = { - ["i"] = typeof(SqlInt32).ToSqlType() + ["i"] = typeof(SqlInt32).ToSqlType(null) }, Alias = "l" }; diff --git a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs index af053fd5..6b78f1c2 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Data; using System.Data.SqlTypes; +using System.Globalization; using System.IO; using System.Linq; using System.Reflection; @@ -3942,7 +3943,7 @@ INSERT INTO account (name) VALUES ('one') var parameters = new Dictionary { - ["@param1"] = typeof(SqlInt32).ToSqlType() + ["@param1"] = typeof(SqlInt32).ToSqlType(null) }; var plans = planBuilder.Build(query, parameters, out _); @@ -3975,7 +3976,7 @@ INSERT INTO account (name) VALUES (@param1) var parameters = new Dictionary { - ["@param1"] = typeof(SqlInt32).ToSqlType() + ["@param1"] = typeof(SqlInt32).ToSqlType(null) }; var plans = planBuilder.Build(query, parameters, out _); @@ -4004,7 +4005,7 @@ INSERT INTO account (name) VALUES (@param1) var parameters = new Dictionary { - ["@param1"] = typeof(SqlString).ToSqlType() + ["@param1"] = typeof(SqlString).ToSqlType(null) }; var plans = planBuilder.Build(query, parameters, out _); @@ -5231,5 +5232,28 @@ public void FoldMultipleJoinConditionsWithKnownValue() "); } + + [TestMethod] + public void ExplicitCollation() + { + var metadata = new AttributeMetadataCache(_service); + var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + + var query = "SELECT 'abc' COLLATE French_CI_AS"; + + planBuilder.Build(query, null, out _); + } + + [ExpectedException(typeof(NotSupportedQueryFragmentException))] + [TestMethod] + public void TwoExplicitCollationsError() + { + var metadata = new AttributeMetadataCache(_service); + var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + + var query = "SELECT ('abc' COLLATE French_CI_AS) COLLATE French_CS_AS"; + + planBuilder.Build(query, null, out _); + } } } diff --git a/MarkMpn.Sql4Cds.Engine.Tests/MarkMpn.Sql4Cds.Engine.Tests.csproj b/MarkMpn.Sql4Cds.Engine.Tests/MarkMpn.Sql4Cds.Engine.Tests.csproj index aac0df38..3d7a48d4 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/MarkMpn.Sql4Cds.Engine.Tests.csproj +++ b/MarkMpn.Sql4Cds.Engine.Tests/MarkMpn.Sql4Cds.Engine.Tests.csproj @@ -151,6 +151,7 @@ + diff --git a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsConnection.cs b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsConnection.cs index 47a7f5cb..e5ccaf7c 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsConnection.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsConnection.cs @@ -67,8 +67,8 @@ public Sql4CdsConnection(IDictionary dataSources) _globalVariableTypes = new Dictionary(StringComparer.OrdinalIgnoreCase) { - ["@@IDENTITY"] = typeof(SqlEntityReference).ToSqlType(), - ["@@ROWCOUNT"] = typeof(SqlInt32).ToSqlType() + ["@@IDENTITY"] = DataTypeHelpers.EntityReference, + ["@@ROWCOUNT"] = DataTypeHelpers.Int }; _globalVariableValues = new Dictionary(StringComparer.OrdinalIgnoreCase) { diff --git a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsParameter.cs b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsParameter.cs index cbee0fbf..a20800cc 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsParameter.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsParameter.cs @@ -27,6 +27,12 @@ public Sql4CdsParameter(string name, object value) Value = value; } + /// + public int LocaleId { get; set; } = Collation.USEnglish.LCID; + + /// + public SqlCompareOptions CompareInfo { get; set; } = Collation.USEnglish.CompareOptions; + public override DbType DbType { get @@ -143,11 +149,11 @@ internal DataTypeReference GetDataType() switch (DbType) { case DbType.AnsiString: - _dataType = DataTypeHelpers.VarChar(Size); + _dataType = DataTypeHelpers.VarChar(Size, new Collation(null, LocaleId, CompareInfo), CollationLabel.CoercibleDefault); break; case DbType.AnsiStringFixedLength: - _dataType = DataTypeHelpers.Char(Size); + _dataType = DataTypeHelpers.Char(Size, new Collation(null, LocaleId, CompareInfo), CollationLabel.CoercibleDefault); break; case DbType.Binary: @@ -220,11 +226,11 @@ internal DataTypeReference GetDataType() break; case DbType.String: - _dataType = DataTypeHelpers.NVarChar(Size); + _dataType = DataTypeHelpers.NVarChar(Size, new Collation(null, LocaleId, CompareInfo), CollationLabel.CoercibleDefault); break; case DbType.StringFixedLength: - _dataType = DataTypeHelpers.NChar(Size); + _dataType = DataTypeHelpers.NChar(Size, new Collation(null, LocaleId, CompareInfo), CollationLabel.CoercibleDefault); break; case DbType.Time: @@ -244,11 +250,11 @@ internal DataTypeReference GetDataType() break; case DbType.VarNumeric: - _dataType = DataTypeHelpers.NVarChar(Int32.MaxValue); + _dataType = DataTypeHelpers.NVarChar(Int32.MaxValue, new Collation(null, LocaleId, CompareInfo), CollationLabel.CoercibleDefault); break; case DbType.Xml: - _dataType = DataTypeHelpers.NVarChar(Int32.MaxValue); + _dataType = DataTypeHelpers.NVarChar(Int32.MaxValue, new Collation(null, LocaleId, CompareInfo), CollationLabel.CoercibleDefault); break; } } diff --git a/MarkMpn.Sql4Cds.Engine/Collation.cs b/MarkMpn.Sql4Cds.Engine/Collation.cs new file mode 100644 index 00000000..acb40582 --- /dev/null +++ b/MarkMpn.Sql4Cds.Engine/Collation.cs @@ -0,0 +1,193 @@ +using System; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.IO; +using System.Reflection; +using System.Text; + +namespace MarkMpn.Sql4Cds.Engine +{ + /// + /// Describes a collation to be used to compare strings + /// + class Collation + { + private static Dictionary _collationNameToLcid; + + static Collation() + { + _collationNameToLcid = new Dictionary(StringComparer.OrdinalIgnoreCase); + + using (var stream = Assembly.GetExecutingAssembly().GetManifestResourceStream("MarkMpn.Sql4Cds.Engine.resources.CollationNameToLCID.txt")) + using (var reader = new StreamReader(stream)) + { + string line; + + while ((line = reader.ReadLine()) != null) + { + var parts = line.Split('\t'); + _collationNameToLcid[parts[0]] = Int32.Parse(parts[1]); + } + } + } + + /// + /// Creates a collation using the locale ID and additional comparison options + /// + /// The name the collation was parsed from + /// The locale ID to use + /// Additional comparison options + public Collation(string name, int lcid, SqlCompareOptions compareOptions) + { + Name = name; + LCID = lcid; + CompareOptions = compareOptions; + } + + /// + /// Creates a collation using the locale ID and common additional options + /// + /// The locale ID to use + /// Indicates if comparisons are case sensitive + /// Indicates if comparisons are accent sensitive + public Collation(int lcid, bool caseSensitive, bool accentSensitive) + { + var compareOptions = SqlCompareOptions.IgnoreKanaType | SqlCompareOptions.IgnoreWidth; + + if (!caseSensitive) + compareOptions |= SqlCompareOptions.IgnoreCase; + + if (!accentSensitive) + compareOptions |= SqlCompareOptions.IgnoreNonSpace; + + LCID = lcid; + CompareOptions = compareOptions; + } + + /// + /// Returns the locale ID to use when comparing strings + /// + public int LCID { get; } + + /// + /// Returns the additional options to use when comparing strings + /// + public SqlCompareOptions CompareOptions { get; } + + /// + /// Returns the name of the collation + /// + /// + /// This will be null for the default collation + /// + public string Name { get; } + + /// + /// Returns the default collation to be used for system data + /// + public static Collation USEnglish { get; } = new Collation(1033, false, false); + + /// + /// Attempts to parse the name of a collation to the corresponding details + /// + /// The name of the collation to parse + /// The details of the collation parsed from the + /// true if the could be parsed, or false otherwise + public static bool TryParse(string name, out Collation coll) + { + var compareOptions = SqlCompareOptions.IgnoreKanaType | SqlCompareOptions.IgnoreWidth; + var parts = name.Split('_'); + var @as = false; + var cs = false; + + for (var i = parts.Length - 1; i >= 0; i--) + { + switch (parts[i].ToUpperInvariant()) + { + case "BIN": + compareOptions |= SqlCompareOptions.BinarySort; + break; + + case "BIN2": + compareOptions |= SqlCompareOptions.BinarySort2; + break; + + case "CI": + compareOptions |= SqlCompareOptions.IgnoreCase; + break; + + case "CS": + cs = true; + break; + + case "AI": + compareOptions |= SqlCompareOptions.IgnoreNonSpace; + break; + + case "AS": + @as = true; + break; + + case "KS": + compareOptions &= ~SqlCompareOptions.IgnoreKanaType; + break; + + case "WS": + compareOptions &= ~SqlCompareOptions.IgnoreWidth; + break; + + case "UTF8": + break; + + default: + // Check we've got sufficient and non-contradictory information + if ((compareOptions.HasFlag(SqlCompareOptions.BinarySort) || compareOptions.HasFlag(SqlCompareOptions.BinarySort2))) + { + // If BIN or BIN2 are set, other options shouldn't be set + if (i < parts.Length - 2) + break; + } + else + { + // Must specify case sensitivity + if (!compareOptions.HasFlag(SqlCompareOptions.IgnoreCase) && !cs) + break; + + // Must specify accent sensitivity + if (!compareOptions.HasFlag(SqlCompareOptions.IgnoreNonSpace) && !@as) + break; + + // Can't be both CS and CI + if (compareOptions.HasFlag(SqlCompareOptions.IgnoreCase) && cs) + break; + + // Can't be both AS and AI + if (compareOptions.HasFlag(SqlCompareOptions.IgnoreNonSpace) && @as) + break; + } + + var collationName = String.Join("_", parts, 0, i + 1); + + if (!_collationNameToLcid.TryGetValue(collationName, out var lcid)) + break; + + coll = new Collation(name, lcid, compareOptions); + return true; + } + } + + coll = null; + return false; + } + + /// + /// Applies the current collation to a string value + /// + /// The string value to apply the collation to + /// A new value with the given string and the current collation + public SqlString ToSqlString(string value) + { + return new SqlString(value, LCID, CompareOptions); + } + } +} diff --git a/MarkMpn.Sql4Cds.Engine/DataSource.cs b/MarkMpn.Sql4Cds.Engine/DataSource.cs index 35bf98ee..994a378f 100644 --- a/MarkMpn.Sql4Cds.Engine/DataSource.cs +++ b/MarkMpn.Sql4Cds.Engine/DataSource.cs @@ -45,6 +45,8 @@ public DataSource(IOrganizationService org) Name = name; TableSizeCache = new TableSizeCache(org, Metadata); MessageCache = new MessageCache(org, Metadata); + + DefaultCollation = LoadDefaultCollation(); } /// @@ -78,5 +80,52 @@ public DataSource() /// A cache of the messages that the instance supports /// public IMessageCache MessageCache { get; set; } + + /// + /// Returns the default collation used by this instance + /// + internal Collation DefaultCollation { get; } + + private Collation LoadDefaultCollation() + { + var qry = new QueryExpression("organization") + { + ColumnSet = new ColumnSet("lcid") + }; + var org = Connection.RetrieveMultiple(qry).Entities[0]; + var lcid = org.GetAttributeValue("lcid"); + + // Collation options are set based on the default language. Most are CI/AI but a few are not + // https://learn.microsoft.com/en-us/power-platform/admin/language-collations#language-and-associated-collation-used-with-dataverse + // On-prem databases may be configured with any default collation, but this is not exposed through any API. + var ci = true; + var ai = true; + + switch (lcid) + { + case 1035: + case 1048: + case 1050: + case 1051: + case 1053: + case 1054: + case 1057: + case 1058: + case 1061: + case 1062: + case 1063: + case 1066: + case 1069: + case 1081: + case 1086: + case 1087: + case 1110: + case 2074: + ai = false; + break; + } + + return new Collation(lcid, !ci, !ai); + } } } diff --git a/MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs b/MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs index e4c62364..4bb48650 100644 --- a/MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs +++ b/MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs @@ -12,14 +12,14 @@ namespace MarkMpn.Sql4Cds.Engine /// static class DataTypeHelpers { - public static SqlDataTypeReference VarChar(int length) + public static SqlDataTypeReference VarChar(int length, Collation collation, CollationLabel collationLabel) { - return new SqlDataTypeReference { SqlDataTypeOption = SqlDataTypeOption.VarChar, Parameters = { length <= 8000 ? (Literal) new IntegerLiteral { Value = length.ToString(CultureInfo.InvariantCulture) } : new MaxLiteral() } }; + return new SqlDataTypeReferenceWithCollation { SqlDataTypeOption = SqlDataTypeOption.VarChar, Parameters = { length <= 8000 ? (Literal) new IntegerLiteral { Value = length.ToString(CultureInfo.InvariantCulture) } : new MaxLiteral() }, Collation = collation, CollationLabel = collationLabel }; } - public static SqlDataTypeReference Char(int length) + public static SqlDataTypeReference Char(int length, Collation collation, CollationLabel collationLabel) { - return new SqlDataTypeReference { SqlDataTypeOption = SqlDataTypeOption.Char, Parameters = { length <= 8000 ? (Literal)new IntegerLiteral { Value = length.ToString(CultureInfo.InvariantCulture) } : new MaxLiteral() } }; + return new SqlDataTypeReferenceWithCollation { SqlDataTypeOption = SqlDataTypeOption.Char, Parameters = { length <= 8000 ? (Literal)new IntegerLiteral { Value = length.ToString(CultureInfo.InvariantCulture) } : new MaxLiteral() }, Collation = collation, CollationLabel = collationLabel }; } public static SqlDataTypeReference VarBinary(int length) @@ -77,14 +77,14 @@ public static UserDataTypeReference Object(Type type) public static SqlDataTypeReference Float { get; } = new SqlDataTypeReference { SqlDataTypeOption = SqlDataTypeOption.Float }; - public static SqlDataTypeReference NVarChar(int length) + public static SqlDataTypeReference NVarChar(int length, Collation collation, CollationLabel collationLabel) { - return new SqlDataTypeReference { SqlDataTypeOption = SqlDataTypeOption.NVarChar, Parameters = { length <= 8000 ? (Literal) new IntegerLiteral { Value = length.ToString(CultureInfo.InvariantCulture) } : new MaxLiteral() } }; + return new SqlDataTypeReferenceWithCollation { SqlDataTypeOption = SqlDataTypeOption.NVarChar, Parameters = { length <= 8000 ? (Literal) new IntegerLiteral { Value = length.ToString(CultureInfo.InvariantCulture) } : new MaxLiteral() }, Collation = collation, CollationLabel = collationLabel }; } - public static SqlDataTypeReference NChar(int length) + public static SqlDataTypeReference NChar(int length, Collation collation, CollationLabel collationLabel) { - return new SqlDataTypeReference { SqlDataTypeOption = SqlDataTypeOption.NChar, Parameters = { length <= 8000 ? (Literal)new IntegerLiteral { Value = length.ToString(CultureInfo.InvariantCulture) } : new MaxLiteral() } }; + return new SqlDataTypeReferenceWithCollation { SqlDataTypeOption = SqlDataTypeOption.NChar, Parameters = { length <= 8000 ? (Literal)new IntegerLiteral { Value = length.ToString(CultureInfo.InvariantCulture) } : new MaxLiteral() }, Collation = collation, CollationLabel = collationLabel }; } public static SqlDataTypeReference Time(short scale) @@ -388,6 +388,7 @@ public static bool IsSameAs(this DataTypeReference x, DataTypeReference y) /// true if the could be successfully parsed, or false otherwise public static bool TryParse(string value, out DataTypeReference parsedType) { + // TODO: Collation parsedType = null; var name = value; @@ -447,8 +448,10 @@ public bool Equals(DataTypeReference x, DataTypeReference y) var xUser = x as UserDataTypeReference; var xSql = x as SqlDataTypeReference; + var xColl = x as SqlDataTypeReferenceWithCollation; var yUser = y as UserDataTypeReference; var ySql = y as SqlDataTypeReference; + var yColl = y as SqlDataTypeReferenceWithCollation; if (xUser != null && yUser != null) return String.Join(".", xUser.Name.Identifiers.Select(i => i.Value)).Equals(String.Join(".", yUser.Name.Identifiers.Select(i => i.Value)), StringComparison.OrdinalIgnoreCase); @@ -477,6 +480,9 @@ public bool Equals(DataTypeReference x, DataTypeReference y) return false; } + if (xColl != null && yColl != null &&!xColl.Collation.Equals(yColl.Collation)) + return false; + return true; } @@ -491,4 +497,115 @@ public int GetHashCode(DataTypeReference obj) throw new NotSupportedException(); } } + + /// + /// Extends the standard with additional collation information + /// + class SqlDataTypeReferenceWithCollation : SqlDataTypeReference + { + /// + /// Returns or sets the collation that the data will use + /// + public Collation Collation { get; set; } + + /// + /// Indicates how the has been spplied + /// + public CollationLabel CollationLabel { get; set; } + + /// + /// Applies the precedence rules to convert values of different collations to a single collation + /// + /// The type of the first expression + /// The type of the second expression + /// The final collation to use + /// The final collation label to use + /// The final collation to use + internal static bool TryConvertCollation(SqlDataTypeReference lhsSql, SqlDataTypeReference rhsSql, out Collation collation, out CollationLabel collationLabel) + { + collation = null; + collationLabel = CollationLabel.NoCollation; + + if (!(lhsSql is SqlDataTypeReferenceWithCollation lhsSqlWithColl)) + return false; + + if (!(rhsSql is SqlDataTypeReferenceWithCollation rhsSqlWithColl)) + return false; + + // Two different explicit collations cannot be converted + if (lhsSqlWithColl.CollationLabel == CollationLabel.Explicit && + rhsSqlWithColl.CollationLabel == CollationLabel.Explicit && + !lhsSqlWithColl.Collation.Equals(rhsSqlWithColl.Collation)) + return false; + + // If either collation is explicit, use that + if (lhsSqlWithColl.CollationLabel == CollationLabel.Explicit) + { + collation = lhsSqlWithColl.Collation; + collationLabel = CollationLabel.Explicit; + return true; + } + + if (rhsSqlWithColl.CollationLabel == CollationLabel.Explicit) + { + collation = rhsSqlWithColl.Collation; + collationLabel = CollationLabel.Explicit; + return true; + } + + // If either label is no collation, use that + if (lhsSqlWithColl.CollationLabel == CollationLabel.NoCollation || + rhsSqlWithColl.CollationLabel == CollationLabel.NoCollation) + { + collationLabel = CollationLabel.NoCollation; + return true; + } + + if (lhsSqlWithColl.CollationLabel == CollationLabel.Implicit && + rhsSqlWithColl.CollationLabel == CollationLabel.Implicit) + { + if (lhsSqlWithColl.Collation.Equals(rhsSqlWithColl.Collation)) + { + // Two identical implicit collations remains unchanged + // This doesn't appear to be explicitly defined in the docs, but seems reasonable + collation = lhsSqlWithColl.Collation; + collationLabel = CollationLabel.Implicit; + return true; + } + else + { + // Two different implicit collations results in no collation + collation = null; + collationLabel = CollationLabel.NoCollation; + return true; + } + } + + // Implicit > coercible default + if (lhsSqlWithColl.CollationLabel == CollationLabel.Implicit) + { + collation = lhsSqlWithColl.Collation; + collationLabel = CollationLabel.Implicit; + return true; + } + + if (rhsSqlWithColl.CollationLabel == CollationLabel.Implicit) + { + collation = rhsSqlWithColl.Collation; + collationLabel = CollationLabel.Implicit; + return true; + } + + collationLabel = CollationLabel.CoercibleDefault; + return true; + } + } + + enum CollationLabel + { + CoercibleDefault, + Implicit, + Explicit, + NoCollation + } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDataNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDataNode.cs index 4e59eac4..79a852f9 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDataNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDataNode.cs @@ -196,6 +196,7 @@ public void MergeStatsFrom(BaseDataNode other) /// /// Translates filter criteria from ScriptDom to FetchXML /// + /// The main for this connection /// The to use to get metadata /// to indicate how the query can be executed /// The SQL criteria to attempt to translate to FetchXML @@ -207,9 +208,9 @@ public void MergeStatsFrom(BaseDataNode other) /// The types of any parameters that can be used /// The FetchXML version of the that is generated by this method /// true if the can be translated to FetchXML, or false otherwise - protected bool TranslateFetchXMLCriteria(IAttributeMetadataCache metadata, IQueryExecutionOptions options, BooleanExpression criteria, INodeSchema schema, string allowedPrefix, string targetEntityName, string targetEntityAlias, object[] items, IDictionary parameterTypes, out filter filter) + protected bool TranslateFetchXMLCriteria(DataSource primaryDataSource, IAttributeMetadataCache metadata, IQueryExecutionOptions options, BooleanExpression criteria, INodeSchema schema, string allowedPrefix, string targetEntityName, string targetEntityAlias, object[] items, IDictionary parameterTypes, out filter filter) { - if (!TranslateFetchXMLCriteria(metadata, options, criteria, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out var condition, out filter)) + if (!TranslateFetchXMLCriteria(primaryDataSource, metadata, options, criteria, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out var condition, out filter)) return false; if (condition != null) @@ -221,6 +222,7 @@ protected bool TranslateFetchXMLCriteria(IAttributeMetadataCache metadata, IQuer /// /// Translates filter criteria from ScriptDom to FetchXML /// + /// The main for this connection /// The to use to get metadata /// to indicate how the query can be executed /// The SQL criteria to attempt to translate to FetchXML @@ -233,16 +235,16 @@ protected bool TranslateFetchXMLCriteria(IAttributeMetadataCache metadata, IQuer /// The FetchXML version of the that is generated by this method when it covers multiple conditions /// The FetchXML version of the that is generated by this method when it is for a single condition only /// true if the can be translated to FetchXML, or false otherwise - private bool TranslateFetchXMLCriteria(IAttributeMetadataCache metadata, IQueryExecutionOptions options, BooleanExpression criteria, INodeSchema schema, string allowedPrefix, string targetEntityName, string targetEntityAlias, object[] items, IDictionary parameterTypes, out condition condition, out filter filter) + private bool TranslateFetchXMLCriteria(DataSource primaryDataSource, IAttributeMetadataCache metadata, IQueryExecutionOptions options, BooleanExpression criteria, INodeSchema schema, string allowedPrefix, string targetEntityName, string targetEntityAlias, object[] items, IDictionary parameterTypes, out condition condition, out filter filter) { condition = null; filter = null; if (criteria is BooleanBinaryExpression binary) { - if (!TranslateFetchXMLCriteria(metadata, options, binary.FirstExpression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out var lhsCondition, out var lhsFilter)) + if (!TranslateFetchXMLCriteria(primaryDataSource, metadata, options, binary.FirstExpression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out var lhsCondition, out var lhsFilter)) return false; - if (!TranslateFetchXMLCriteria(metadata, options, binary.SecondExpression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out var rhsCondition, out var rhsFilter)) + if (!TranslateFetchXMLCriteria(primaryDataSource, metadata, options, binary.SecondExpression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out var rhsCondition, out var rhsFilter)) return false; filter = new filter @@ -259,7 +261,7 @@ private bool TranslateFetchXMLCriteria(IAttributeMetadataCache metadata, IQueryE if (criteria is BooleanParenthesisExpression paren) { - return TranslateFetchXMLCriteria(metadata, options, paren.Expression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out condition, out filter); + return TranslateFetchXMLCriteria(primaryDataSource, metadata, options, paren.Expression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out condition, out filter); } if (criteria is BooleanComparisonExpression comparison) @@ -320,7 +322,7 @@ private bool TranslateFetchXMLCriteria(IAttributeMetadataCache metadata, IQueryE } // If we still couldn't find the column name and value, this isn't a pattern we can support in FetchXML - if (field == null || (literal == null && func == null && variable == null && parameterless == null && globalVariable == null && (field2 == null || !options.ColumnComparisonAvailable) && !expr.IsConstantValueExpression(schema, options, out literal))) + if (field == null || (literal == null && func == null && variable == null && parameterless == null && globalVariable == null && (field2 == null || !options.ColumnComparisonAvailable) && !expr.IsConstantValueExpression(primaryDataSource, schema, options, out literal))) return false; // Select the correct FetchXML operator @@ -425,14 +427,14 @@ private bool TranslateFetchXMLCriteria(IAttributeMetadataCache metadata, IQueryE } else if (func != null) { - if (func.IsConstantValueExpression(schema, options, out literal)) + if (func.IsConstantValueExpression(primaryDataSource, schema, options, out literal)) values = new[] { literal }; else return false; } else if (parameterless != null) { - if (parameterless.IsConstantValueExpression(schema, options, out literal)) + if (parameterless.IsConstantValueExpression(primaryDataSource, schema, options, out literal)) { values = new[] { literal }; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs index 657cf402..9394d61c 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs @@ -220,6 +220,7 @@ protected List GetDmlSourceEntities(IDictionary data var dataTable = new DataTable(); var schemaTable = dataReader.GetSchemaTable(); var columnTypes = new Dictionary(StringComparer.OrdinalIgnoreCase); + var targetDataSource = dataSources[DataSource]; for (var i = 0; i < schemaTable.Rows.Count; i++) { @@ -239,10 +240,10 @@ protected List GetDmlSourceEntities(IDictionary data { case "binary": colSqlType = DataTypeHelpers.Binary(colSize); break; case "varbinary": colSqlType = DataTypeHelpers.VarBinary(colSize); break; - case "char": colSqlType = DataTypeHelpers.Char(colSize); break; - case "varchar": colSqlType = DataTypeHelpers.VarChar(colSize); break; - case "nchar": colSqlType = DataTypeHelpers.NChar(colSize); break; - case "nvarchar": colSqlType = DataTypeHelpers.NVarChar(colSize); break; + case "char": colSqlType = DataTypeHelpers.Char(colSize, targetDataSource.DefaultCollation, CollationLabel.Implicit); break; + case "varchar": colSqlType = DataTypeHelpers.VarChar(colSize, targetDataSource.DefaultCollation, CollationLabel.Implicit); break; + case "nchar": colSqlType = DataTypeHelpers.NChar(colSize, targetDataSource.DefaultCollation, CollationLabel.Implicit); break; + case "nvarchar": colSqlType = DataTypeHelpers.NVarChar(colSize, targetDataSource.DefaultCollation, CollationLabel.Implicit); break; case "datetime": colSqlType = DataTypeHelpers.DateTime; break; case "smalldatetime": colSqlType = DataTypeHelpers.SmallDateTime; break; case "date": colSqlType = DataTypeHelpers.Date; break; @@ -330,9 +331,9 @@ protected List GetDmlSourceEntities(IDictionary data /// The time zone that datetime values are supplied in /// The records that are being mapped /// - protected Dictionary> CompileColumnMappings(IAttributeMetadataCache cache, string logicalName, IDictionary mappings, INodeSchema schema, DateTimeKind dateTimeKind, List entities) + protected Dictionary> CompileColumnMappings(DataSource dataSource, string logicalName, IDictionary mappings, INodeSchema schema, DateTimeKind dateTimeKind, List entities) { - var metadata = cache[logicalName]; + var metadata = dataSource.Metadata[logicalName]; var attributes = metadata.Attributes.ToDictionary(a => a.LogicalName, StringComparer.OrdinalIgnoreCase); var attributeAccessors = new Dictionary>(); @@ -353,7 +354,7 @@ protected Dictionary> CompileColumnMappings(IAttrib var sourceSqlType = schema.Schema[sourceColumnName]; var destType = attr.GetAttributeType(); - var destSqlType = attr.IsPrimaryId == true ? DataTypeHelpers.UniqueIdentifier : attr.GetAttributeSqlType(cache, true); + var destSqlType = attr.IsPrimaryId == true ? DataTypeHelpers.UniqueIdentifier : attr.GetAttributeSqlType(dataSource, true); if (attr is LookupAttributeMetadata && metadata.IsIntersect == true) { @@ -402,7 +403,7 @@ protected Dictionary> CompileColumnMappings(IAttrib { var sourceTargetColumnName = mappings[destAttributeName + "type"]; var sourceTargetType = schema.Schema[sourceTargetColumnName]; - var stringType = DataTypeHelpers.NVarChar(MetadataExtensions.EntityLogicalNameMaxLength); + var stringType = DataTypeHelpers.NVarChar(MetadataExtensions.EntityLogicalNameMaxLength, dataSource.DefaultCollation, CollationLabel.Implicit); targetExpr = Expression.Property(entityParam, typeof(Entity).GetCustomAttribute().MemberName, Expression.Constant(sourceTargetColumnName)); targetExpr = SqlTypeConverter.Convert(targetExpr, sourceTargetType, stringType); targetExpr = SqlTypeConverter.Convert(targetExpr, typeof(string)); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConcatenateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConcatenateNode.cs index 8d218033..2ab3ced4 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConcatenateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConcatenateNode.cs @@ -86,7 +86,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionaryA mapping of parameter names to their types that are available to the expression /// The SQL data type that will be returned /// The type of value that will be returned by the expression - public static Type GetType(this TSqlFragment expr, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, out DataTypeReference sqlType) + public static Type GetType(this TSqlFragment expr, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, out DataTypeReference sqlType) { var entityParam = Expression.Parameter(typeof(Entity)); var parameterParam = Expression.Parameter(typeof(IDictionary)); var optionsParam = Expression.Parameter(typeof(IQueryExecutionOptions)); - var expression = ToExpression(expr, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + var expression = ToExpression(expr, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); return expression.Type; } @@ -47,13 +47,13 @@ public static Type GetType(this TSqlFragment expr, INodeSchema schema, INodeSche /// The schema of the node that the expression will be evaluated in the context of /// A mapping of parameter names to their types that are available to the expression /// A function that accepts a representing the data values of a record, a holding parameter values and an defining how the query should be run and returns the value of the expression - public static Func, IQueryExecutionOptions, object> Compile(this TSqlFragment expr, INodeSchema schema, IDictionary parameterTypes) + public static Func, IQueryExecutionOptions, object> Compile(this TSqlFragment expr, DataSource primaryDataSource, INodeSchema schema, IDictionary parameterTypes) { var entityParam = Expression.Parameter(typeof(Entity)); var parameterParam = Expression.Parameter(typeof(IDictionary)); var optionsParam = Expression.Parameter(typeof(IQueryExecutionOptions)); - var expression = ToExpression(expr, schema, null, parameterTypes, entityParam, parameterParam, optionsParam, out _); + var expression = ToExpression(expr, primaryDataSource, schema, null, parameterTypes, entityParam, parameterParam, optionsParam, out _); expression = Expr.Box(expression); return Expression.Lambda, IQueryExecutionOptions, object>>(expression, entityParam, parameterParam, optionsParam).Compile(); @@ -66,80 +66,80 @@ public static Func, IQueryExecutionOptions, /// The schema of the node that the expression will be evaluated in the context of /// A mapping of parameter names to their types that are available to the expression /// A function that accepts a representing the data values of a record, a holding parameter values and an defining how the query should be run and returns the value of the expression - public static Func, IQueryExecutionOptions, bool> Compile(this BooleanExpression b, INodeSchema schema, IDictionary parameterTypes) + public static Func, IQueryExecutionOptions, bool> Compile(this BooleanExpression b, DataSource primaryDataSource, INodeSchema schema, IDictionary parameterTypes) { var entityParam = Expression.Parameter(typeof(Entity)); var parameterParam = Expression.Parameter(typeof(IDictionary)); var optionsParam = Expression.Parameter(typeof(IQueryExecutionOptions)); - var expression = ToExpression(b, schema, null, parameterTypes, entityParam, parameterParam, optionsParam, out _); + var expression = ToExpression(b, primaryDataSource, schema, null, parameterTypes, entityParam, parameterParam, optionsParam, out _); expression = Expression.IsTrue(expression); return Expression.Lambda, IQueryExecutionOptions, bool>>(expression, entityParam, parameterParam, optionsParam).Compile(); } - private static Expression ToExpression(this TSqlFragment expr, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this TSqlFragment expr, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { if (expr is ColumnReferenceExpression col) - return ToExpression(col, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(col, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is IdentifierLiteral guid) - return ToExpression(guid, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(guid, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is IntegerLiteral i) - return ToExpression(i, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(i, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is MoneyLiteral money) - return ToExpression(money, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(money, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is NullLiteral n) - return ToExpression(n, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(n, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is NumericLiteral num) - return ToExpression(num, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(num, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is RealLiteral real) - return ToExpression(real, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(real, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is StringLiteral str) - return ToExpression(str, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(str, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is OdbcLiteral odbc) - return ToExpression(odbc, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(odbc, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is BooleanBinaryExpression boolBin) - return ToExpression(boolBin, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(boolBin, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is BooleanComparisonExpression cmp) - return ToExpression(cmp, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(cmp, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is BooleanParenthesisExpression boolParen) - return ToExpression(boolParen, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(boolParen, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is InPredicate inPred) - return ToExpression(inPred, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(inPred, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is BooleanIsNullExpression isNull) - return ToExpression(isNull, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(isNull, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is LikePredicate like) - return ToExpression(like, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(like, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is BooleanNotExpression not) - return ToExpression(not, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(not, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is FullTextPredicate fullText) - return ToExpression(fullText, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(fullText, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is Microsoft.SqlServer.TransactSql.ScriptDom.BinaryExpression bin) - return ToExpression(bin, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(bin, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is FunctionCall func) - return ToExpression(func, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(func, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is ParenthesisExpression paren) - return ToExpression(paren, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(paren, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is Microsoft.SqlServer.TransactSql.ScriptDom.UnaryExpression unary) - return ToExpression(unary, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(unary, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is VariableReference var) - return ToExpression(var, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(var, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is SimpleCaseExpression simpleCase) - return ToExpression(simpleCase, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(simpleCase, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is SearchedCaseExpression searchedCase) - return ToExpression(searchedCase, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(searchedCase, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is ConvertCall convert) - return ToExpression(convert, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(convert, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is CastCall cast) - return ToExpression(cast, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(cast, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is ParameterlessCall parameterless) - return ToExpression(parameterless, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(parameterless, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else if (expr is GlobalVariableExpression global) - return ToExpression(global, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(global, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); else throw new NotSupportedQueryFragmentException("Unhandled expression type", expr); } - private static Expression ToExpression(ColumnReferenceExpression col, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(ColumnReferenceExpression col, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { var name = col.GetColumnName(); @@ -173,50 +173,56 @@ private static Expression ToExpression(ColumnReferenceExpression col, INodeSchem return Expression.Convert(expr, sqlType.ToNetType(out _)); } - private static Expression ToExpression(IdentifierLiteral guid, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(IdentifierLiteral guid, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { sqlType = DataTypeHelpers.UniqueIdentifier; return Expression.Constant(new SqlGuid(guid.Value)); } - private static Expression ToExpression(IntegerLiteral i, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(IntegerLiteral i, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { sqlType = DataTypeHelpers.Int; return Expression.Constant(new SqlInt32(Int32.Parse(i.Value, CultureInfo.InvariantCulture))); } - private static Expression ToExpression(MoneyLiteral money, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(MoneyLiteral money, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { sqlType = DataTypeHelpers.Money; return Expression.Constant(new SqlMoney(Decimal.Parse(money.Value, CultureInfo.InvariantCulture))); } - private static Expression ToExpression(NullLiteral n, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(NullLiteral n, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { sqlType = DataTypeHelpers.ImplicitIntForNullLiteral; return Expression.Constant(SqlInt32.Null); } - private static Expression ToExpression(NumericLiteral num, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(NumericLiteral num, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { var value = new SqlDecimal(Decimal.Parse(num.Value, CultureInfo.InvariantCulture)); sqlType = DataTypeHelpers.Decimal(value.Precision, value.Scale); return Expression.Constant(value); } - private static Expression ToExpression(RealLiteral real, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(RealLiteral real, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { sqlType = DataTypeHelpers.Real; return Expression.Constant(new SqlDouble(Double.Parse(real.Value, CultureInfo.InvariantCulture))); } - private static Expression ToExpression(StringLiteral str, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(StringLiteral str, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { - sqlType = str.IsNational ? DataTypeHelpers.NVarChar(str.Value.Length) : DataTypeHelpers.VarChar(str.Value.Length); - return Expression.Constant(new SqlString(str.Value, CultureInfo.CurrentCulture.LCID, SqlCompareOptions.IgnoreCase | SqlCompareOptions.IgnoreNonSpace)); + var collationLabel = CollationLabel.CoercibleDefault; + var collation = GetCollation(primaryDataSource, str.Collation, ref collationLabel); + + sqlType = str.IsNational + ? DataTypeHelpers.NVarChar(str.Value.Length, collation, collationLabel) + : DataTypeHelpers.VarChar(str.Value.Length, collation, collationLabel); + + return Expression.Constant(collation.ToSqlString(str.Value)); } - private static Expression ToExpression(OdbcLiteral odbc, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(OdbcLiteral odbc, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { switch (odbc.OdbcLiteralType) { @@ -237,7 +243,7 @@ private static Expression ToExpression(OdbcLiteral odbc, INodeSchema schema, INo } } - private static Expression ToExpression(BooleanComparisonExpression cmp, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(BooleanComparisonExpression cmp, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { // Special case for field = func() where func is defined in FetchXmlConditionMethods if (cmp.FirstExpression is ColumnReferenceExpression && @@ -247,15 +253,15 @@ cmp.SecondExpression is FunctionCall func { var parameters = func.Parameters.Select(p => { - var paramExpr = p.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var paramType); + var paramExpr = p.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var paramType); return new KeyValuePair(paramExpr, paramType); }).ToList(); - var colExpr = cmp.FirstExpression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var colType); + var colExpr = cmp.FirstExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var colType); parameters.Insert(0, new KeyValuePair(colExpr, colType)); var paramTypes = parameters.Select(p => p.Value).ToArray(); var paramExpressions = parameters.Select(p => p.Key).ToArray(); - var fetchXmlComparison = GetMethod(typeof(FetchXmlConditionMethods), func, paramTypes, false, optionsParam, ref paramExpressions, out sqlType); + var fetchXmlComparison = GetMethod(typeof(FetchXmlConditionMethods), primaryDataSource, func, paramTypes, false, optionsParam, ref paramExpressions, out sqlType); if (fetchXmlComparison != null) return Expr.Call(fetchXmlComparison, paramExpressions); @@ -263,10 +269,10 @@ cmp.SecondExpression is FunctionCall func sqlType = DataTypeHelpers.Bit; - var lhs = cmp.FirstExpression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var lhsType); - var rhs = cmp.SecondExpression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var rhsType); + var lhs = cmp.FirstExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var lhsType); + var rhs = cmp.SecondExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var rhsType); - if (!SqlTypeConverter.CanMakeConsistentTypes(lhsType, rhsType, out var type)) + if (!SqlTypeConverter.CanMakeConsistentTypes(lhsType, rhsType, primaryDataSource, out var type)) { // Special case - we can filter on entity reference types by string if (lhs.Type == typeof(SqlEntityReference) && rhs.Type == typeof(SqlString) || @@ -332,12 +338,12 @@ cmp.SecondExpression is StringLiteral str && } } - private static Expression ToExpression(BooleanBinaryExpression bin, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(BooleanBinaryExpression bin, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { sqlType = DataTypeHelpers.Bit; - var lhs = bin.FirstExpression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out _); - var rhs = bin.SecondExpression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out _); + var lhs = bin.FirstExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out _); + var rhs = bin.SecondExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out _); if (bin.BinaryExpressionType == BooleanBinaryExpressionType.And) return Expression.AndAlso(lhs, rhs); @@ -345,17 +351,17 @@ private static Expression ToExpression(BooleanBinaryExpression bin, INodeSchema return Expression.OrElse(lhs, rhs); } - private static Expression ToExpression(BooleanParenthesisExpression paren, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(BooleanParenthesisExpression paren, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { - return paren.Expression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return paren.Expression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); } - private static Expression ToExpression(Microsoft.SqlServer.TransactSql.ScriptDom.BinaryExpression bin, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(Microsoft.SqlServer.TransactSql.ScriptDom.BinaryExpression bin, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { - var lhs = bin.FirstExpression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var lhsSqlType); - var rhs = bin.SecondExpression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var rhsSqlType); + var lhs = bin.FirstExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var lhsSqlType); + var rhs = bin.SecondExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var rhsSqlType); - if (!SqlTypeConverter.CanMakeConsistentTypes(lhsSqlType, rhsSqlType, out var type)) + if (!SqlTypeConverter.CanMakeConsistentTypes(lhsSqlType, rhsSqlType, primaryDataSource, out var type)) throw new NotSupportedQueryFragmentException($"No implicit conversion exists for types {lhsSqlType.ToSql()} and {rhsSqlType.ToSql()}", bin); // For decimal types, need to work out the precision and scale of the result depending on the type of operation @@ -439,19 +445,35 @@ private static Expression ToExpression(Microsoft.SqlServer.TransactSql.ScriptDom else expr = Expression.Add(lhs, rhs); - // Special case for SqlString length calculation - if (lhsSqlType is SqlDataTypeReference lhsSql && - rhsSqlType is SqlDataTypeReference rhsSql && + // Special case for SqlString length & collation calculation + if (lhsSqlType is SqlDataTypeReferenceWithCollation lhsSql && + rhsSqlType is SqlDataTypeReferenceWithCollation rhsSql && lhs.Type == typeof(SqlString) && rhs.Type == typeof(SqlString) && lhsSql.Parameters.Count == 1 && rhsSql.Parameters.Count == 1 && - lhsSql.Parameters[0].LiteralType == LiteralType.Integer && - rhsSql.Parameters[0].LiteralType == LiteralType.Integer && - Int32.TryParse(lhsSql.Parameters[0].Value, out var lhsLength) && - Int32.TryParse(rhsSql.Parameters[0].Value, out var rhsLength)) + sqlType is SqlDataTypeReferenceWithCollation sqlTypeWithColl) { - sqlType = DataTypeHelpers.NVarChar(lhsLength + rhsLength); + int lhsLength; + int rhsLength; + + if (lhsSql.Parameters[0].LiteralType != LiteralType.Integer || + !Int32.TryParse(lhsSql.Parameters[0].Value, out lhsLength)) + lhsLength = 8000; + + if (rhsSql.Parameters[0].LiteralType != LiteralType.Integer || + !Int32.TryParse(rhsSql.Parameters[0].Value, out rhsLength)) + rhsLength = 8000; + + var length = lhsLength + rhsLength; + + sqlType = new SqlDataTypeReferenceWithCollation + { + SqlDataTypeOption = ((SqlDataTypeReference)type).SqlDataTypeOption, + Parameters = { length <= 8000 ? (Literal)new IntegerLiteral { Value = length.ToString(CultureInfo.InvariantCulture) } : new MaxLiteral() }, + Collation = sqlTypeWithColl.Collation, + CollationLabel = sqlTypeWithColl.CollationLabel + }; } break; @@ -495,7 +517,7 @@ rhsSqlType is SqlDataTypeReference rhsSql && sqlType = type; if (sqlType == null) - sqlType = expr.Type.ToSqlType(); + sqlType = expr.Type.ToSqlType(primaryDataSource); return expr; } @@ -520,7 +542,7 @@ private static SqlDateTime SubtractSqlDateTime(SqlDateTime lhs, SqlDateTime rhs) return lhs - ts; } - private static MethodInfo GetMethod(FunctionCall func, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out Expression[] paramExpressions, out DataTypeReference sqlType) + private static MethodInfo GetMethod(FunctionCall func, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out Expression[] paramExpressions, out DataTypeReference sqlType) { KeyValuePair[] paramExpressionsWithType; @@ -547,10 +569,10 @@ private static MethodInfo GetMethod(FunctionCall func, INodeSchema schema, INode throw new NotSupportedQueryFragmentException("Expected a datepart name", param); } - return new KeyValuePair(Expression.Constant(col.MultiPartIdentifier.Identifiers.Single().Value), DataTypeHelpers.NVarChar(col.MultiPartIdentifier.Identifiers.Single().Value.Length)); + return new KeyValuePair(Expression.Constant(col.MultiPartIdentifier.Identifiers.Single().Value), DataTypeHelpers.NVarChar(col.MultiPartIdentifier.Identifiers.Single().Value.Length, primaryDataSource.DefaultCollation, CollationLabel.CoercibleDefault)); } - var paramExpr = param.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var paramType); + var paramExpr = param.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var paramType); return new KeyValuePair(paramExpr, paramType); }) .ToArray(); @@ -560,7 +582,7 @@ private static MethodInfo GetMethod(FunctionCall func, INodeSchema schema, INode paramExpressionsWithType = func.Parameters .Select(param => { - var paramExpr = param.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var paramType); + var paramExpr = param.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var paramType); return new KeyValuePair(paramExpr, paramType); }) .ToArray(); @@ -570,10 +592,10 @@ private static MethodInfo GetMethod(FunctionCall func, INodeSchema schema, INode .Select(kvp => kvp.Key) .ToArray(); - return GetMethod(typeof(ExpressionFunctions), func, paramExpressionsWithType.Select(kvp => kvp.Value).ToArray(), true, optionsParam, ref paramExpressions, out sqlType); + return GetMethod(typeof(ExpressionFunctions), primaryDataSource, func, paramExpressionsWithType.Select(kvp => kvp.Value).ToArray(), true, optionsParam, ref paramExpressions, out sqlType); } - private static MethodInfo GetMethod(Type targetType, FunctionCall func, DataTypeReference[] paramTypes, bool throwOnMissing, Expression optionsParam, ref Expression[] paramExpressions, out DataTypeReference sqlType) + private static MethodInfo GetMethod(Type targetType, DataSource primaryDataSource, FunctionCall func, DataTypeReference[] paramTypes, bool throwOnMissing, Expression optionsParam, ref Expression[] paramExpressions, out DataTypeReference sqlType) { // Find a method that implements this function var methods = targetType @@ -644,8 +666,11 @@ private static MethodInfo GetMethod(Type targetType, FunctionCall func, DataType // Use the [MaxLength(value)] attribute from the method where available var methodMaxLength = method.GetCustomAttribute(); + // TODO: Add an attribute to indicate if the collation should be taken from a parameter to the function + // or use the default collation for the connection + if (methodMaxLength?.MaxLength != null) - sqlType = DataTypeHelpers.NVarChar(methodMaxLength.MaxLength.Value); + sqlType = DataTypeHelpers.NVarChar(methodMaxLength.MaxLength.Value, primaryDataSource.DefaultCollation, CollationLabel.CoercibleDefault); // Work out precise type from parameter with [MaxLength] attribute where available for (var i = 0; i < parameters.Length; i++) @@ -653,7 +678,7 @@ private static MethodInfo GetMethod(Type targetType, FunctionCall func, DataType if (parameters[i].GetCustomAttribute() != null) { if (parameters[i].ParameterType == typeof(SqlInt32) && paramExpressions[i] is ConstantExpression lengthConst && lengthConst.Value is SqlInt32 length && !length.IsNull) - sqlType = DataTypeHelpers.NVarChar(length.Value); + sqlType = DataTypeHelpers.NVarChar(length.Value, primaryDataSource.DefaultCollation, CollationLabel.CoercibleDefault); else if (parameters[i].ParameterType == typeof(SqlString) && paramTypes[i].ToNetType(out var sqlStringType) == typeof(SqlString)) sqlType = paramTypes[i]; @@ -717,31 +742,31 @@ private static MethodInfo GetMethod(Type targetType, FunctionCall func, DataType continue; } - if (!SqlTypeConverter.CanChangeTypeImplicit(paramTypes[i], paramType.ToSqlType())) - throw new NotSupportedQueryFragmentException($"Cannot convert {paramTypes[i].ToSql()} to {paramType.ToSqlType().ToSql()}", i < paramOffset ? func : func.Parameters[i - paramOffset]); + if (!SqlTypeConverter.CanChangeTypeImplicit(paramTypes[i], paramType.ToSqlType(primaryDataSource))) + throw new NotSupportedQueryFragmentException($"Cannot convert {paramTypes[i].ToSql()} to {paramType.ToSqlType(primaryDataSource).ToSql()}", i < paramOffset ? func : func.Parameters[i - paramOffset]); } for (var i = parameters.Length; i < paramTypes.Length; i++) { var paramType = parameters.Last().ParameterType.GetElementType(); - if (!SqlTypeConverter.CanChangeTypeImplicit(paramTypes[i], paramType.ToSqlType())) - throw new NotSupportedQueryFragmentException($"Cannot convert {paramTypes[i].ToSql()} to {paramType.ToSqlType().ToSql()}", i < paramOffset ? func : func.Parameters[i - paramOffset]); + if (!SqlTypeConverter.CanChangeTypeImplicit(paramTypes[i], paramType.ToSqlType(primaryDataSource))) + throw new NotSupportedQueryFragmentException($"Cannot convert {paramTypes[i].ToSql()} to {paramType.ToSqlType(primaryDataSource).ToSql()}", i < paramOffset ? func : func.Parameters[i - paramOffset]); } if (sqlType == null) - sqlType = method.ReturnType.ToSqlType(); + sqlType = method.ReturnType.ToSqlType(primaryDataSource); return method; } - private static Expression ToExpression(this FunctionCall func, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this FunctionCall func, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { if (func.OverClause != null) throw new NotSupportedQueryFragmentException("Window functions are not supported", func); // Find the method to call and get the expressions for the parameter values - var method = GetMethod(func, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var paramValues, out sqlType); + var method = GetMethod(func, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var paramValues, out sqlType); // Convert the parameters to the expected types var parameters = method.GetParameters(); @@ -760,14 +785,14 @@ private static Expression ToExpression(this FunctionCall func, INodeSchema schem return expr; } - private static Expression ToExpression(this ParenthesisExpression paren, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this ParenthesisExpression paren, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { - return paren.Expression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return paren.Expression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); } - private static Expression ToExpression(this Microsoft.SqlServer.TransactSql.ScriptDom.UnaryExpression unary, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this Microsoft.SqlServer.TransactSql.ScriptDom.UnaryExpression unary, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { - var value = unary.Expression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + var value = unary.Expression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); switch (unary.UnaryExpressionType) { @@ -785,20 +810,20 @@ private static Expression ToExpression(this Microsoft.SqlServer.TransactSql.Scri } } - private static Expression ToExpression(this InPredicate inPred, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this InPredicate inPred, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { if (inPred.Subquery != null) throw new NotSupportedQueryFragmentException("Subquery should have been eliminated by query plan", inPred); - var exprValue = inPred.Expression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var exprType); + var exprValue = inPred.Expression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var exprType); Expression result = null; foreach (var value in inPred.Values) { - var comparisonValue = value.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var comparisonType); + var comparisonValue = value.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var comparisonType); - if (!SqlTypeConverter.CanMakeConsistentTypes(exprType, comparisonType, out var type)) + if (!SqlTypeConverter.CanMakeConsistentTypes(exprType, comparisonType, primaryDataSource, out var type)) throw new NotSupportedQueryFragmentException($"No implicit conversion exists for types {exprType.ToSql()} and {comparisonType.ToSql()}", inPred); var convertedExprValue = exprValue; @@ -821,7 +846,7 @@ private static Expression ToExpression(this InPredicate inPred, INodeSchema sche return result; } - private static Expression ToExpression(this VariableReference var, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this VariableReference var, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { if (parameterTypes == null || !parameterTypes.TryGetValue(var.Name, out sqlType)) throw new NotSupportedQueryFragmentException("Undefined variable", var); @@ -830,7 +855,7 @@ private static Expression ToExpression(this VariableReference var, INodeSchema s return Expression.Convert(expr, sqlType.ToNetType(out _)); } - private static Expression ToExpression(this GlobalVariableExpression var, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this GlobalVariableExpression var, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { if (parameterTypes == null || !parameterTypes.TryGetValue(var.Name, out sqlType)) throw new NotSupportedQueryFragmentException("Undefined variable", var); @@ -839,9 +864,9 @@ private static Expression ToExpression(this GlobalVariableExpression var, INodeS return Expression.Convert(expr, sqlType.ToNetType(out _)); } - private static Expression ToExpression(this BooleanIsNullExpression isNull, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this BooleanIsNullExpression isNull, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { - var value = isNull.Expression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out _); + var value = isNull.Expression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out _); value = SqlTypeConverter.NullCheck(value); if (isNull.IsNot) @@ -852,16 +877,17 @@ private static Expression ToExpression(this BooleanIsNullExpression isNull, INod return value; } - private static Expression ToExpression(this LikePredicate like, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this LikePredicate like, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { DataTypeReference escapeType = null; - var value = like.FirstExpression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var valueType); - var pattern = like.SecondExpression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var patternType); - var escape = like.EscapeExpression?.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out escapeType); + var value = like.FirstExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var valueType); + var pattern = like.SecondExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var patternType); + var escape = like.EscapeExpression?.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out escapeType); + // TODO: Use the collations of the value/pattern and ensure they are consistent sqlType = DataTypeHelpers.Bit; - var stringType = DataTypeHelpers.NVarChar(Int32.MaxValue); + var stringType = DataTypeHelpers.NVarChar(Int32.MaxValue, primaryDataSource.DefaultCollation, CollationLabel.CoercibleDefault); if (value.Type != typeof(SqlString)) { @@ -1008,23 +1034,23 @@ private static SqlBoolean Like(SqlString value, Regex pattern, bool not) return result; } - private static Expression ToExpression(this SimpleCaseExpression simpleCase, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this SimpleCaseExpression simpleCase, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { // Convert all the different elements to expressions - var value = simpleCase.InputExpression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var valueType); + var value = simpleCase.InputExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var valueType); var whenClauses = simpleCase.WhenClauses.Select(when => { - var whenExpr = when.WhenExpression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var whenType); + var whenExpr = when.WhenExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var whenType); return new { Expression = whenExpr, Type = whenType }; }).ToList(); var caseTypes = new DataTypeReference[whenClauses.Count]; var thenClauses = simpleCase.WhenClauses.Select(when => { - var thenExpr = when.ThenExpression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var thenType); + var thenExpr = when.ThenExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var thenType); return new { Expression = thenExpr, Type = thenType }; }).ToList(); DataTypeReference elseType = null; - var elseValue = simpleCase.ElseExpression?.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out elseType); + var elseValue = simpleCase.ElseExpression?.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out elseType); // First pass to determine final return type DataTypeReference type = null; @@ -1033,7 +1059,7 @@ private static Expression ToExpression(this SimpleCaseExpression simpleCase, INo { var whenType = whenClauses[i].Type; - if (!SqlTypeConverter.CanMakeConsistentTypes(valueType, whenType, out var caseType)) + if (!SqlTypeConverter.CanMakeConsistentTypes(valueType, whenType, primaryDataSource, out var caseType)) throw new NotSupportedQueryFragmentException($"Cannot compare values of type {value.Type} and {whenType}", simpleCase.WhenClauses[i].WhenExpression); caseTypes[i] = caseType; @@ -1042,7 +1068,7 @@ private static Expression ToExpression(this SimpleCaseExpression simpleCase, INo if (type == null) type = thenType; - else if (!SqlTypeConverter.CanMakeConsistentTypes(type, thenType, out type)) + else if (!SqlTypeConverter.CanMakeConsistentTypes(type, thenType, primaryDataSource, out type)) throw new NotSupportedQueryFragmentException($"Cannot determine return type", simpleCase); } @@ -1050,7 +1076,7 @@ private static Expression ToExpression(this SimpleCaseExpression simpleCase, INo { if (type == null) type = elseType; - else if (!SqlTypeConverter.CanMakeConsistentTypes(type, elseType, out type)) + else if (!SqlTypeConverter.CanMakeConsistentTypes(type, elseType, primaryDataSource, out type)) throw new NotSupportedQueryFragmentException($"Cannot determine return type", simpleCase); } @@ -1097,21 +1123,21 @@ private static Expression ToExpression(this SimpleCaseExpression simpleCase, INo return result; } - private static Expression ToExpression(this SearchedCaseExpression searchedCase, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this SearchedCaseExpression searchedCase, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { // Convert all the different elements to expressions var whenClauses = searchedCase.WhenClauses.Select(when => { - var whenExpr = when.WhenExpression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var whenType); + var whenExpr = when.WhenExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var whenType); return new { Expression = whenExpr, Type = whenType }; }).ToList(); var thenClauses = searchedCase.WhenClauses.Select(when => { - var thenExpr = when.ThenExpression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var thenType); + var thenExpr = when.ThenExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var thenType); return new { Expression = thenExpr, Type = thenType }; }).ToList(); DataTypeReference elseType = null; - var elseValue = searchedCase.ElseExpression?.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out elseType); + var elseValue = searchedCase.ElseExpression?.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out elseType); // First pass to determine final return type DataTypeReference type = null; @@ -1122,7 +1148,7 @@ private static Expression ToExpression(this SearchedCaseExpression searchedCase, if (type == null) type = thenType; - else if (!SqlTypeConverter.CanMakeConsistentTypes(type, thenType, out type)) + else if (!SqlTypeConverter.CanMakeConsistentTypes(type, thenType, primaryDataSource, out type)) throw new NotSupportedQueryFragmentException($"Cannot determine return type", searchedCase); } @@ -1130,7 +1156,7 @@ private static Expression ToExpression(this SearchedCaseExpression searchedCase, { if (type == null) type = elseType; - else if (!SqlTypeConverter.CanMakeConsistentTypes(type, elseType, out type)) + else if (!SqlTypeConverter.CanMakeConsistentTypes(type, elseType, primaryDataSource, out type)) throw new NotSupportedQueryFragmentException($"Cannot determine return type", searchedCase); } @@ -1171,9 +1197,9 @@ private static Expression ToExpression(this SearchedCaseExpression searchedCase, return result; } - private static Expression ToExpression(this BooleanNotExpression not, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this BooleanNotExpression not, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { - var value = not.Expression.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + var value = not.Expression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); return Expression.Not(value); } @@ -1255,7 +1281,6 @@ public static bool IsType(this DataTypeReference type, SqlDataTypeOption sqlType [typeof(SqlInt64)] = DataTypeHelpers.BigInt, [typeof(SqlBinary)] = DataTypeHelpers.VarBinary(Int32.MaxValue), [typeof(SqlBoolean)] = DataTypeHelpers.Bit, - [typeof(SqlString)] = DataTypeHelpers.NVarChar(Int32.MaxValue), [typeof(SqlDateTime)] = DataTypeHelpers.DateTime, [typeof(SqlDecimal)] = DataTypeHelpers.Decimal(38, 10), [typeof(SqlDouble)] = DataTypeHelpers.Float, @@ -1277,16 +1302,19 @@ public static bool IsType(this DataTypeReference type, SqlDataTypeOption sqlType /// /// The data type to convert /// The equivalent SQL - public static DataTypeReference ToSqlType(this Type type) + public static DataTypeReference ToSqlType(this Type type, DataSource primaryDataSource) { + if (type == typeof(SqlString)) + return DataTypeHelpers.NVarChar(Int32.MaxValue, primaryDataSource?.DefaultCollation ?? Collation.USEnglish, CollationLabel.CoercibleDefault); + return _netTypeMapping[type]; } - private static Expression ToExpression(this ConvertCall convert, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this ConvertCall convert, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { - var value = convert.Parameter.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var valueType); + var value = convert.Parameter.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var valueType); DataTypeReference styleType = null; - var style = convert.Style?.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out styleType); + var style = convert.Style?.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out styleType); sqlType = convert.DataType; @@ -1305,14 +1333,14 @@ private static Expression ToExpression(this ConvertCall convert, INodeSchema sch return SqlTypeConverter.Convert(value, valueType, sqlType, style, styleType, convert); } - private static Expression ToExpression(this CastCall cast, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this CastCall cast, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { - return ToExpression(new ConvertCall { Parameter = cast.Parameter, DataType = cast.DataType }, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(new ConvertCall { Parameter = cast.Parameter, DataType = cast.DataType, Collation = cast.Collation }, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); } private static readonly Regex _containsParser = new Regex("^\\S+( OR \\S+)*$", RegexOptions.IgnoreCase | RegexOptions.Compiled); - private static Expression ToExpression(this FullTextPredicate fullText, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this FullTextPredicate fullText, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { // Only support simple CONTAINS calls to handle multi-select optionsets for now if (fullText.FullTextFunctionType != FullTextFunctionType.Contains) @@ -1330,8 +1358,8 @@ private static Expression ToExpression(this FullTextPredicate fullText, INodeSch if (fullText.LanguageTerm != null) throw new NotSupportedQueryFragmentException("LANGUAGE is not currently supported", fullText.LanguageTerm); - var col = fullText.Columns[0].ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var colType); - var stringType = DataTypeHelpers.NVarChar(Int32.MaxValue); + var col = fullText.Columns[0].ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var colType); + var stringType = DataTypeHelpers.NVarChar(Int32.MaxValue, primaryDataSource.DefaultCollation, CollationLabel.CoercibleDefault); if (!SqlTypeConverter.CanChangeTypeImplicit(colType, stringType)) throw new NotSupportedQueryFragmentException("Only string columns are supported", fullText.Columns[0]); @@ -1348,7 +1376,7 @@ private static Expression ToExpression(this FullTextPredicate fullText, INodeSch return Expr.Call(() => Contains(Expr.Arg(), Expr.Arg()), col, Expression.Constant(words)); } - var value = fullText.Value.ToExpression(schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var valueType); + var value = fullText.Value.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var valueType); if (!SqlTypeConverter.CanChangeTypeImplicit(valueType, stringType)) throw new NotSupportedQueryFragmentException($"Expected string value to match, got {value.Type}", fullText.Value); @@ -1391,7 +1419,7 @@ private static Regex[] GetContainsWords(string pattern, bool compile) .ToArray(); } - private static Expression ToExpression(this ParameterlessCall parameterless, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this ParameterlessCall parameterless, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) { switch (parameterless.ParameterlessCallType) { @@ -1521,7 +1549,7 @@ public static ColumnReferenceExpression ToColumnReference(this string colName) /// The schema that the expression is evaluated in /// The equivalent literal value /// true if the expression has a constant value, or false if it can change depending on the current data record - public static bool IsConstantValueExpression(this ScalarExpression expr, INodeSchema schema, IQueryExecutionOptions options, out Literal literal) + public static bool IsConstantValueExpression(this ScalarExpression expr, DataSource primaryDataSource, INodeSchema schema, IQueryExecutionOptions options, out Literal literal) { literal = expr as Literal; @@ -1546,7 +1574,7 @@ public static bool IsConstantValueExpression(this ScalarExpression expr, INodeSc if (parameterlessVisitor.ParameterlessCalls.Any(p => p.ParameterlessCallType != ParameterlessCallType.CurrentTimestamp)) return false; - var value = expr.Compile(schema, null)(null, null, options); + var value = expr.Compile(primaryDataSource, schema, null)(null, null, options); if (value == null || value is INullable n && n.IsNull) literal = new NullLiteral(); @@ -1569,5 +1597,17 @@ public static bool IsConstantValueExpression(this ScalarExpression expr, INodeSc return true; } + + private static Collation GetCollation(DataSource dataSource, Identifier collation, ref CollationLabel collationLabel) + { + if (collation == null) + return dataSource.DefaultCollation; + + if (!Collation.TryParse(collation.Value, out var coll)) + throw new NotSupportedQueryFragmentException("Invalid collation", collation); + + collationLabel = CollationLabel.Explicit; + return coll; + } } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs index f0da114d..0092c0c3 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs @@ -662,7 +662,7 @@ public override INodeSchema GetSchema(IDictionary dataSource var notNullColumns = new HashSet(); var sortOrder = new List(); - AddSchemaAttributes(schema, aliases, ref primaryKey, notNullColumns, sortOrder, dataSource.Metadata, entity.name, Alias, entity.Items, true, false); + AddSchemaAttributes(dataSource, schema, aliases, ref primaryKey, notNullColumns, sortOrder, entity.name, Alias, entity.Items, true, false); _lastSchema = new NodeSchema( primaryKey: primaryKey, @@ -760,12 +760,12 @@ internal static bool IsValidAlias(string alias) return Regex.IsMatch(alias, "^[A-Za-z_][A-Za-z0-9_]*$"); } - private void AddSchemaAttributes(Dictionary schema, Dictionary> aliases, ref string primaryKey, HashSet notNullColumns, List sortOrder, IAttributeMetadataCache metadata, string entityName, string alias, object[] items, bool innerJoin, bool requireTablePrefix) + private void AddSchemaAttributes(DataSource dataSource, Dictionary schema, Dictionary> aliases, ref string primaryKey, HashSet notNullColumns, List sortOrder, string entityName, string alias, object[] items, bool innerJoin, bool requireTablePrefix) { if (items == null && !ReturnFullSchema) return; - var meta = metadata[entityName]; + var meta = dataSource.Metadata[entityName]; if (ReturnFullSchema && !FetchXml.aggregate) { @@ -779,8 +779,8 @@ private void AddSchemaAttributes(Dictionary schema, D var fullName = $"{alias}.{attrMetadata.LogicalName}"; var simpleName = requireTablePrefix ? null : attrMetadata.LogicalName; - var attrType = attrMetadata.GetAttributeSqlType(metadata, false); - AddSchemaAttribute(schema, aliases, notNullColumns, metadata, fullName, simpleName, attrType, attrMetadata, innerJoin); + var attrType = attrMetadata.GetAttributeSqlType(dataSource, false); + AddSchemaAttribute(dataSource, schema, aliases, notNullColumns, fullName, simpleName, attrType, attrMetadata, innerJoin); } } @@ -789,7 +789,7 @@ private void AddSchemaAttributes(Dictionary schema, D foreach (var attribute in items.OfType()) { var attrMetadata = meta.Attributes.Single(a => a.LogicalName == attribute.name); - var attrType = attrMetadata.GetAttributeSqlType(metadata, false); + var attrType = attrMetadata.GetAttributeSqlType(dataSource, false); if (attribute.aggregateSpecified && (attribute.aggregate == Engine.FetchXml.AggregateType.count || attribute.aggregate == Engine.FetchXml.AggregateType.countcolumn) || attribute.dategroupingSpecified) @@ -838,7 +838,7 @@ private void AddSchemaAttributes(Dictionary schema, D if (requireTablePrefix) attrAlias = null; - AddSchemaAttribute(schema, aliases, notNullColumns, metadata, fullName, attrAlias, attrType, attrMetadata, innerJoin); + AddSchemaAttribute(dataSource, schema, aliases, notNullColumns, fullName, attrAlias, attrType, attrMetadata, innerJoin); } if (items.OfType().Any()) @@ -851,11 +851,11 @@ private void AddSchemaAttributes(Dictionary schema, D if (attrMetadata.AttributeOf != null) continue; - var attrType = attrMetadata.GetAttributeSqlType(metadata, false); + var attrType = attrMetadata.GetAttributeSqlType(dataSource, false); var attrName = requireTablePrefix ? null : attrMetadata.LogicalName; var fullName = $"{alias}.{attrName}"; - AddSchemaAttribute(schema, aliases, notNullColumns, metadata, fullName, attrName, attrType, attrMetadata, innerJoin); + AddSchemaAttribute(dataSource, schema, aliases, notNullColumns, fullName, attrName, attrType, attrMetadata, innerJoin); } } @@ -897,7 +897,7 @@ private void AddSchemaAttributes(Dictionary schema, D if (primaryKey != null) { - var childMeta = metadata[linkEntity.name]; + var childMeta = dataSource.Metadata[linkEntity.name]; if (linkEntity.from != childMeta.PrimaryIdAttribute) { @@ -908,7 +908,7 @@ private void AddSchemaAttributes(Dictionary schema, D } } - AddSchemaAttributes(schema, aliases, ref primaryKey, notNullColumns, sortOrder, metadata, linkEntity.name, linkEntity.alias, linkEntity.Items, innerJoin && linkEntity.linktype == "inner", requireTablePrefix || linkEntity.RequireTablePrefix); + AddSchemaAttributes(dataSource, schema, aliases, ref primaryKey, notNullColumns, sortOrder, linkEntity.name, linkEntity.alias, linkEntity.Items, innerJoin && linkEntity.linktype == "inner", requireTablePrefix || linkEntity.RequireTablePrefix); } if (innerJoin) @@ -942,7 +942,7 @@ private void AddNotNullFilters(Dictionary schema, Dic AddNotNullFilters(schema, aliases, notNullColumns, alias, subFilter); } - private void AddSchemaAttribute(Dictionary schema, Dictionary> aliases, HashSet notNullColumns, IAttributeMetadataCache metadata, string fullName, string simpleName, DataTypeReference type, AttributeMetadata attrMetadata, bool innerJoin) + private void AddSchemaAttribute(DataSource dataSource, Dictionary schema, Dictionary> aliases, HashSet notNullColumns, string fullName, string simpleName, DataTypeReference type, AttributeMetadata attrMetadata, bool innerJoin) { var notNull = innerJoin && (attrMetadata.RequiredLevel?.Value == AttributeRequiredLevel.SystemRequired || attrMetadata.LogicalName == "createdon" || attrMetadata.LogicalName == "createdby" || attrMetadata.AttributeOf == "createdby"); @@ -957,16 +957,16 @@ private void AddSchemaAttribute(Dictionary schema, Di // Add standard virtual attributes if (attrMetadata is MultiSelectPicklistAttributeMetadata) - AddSchemaAttribute(schema, aliases, notNullColumns, fullName + "name", attrMetadata.LogicalName + "name", DataTypeHelpers.NVarChar(Int32.MaxValue), notNull); + AddSchemaAttribute(schema, aliases, notNullColumns, fullName + "name", attrMetadata.LogicalName + "name", DataTypeHelpers.NVarChar(Int32.MaxValue, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), notNull); else if (attrMetadata is EnumAttributeMetadata || attrMetadata is BooleanAttributeMetadata) - AddSchemaAttribute(schema, aliases, notNullColumns, fullName + "name", attrMetadata.LogicalName + "name", DataTypeHelpers.NVarChar(LabelMaxLength), notNull); + AddSchemaAttribute(schema, aliases, notNullColumns, fullName + "name", attrMetadata.LogicalName + "name", DataTypeHelpers.NVarChar(LabelMaxLength, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), notNull); if (attrMetadata is LookupAttributeMetadata lookup) { - AddSchemaAttribute(schema, aliases, notNullColumns, fullName + "name", attrMetadata.LogicalName + "name", DataTypeHelpers.NVarChar(lookup.Targets == null || lookup.Targets.Length == 0 ? 100 : lookup.Targets.Select(e => ((StringAttributeMetadata)metadata[e].Attributes.SingleOrDefault(a => a.LogicalName == metadata[e].PrimaryNameAttribute))?.MaxLength ?? 100).Max()), notNull); + AddSchemaAttribute(schema, aliases, notNullColumns, fullName + "name", attrMetadata.LogicalName + "name", DataTypeHelpers.NVarChar(lookup.Targets == null || lookup.Targets.Length == 0 ? 100 : lookup.Targets.Select(e => ((StringAttributeMetadata)dataSource.Metadata[e].Attributes.SingleOrDefault(a => a.LogicalName == dataSource.Metadata[e].PrimaryNameAttribute))?.MaxLength ?? 100).Max(), dataSource.DefaultCollation, CollationLabel.CoercibleDefault), notNull); if (lookup.Targets?.Length != 1 && lookup.AttributeType != AttributeTypeCode.PartyList) - AddSchemaAttribute(schema, aliases, notNullColumns, fullName + "type", attrMetadata.LogicalName + "type", DataTypeHelpers.NVarChar(MetadataExtensions.EntityLogicalNameMaxLength), notNull); + AddSchemaAttribute(schema, aliases, notNullColumns, fullName + "type", attrMetadata.LogicalName + "type", DataTypeHelpers.NVarChar(MetadataExtensions.EntityLogicalNameMaxLength, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), notNull); } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs index cba21513..919f62ba 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs @@ -35,7 +35,7 @@ class FilterNode : BaseDataNode, ISingleSourceExecutionPlanNode protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) { var schema = Source.GetSchema(dataSources, parameterTypes); - var filter = Filter.Compile(schema, parameterTypes); + var filter = Filter.Compile(dataSources[options.PrimaryDataSource], schema, parameterTypes); foreach (var entity in Source.Execute(dataSources, options, parameterTypes, parameterValues)) { @@ -681,7 +681,7 @@ private bool FoldFiltersToDataSources(IDictionary dataSource throw new NotSupportedQueryFragmentException("Missing datasource " + fetchXml.DataSource); // If the criteria are ANDed, see if any of the individual conditions can be translated to FetchXML - Filter = ExtractFetchXMLFilters(dataSource.Metadata, options, Filter, schema, null, fetchXml.Entity.name, fetchXml.Alias, fetchXml.Entity.Items, parameterTypes, out var fetchFilter); + Filter = ExtractFetchXMLFilters(dataSources[options.PrimaryDataSource], dataSource.Metadata, options, Filter, schema, null, fetchXml.Entity.name, fetchXml.Alias, fetchXml.Entity.Items, parameterTypes, out var fetchFilter); if (fetchFilter != null) { @@ -706,7 +706,7 @@ private bool FoldFiltersToDataSources(IDictionary dataSource if (source is MetadataQueryNode meta) { // If the criteria are ANDed, see if any of the individual conditions can be translated to the metadata query - Filter = ExtractMetadataFilters(Filter, meta, options, out var entityFilter, out var attributeFilter, out var relationshipFilter); + Filter = ExtractMetadataFilters(dataSources[options.PrimaryDataSource], Filter, meta, options, out var entityFilter, out var attributeFilter, out var relationshipFilter); meta.Query.AddFilter(entityFilter); @@ -1012,9 +1012,9 @@ public override void AddRequiredColumns(IDictionary dataSour Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); } - private BooleanExpression ExtractFetchXMLFilters(IAttributeMetadataCache metadata, IQueryExecutionOptions options, BooleanExpression criteria, INodeSchema schema, string allowedPrefix, string targetEntityName, string targetEntityAlias, object[] items, IDictionary parameterTypes, out filter filter) + private BooleanExpression ExtractFetchXMLFilters(DataSource primaryDataSource, IAttributeMetadataCache metadata, IQueryExecutionOptions options, BooleanExpression criteria, INodeSchema schema, string allowedPrefix, string targetEntityName, string targetEntityAlias, object[] items, IDictionary parameterTypes, out filter filter) { - if (TranslateFetchXMLCriteria(metadata, options, criteria, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out filter)) + if (TranslateFetchXMLCriteria(primaryDataSource, metadata, options, criteria, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out filter)) return null; if (!(criteria is BooleanBinaryExpression bin)) @@ -1023,8 +1023,8 @@ private BooleanExpression ExtractFetchXMLFilters(IAttributeMetadataCache metadat if (bin.BinaryExpressionType != BooleanBinaryExpressionType.And) return criteria; - bin.FirstExpression = ExtractFetchXMLFilters(metadata, options, bin.FirstExpression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out var lhsFilter); - bin.SecondExpression = ExtractFetchXMLFilters(metadata, options, bin.SecondExpression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out var rhsFilter); + bin.FirstExpression = ExtractFetchXMLFilters(primaryDataSource, metadata, options, bin.FirstExpression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out var lhsFilter); + bin.SecondExpression = ExtractFetchXMLFilters(primaryDataSource, metadata, options, bin.SecondExpression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out var rhsFilter); filter = (lhsFilter != null && rhsFilter != null) ? new filter { Items = new object[] { lhsFilter, rhsFilter } } : lhsFilter ?? rhsFilter; @@ -1034,9 +1034,9 @@ private BooleanExpression ExtractFetchXMLFilters(IAttributeMetadataCache metadat return bin.FirstExpression ?? bin.SecondExpression; } - protected BooleanExpression ExtractMetadataFilters(BooleanExpression criteria, MetadataQueryNode meta, IQueryExecutionOptions options, out MetadataFilterExpression entityFilter, out MetadataFilterExpression attributeFilter, out MetadataFilterExpression relationshipFilter) + protected BooleanExpression ExtractMetadataFilters(DataSource primaryDataSource, BooleanExpression criteria, MetadataQueryNode meta, IQueryExecutionOptions options, out MetadataFilterExpression entityFilter, out MetadataFilterExpression attributeFilter, out MetadataFilterExpression relationshipFilter) { - if (TranslateMetadataCriteria(criteria, meta, options, out entityFilter, out attributeFilter, out relationshipFilter)) + if (TranslateMetadataCriteria(primaryDataSource, criteria, meta, options, out entityFilter, out attributeFilter, out relationshipFilter)) return null; if (!(criteria is BooleanBinaryExpression bin)) @@ -1045,8 +1045,8 @@ protected BooleanExpression ExtractMetadataFilters(BooleanExpression criteria, M if (bin.BinaryExpressionType != BooleanBinaryExpressionType.And) return criteria; - bin.FirstExpression = ExtractMetadataFilters(bin.FirstExpression, meta, options, out var lhsEntityFilter, out var lhsAttributeFilter, out var lhsRelationshipFilter); - bin.SecondExpression = ExtractMetadataFilters(bin.SecondExpression, meta, options, out var rhsEntityFilter, out var rhsAttributeFilter, out var rhsRelationshipFilter); + bin.FirstExpression = ExtractMetadataFilters(primaryDataSource, bin.FirstExpression, meta, options, out var lhsEntityFilter, out var lhsAttributeFilter, out var lhsRelationshipFilter); + bin.SecondExpression = ExtractMetadataFilters(primaryDataSource, bin.SecondExpression, meta, options, out var rhsEntityFilter, out var rhsAttributeFilter, out var rhsRelationshipFilter); entityFilter = (lhsEntityFilter != null && rhsEntityFilter != null) ? new MetadataFilterExpression { Filters = { lhsEntityFilter, rhsEntityFilter } } : lhsEntityFilter ?? rhsEntityFilter; attributeFilter = (lhsAttributeFilter != null && rhsAttributeFilter != null) ? new MetadataFilterExpression { Filters = { lhsAttributeFilter, rhsAttributeFilter } } : lhsAttributeFilter ?? rhsAttributeFilter; @@ -1057,7 +1057,7 @@ protected BooleanExpression ExtractMetadataFilters(BooleanExpression criteria, M return bin.FirstExpression ?? bin.SecondExpression; } - protected bool TranslateMetadataCriteria(BooleanExpression criteria, MetadataQueryNode meta, IQueryExecutionOptions options, out MetadataFilterExpression entityFilter, out MetadataFilterExpression attributeFilter, out MetadataFilterExpression relationshipFilter) + protected bool TranslateMetadataCriteria(DataSource primaryDataSource, BooleanExpression criteria, MetadataQueryNode meta, IQueryExecutionOptions options, out MetadataFilterExpression entityFilter, out MetadataFilterExpression attributeFilter, out MetadataFilterExpression relationshipFilter) { entityFilter = null; attributeFilter = null; @@ -1065,9 +1065,9 @@ protected bool TranslateMetadataCriteria(BooleanExpression criteria, MetadataQue if (criteria is BooleanBinaryExpression binary) { - if (!TranslateMetadataCriteria(binary.FirstExpression, meta, options, out var lhsEntityFilter, out var lhsAttributeFilter, out var lhsRelationshipFilter)) + if (!TranslateMetadataCriteria(primaryDataSource, binary.FirstExpression, meta, options, out var lhsEntityFilter, out var lhsAttributeFilter, out var lhsRelationshipFilter)) return false; - if (!TranslateMetadataCriteria(binary.SecondExpression, meta, options, out var rhsEntityFilter, out var rhsAttributeFilter, out var rhsRelationshipFilter)) + if (!TranslateMetadataCriteria(primaryDataSource, binary.SecondExpression, meta, options, out var rhsEntityFilter, out var rhsAttributeFilter, out var rhsRelationshipFilter)) return false; if (binary.BinaryExpressionType == BooleanBinaryExpressionType.Or) @@ -1174,7 +1174,7 @@ protected bool TranslateMetadataCriteria(BooleanExpression criteria, MetadataQue throw new InvalidOperationException(); } - var condition = new MetadataConditionExpression(parts[1], op, literal.Compile(null, null)(null, null, options)); + var condition = new MetadataConditionExpression(parts[1], op, literal.Compile(primaryDataSource, null, null)(null, null, options)); return TranslateMetadataCondition(condition, parts[0], meta, out entityFilter, out attributeFilter, out relationshipFilter); } @@ -1198,7 +1198,7 @@ protected bool TranslateMetadataCriteria(BooleanExpression criteria, MetadataQue if (inPred.Values.Any(val => !(val is Literal))) return false; - var condition = new MetadataConditionExpression(parts[1], inPred.NotDefined ? MetadataConditionOperator.NotIn : MetadataConditionOperator.In, inPred.Values.Select(val => val.Compile(null, null)(null, null, options)).ToArray()); + var condition = new MetadataConditionExpression(parts[1], inPred.NotDefined ? MetadataConditionOperator.NotIn : MetadataConditionOperator.In, inPred.Values.Select(val => val.Compile(primaryDataSource, null, null)(null, null, options)).ToArray()); return TranslateMetadataCondition(condition, parts[0], meta, out entityFilter, out attributeFilter, out relationshipFilter); } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GlobalOptionSetQueryNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GlobalOptionSetQueryNode.cs index 9263135c..5f232881 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GlobalOptionSetQueryNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GlobalOptionSetQueryNode.cs @@ -47,7 +47,7 @@ static GlobalOptionSetQueryNode() { type = MetadataQueryNode.GetPropertyType(prop.Property.PropertyType); } - else if (!SqlTypeConverter.CanMakeConsistentTypes(type, MetadataQueryNode.GetPropertyType(prop.Property.PropertyType), out type)) + else if (!SqlTypeConverter.CanMakeConsistentTypes(type, MetadataQueryNode.GetPropertyType(prop.Property.PropertyType), null, out type)) { // Can't make a consistent type for this property, so we can't use it type = null; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MergeJoinNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MergeJoinNode.cs index 4f2c56ac..e9fe5921 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MergeJoinNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MergeJoinNode.cs @@ -18,6 +18,9 @@ class MergeJoinNode : FoldableJoinNode { protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) { + if (!dataSources.TryGetValue(options.PrimaryDataSource, out var dataSource)) + throw new QueryExecutionException("Invalid data source"); + // https://sqlserverfast.com/epr/merge-join/ // Implemented inner, left outer, right outer and full outer variants // Not implemented semi joins @@ -30,7 +33,7 @@ protected override IEnumerable ExecuteInternal(IDictionary ExecuteInternal(IDictionary GetPropertyAccessor(PropertyInfo prop, Type targetType) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlTypeConverter.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlTypeConverter.cs index f5312846..6fae711b 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlTypeConverter.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlTypeConverter.cs @@ -185,9 +185,10 @@ private static void AddNullableTypeConversion(Func /// The type of the first value /// The type of the second value + /// The details of the primary data source being used for the connection /// The type that both values can be converted to /// true if the two values can be converted to a consistent type, or false otherwise - public static bool CanMakeConsistentTypes(DataTypeReference lhs, DataTypeReference rhs, out DataTypeReference consistent) + public static bool CanMakeConsistentTypes(DataTypeReference lhs, DataTypeReference rhs, DataSource primaryDataSource, out DataTypeReference consistent) { if (lhs.IsSameAs(rhs)) { @@ -252,7 +253,42 @@ public static bool CanMakeConsistentTypes(DataTypeReference lhs, DataTypeReferen } var targetType = _precendenceOrder[Math.Min(lhsPrecedence, rhsPrecedence)]; - var fullTargetType = new SqlDataTypeReference { SqlDataTypeOption = targetType }; + SqlDataTypeReference fullTargetType; + + if (targetType.IsStringType()) + { + var lhsColl = lhs as SqlDataTypeReferenceWithCollation; + var rhsColl = rhs as SqlDataTypeReferenceWithCollation; + + if (lhsColl != null && rhsColl != null) + { + if (!SqlDataTypeReferenceWithCollation.TryConvertCollation(lhsSql, rhsSql, out var coll, out var collLabel)) + { + consistent = null; + return false; + } + + fullTargetType = new SqlDataTypeReferenceWithCollation + { + SqlDataTypeOption = targetType, + Collation = coll, + CollationLabel = collLabel + }; + } + else + { + fullTargetType = new SqlDataTypeReferenceWithCollation + { + SqlDataTypeOption = targetType, + Collation = primaryDataSource?.DefaultCollation ?? Collation.USEnglish, + CollationLabel = CollationLabel.CoercibleDefault + }; + } + } + else + { + fullTargetType = new SqlDataTypeReference { SqlDataTypeOption = targetType }; + } // If we're converting to a type that uses a length, choose the longest length if (targetType == SqlDataTypeOption.Binary || targetType == SqlDataTypeOption.VarBinary || diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/UpdateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/UpdateNode.cs index 2158858a..cfd58197 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/UpdateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/UpdateNode.cs @@ -103,8 +103,8 @@ public override string Execute(IDictionary dataSources, IQue var dateTimeKind = options.UseLocalTimeZone ? DateTimeKind.Local : DateTimeKind.Utc; var fullMappings = new Dictionary(ColumnMappings); fullMappings[meta.PrimaryIdAttribute] = new UpdateMapping { OldValueColumn = PrimaryIdSource, NewValueColumn = PrimaryIdSource }; - newAttributeAccessors = CompileColumnMappings(dataSource.Metadata, LogicalName, fullMappings.Where(kvp => kvp.Value.NewValueColumn != null).ToDictionary(kvp => kvp.Key, kvp => kvp.Value.NewValueColumn), schema, dateTimeKind, entities); - oldAttributeAccessors = CompileColumnMappings(dataSource.Metadata, LogicalName, fullMappings.Where(kvp => kvp.Value.OldValueColumn != null).ToDictionary(kvp => kvp.Key, kvp => kvp.Value.OldValueColumn), schema, dateTimeKind, entities); + newAttributeAccessors = CompileColumnMappings(dataSource, LogicalName, fullMappings.Where(kvp => kvp.Value.NewValueColumn != null).ToDictionary(kvp => kvp.Key, kvp => kvp.Value.NewValueColumn), schema, dateTimeKind, entities); + oldAttributeAccessors = CompileColumnMappings(dataSource, LogicalName, fullMappings.Where(kvp => kvp.Value.OldValueColumn != null).ToDictionary(kvp => kvp.Key, kvp => kvp.Value.OldValueColumn), schema, dateTimeKind, entities); primaryIdAccessor = newAttributeAccessors[meta.PrimaryIdAttribute]; } diff --git a/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems b/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems index a2ddde9f..8fdd4258 100644 --- a/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems +++ b/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems @@ -25,6 +25,7 @@ + @@ -142,4 +143,10 @@ Designer + + + + + + \ No newline at end of file diff --git a/MarkMpn.Sql4Cds.Engine/MetadataExtensions.cs b/MarkMpn.Sql4Cds.Engine/MetadataExtensions.cs index 49ecb410..8a845ac3 100644 --- a/MarkMpn.Sql4Cds.Engine/MetadataExtensions.cs +++ b/MarkMpn.Sql4Cds.Engine/MetadataExtensions.cs @@ -82,10 +82,10 @@ public static Type GetAttributeType(this AttributeMetadata attrMetadata) throw new ApplicationException("Unknown attribute type " + attrMetadata.GetType()); } - public static DataTypeReference GetAttributeSqlType(this AttributeMetadata attrMetadata, IAttributeMetadataCache cache, bool write) + public static DataTypeReference GetAttributeSqlType(this AttributeMetadata attrMetadata, DataSource dataSource, bool write) { if (attrMetadata is MultiSelectPicklistAttributeMetadata) - return DataTypeHelpers.NVarChar(Int32.MaxValue); + return DataTypeHelpers.NVarChar(Int32.MaxValue, dataSource.DefaultCollation, CollationLabel.CoercibleDefault); var typeCode = attrMetadata.AttributeType; @@ -114,7 +114,7 @@ public static DataTypeReference GetAttributeSqlType(this AttributeMetadata attrM return DataTypeHelpers.Float; if (attrMetadata is EntityNameAttributeMetadata || typeCode == AttributeTypeCode.EntityName) - return DataTypeHelpers.NVarChar(EntityLogicalNameMaxLength); + return DataTypeHelpers.NVarChar(EntityLogicalNameMaxLength, dataSource.DefaultCollation, CollationLabel.CoercibleDefault); if (attrMetadata is ImageAttributeMetadata) return DataTypeHelpers.VarBinary(Int32.MaxValue); @@ -126,13 +126,13 @@ public static DataTypeReference GetAttributeSqlType(this AttributeMetadata attrM return DataTypeHelpers.BigInt; if (typeCode == AttributeTypeCode.PartyList) - return DataTypeHelpers.NVarChar(Int32.MaxValue); + return DataTypeHelpers.NVarChar(Int32.MaxValue, dataSource.DefaultCollation, CollationLabel.CoercibleDefault); if (attrMetadata is LookupAttributeMetadata || attrMetadata.IsPrimaryId == true || typeCode == AttributeTypeCode.Lookup || typeCode == AttributeTypeCode.Customer || typeCode == AttributeTypeCode.Owner) return DataTypeHelpers.EntityReference; if (attrMetadata is MemoAttributeMetadata || typeCode == AttributeTypeCode.Memo) - return DataTypeHelpers.NVarChar(write && attrMetadata is MemoAttributeMetadata memo && memo.MaxLength != null ? memo.MaxLength.Value : Int32.MaxValue); + return DataTypeHelpers.NVarChar(write && attrMetadata is MemoAttributeMetadata memo && memo.MaxLength != null ? memo.MaxLength.Value : Int32.MaxValue, dataSource.DefaultCollation, CollationLabel.CoercibleDefault); if (attrMetadata is MoneyAttributeMetadata || typeCode == AttributeTypeCode.Money) return DataTypeHelpers.Money; @@ -151,7 +151,7 @@ public static DataTypeReference GetAttributeSqlType(this AttributeMetadata attrM if (attrMetadata.LogicalName.StartsWith("address")) { var parts = attrMetadata.LogicalName.Split('_'); - if (parts.Length == 2 && Int32.TryParse(parts[0].Substring(7), out _) && cache.TryGetValue("customeraddress", out var addressMetadata)) + if (parts.Length == 2 && Int32.TryParse(parts[0].Substring(7), out _) && dataSource.Metadata.TryGetValue("customeraddress", out var addressMetadata)) { // Attribute is e.g. address1_postalcode. Get the equivalent attribute from the customeraddress // entity as it can have very different max length @@ -170,7 +170,7 @@ public static DataTypeReference GetAttributeSqlType(this AttributeMetadata attrM maxLength = maxLengthSetting.Value; } - return DataTypeHelpers.NVarChar(maxLength); + return DataTypeHelpers.NVarChar(maxLength, dataSource.DefaultCollation, CollationLabel.CoercibleDefault); } if (attrMetadata is UniqueIdentifierAttributeMetadata || typeCode == AttributeTypeCode.Uniqueidentifier) @@ -180,7 +180,7 @@ public static DataTypeReference GetAttributeSqlType(this AttributeMetadata attrM return DataTypeHelpers.UniqueIdentifier; if (attrMetadata.AttributeType == AttributeTypeCode.Virtual) - return DataTypeHelpers.NVarChar(Int32.MaxValue); + return DataTypeHelpers.NVarChar(Int32.MaxValue, dataSource.DefaultCollation, CollationLabel.CoercibleDefault); throw new ApplicationException("Unknown attribute type " + attrMetadata.GetType()); } diff --git a/MarkMpn.Sql4Cds.Engine/Resources/CollationNameToLCID.txt b/MarkMpn.Sql4Cds.Engine/Resources/CollationNameToLCID.txt new file mode 100644 index 00000000..6ab5a5ef --- /dev/null +++ b/MarkMpn.Sql4Cds.Engine/Resources/CollationNameToLCID.txt @@ -0,0 +1,135 @@ +Albanian 1052 +Arabic 1025 +Assamese 1101 +Azeri_Cyrillic 2092 +Azeri_Latin 1068 +Bashkir 1133 +Bengali 1093 +Bosnian_Cyrillic 8218 +Bosnian_Latin 5146 +Breton 1150 +Chinese_Hong_Kong_Stroke 3076 +Chinese_PRC 2052 +Chinese_PRC_Stroke 133124 +Chinese_Simplified_Pinyin 2052 +Chinese_Simplified_Stroke_Order 133124 +Chinese_Taiwan_Bopomofo 197636 +Chinese_Taiwan_Stroke 1028 +Chinese_Traditional_Bopomofo 197636 +Chinese_Traditional_Pinyin 5124 +Chinese_Traditional_Stroke_Count 1028 +Chinese_Traditional_Stroke_Order 136196 +Corsican 1155 +Croatian 1050 +Cyrillic_General 1049 +Czech 1029 +Danish_Greenlandic 1030 +Danish_Norwegian 1030 +Dari 1164 +Divehi 1125 +Estonian 1061 +Finnish_Swedish 1035 +French 1036 +Frisian 1122 +Georgian_Modern_Sort 66615 +German_PhoneBook 66567 +Greek 1032 +Hebrew 1037 +Hungarian 1038 +Hungarian_Technical 66574 +Icelandic 1039 +Indic_General 1081 +Japanese 1041 +Japanese_Bushu_Kakusu 263185 +Japanese_Bushu_Kakusu_140 263185 +Japanese_Unicode 66577 +Japanese_XJIS 1041 +Japanese_XJIS_140 1041 +Kazakh 1087 +Khmer 1107 +Korean 1042 +Korean_Wansung 1042 +Lao 1108 +Latin1_General 1033 +Latvian 1062 +Lithuanian 1063 +Macedonian_FYROM 1071 +Maltese 1082 +Maori 1153 +Mapudungan 1146 +Modern_Spanish 3082 +Mohawk 1148 +Nepali 1121 +Norwegian 1044 +Pashto 1123 +Persian 1065 +Polish 1045 +Romanian 1048 +Romansh 1047 +Sami_Norway 1083 +Sami_Sweden_Finland 2107 +Serbian_Cyrillic 3098 +Serbian_Latin 2074 +Slovak 1051 +Slovenian 1060 +SQL_1xCompat_CP850 1033 +SQL_AltDiction_CP850 1033 +SQL_AltDiction_Pref_CP850 1033 +SQL_AltDiction2_CP1253 1033 +SQL_Croatian_CP1250 1050 +SQL_Czech_CP1250 1029 +SQL_Danish_Pref_CP1 1030 +SQL_EBCDIC037_CP1 1033 +SQL_EBCDIC1141_CP1 66567 +SQL_EBCDIC273_CP1 66567 +SQL_EBCDIC277_2_CP1 1030 +SQL_EBCDIC277_CP1 1030 +SQL_EBCDIC278_CP1 1035 +SQL_EBCDIC280_CP1 1033 +SQL_EBCDIC284_CP1 3082 +SQL_EBCDIC285_CP1 1033 +SQL_EBCDIC297_CP1 1036 +SQL_Estonian_CP1257 1061 +SQL_Hungarian_CP1250 1038 +SQL_Icelandic_Pref_CP1 1039 +SQL_Latin1_General_CP1 1033 +SQL_Latin1_General_CP1250 1033 +SQL_Latin1_General_CP1251 1033 +SQL_Latin1_General_CP1253 1033 +SQL_Latin1_General_CP1254 1055 +SQL_Latin1_General_CP1255 1033 +SQL_Latin1_General_CP1256 1033 +SQL_Latin1_General_CP1257 1033 +SQL_Latin1_General_CP437 1033 +SQL_Latin1_General_CP850 1033 +SQL_Latin1_General_Pref_CP1 1033 +SQL_Latin1_General_Pref_CP437 1033 +SQL_Latin1_General_Pref_CP850 1033 +SQL_Latvian_CP1257 1062 +SQL_Lithuanian_CP1257 1063 +SQL_MixDiction_CP1253 1033 +SQL_Polish_CP1250 1045 +SQL_Romanian_CP1250 1048 +SQL_Slovak_CP1250 1051 +SQL_Slovenian_CP1250 1060 +SQL_SwedishPhone_Pref_CP1 1035 +SQL_SwedishStd_Pref_CP1 1035 +SQL_Ukrainian_CP1251 1058 +SQLandinavian_CP850 1035 +SQLandinavian_Pref_CP850 1035 +Syriac 1114 +Tamazight 2143 +Tatar 1092 +Thai 1054 +Tibetan 1105 +Traditional_Spanish 1034 +Turkish 1055 +Turkmen 1090 +Uighur 1152 +Ukrainian 1058 +Upper_Sorbian 1070 +Urdu 1056 +Uzbek_Latin 1091 +Vietnamese 1066 +Welsh 1106 +Yakut 1157 \ No newline at end of file From 6fd54a9c29eee6af6dccae9e775c0ccc24319e91 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Thu, 16 Mar 2023 19:00:27 +0000 Subject: [PATCH 02/34] Collation progress --- .../ExecutionPlan/BaseAggregateNode.cs | 14 ++--- .../ExecutionPlan/ComputeScalarNode.cs | 6 +- .../ExecutionPlan/ConstantScanNode.cs | 2 +- .../ExecutionPlan/DeleteNode.cs | 2 +- .../ExecutionPlan/ExecuteAsNode.cs | 2 +- .../ExecutionPlan/ExecuteMessageNode.cs | 44 +++++++------- .../ExecutionPlan/ExpressionExtensions.cs | 4 +- .../ExecutionPlan/FoldableJoinNode.cs | 2 +- .../ExecutionPlan/GoToNode.cs | 2 +- .../ExecutionPlan/HashJoinNode.cs | 8 +-- .../ExecutionPlan/HashMatchAggregateNode.cs | 4 +- .../ExecutionPlan/IndexSpoolNode.cs | 2 +- .../ExecutionPlan/InsertNode.cs | 2 +- .../ExecutionPlan/NestedLoopNode.cs | 2 +- .../ExecutionPlan/OffsetFetchNode.cs | 16 ++--- .../ExecutionPlan/PartitionedAggregateNode.cs | 2 +- .../ExecutionPlan/PrintNode.cs | 2 +- .../ExecutionPlan/SortNode.cs | 8 +-- .../ExecutionPlan/TopNode.cs | 10 ++-- .../ExecutionPlan/WaitForNode.cs | 2 +- .../ExecutionPlanBuilder.cs | 58 ++++++++++--------- 21 files changed, 98 insertions(+), 96 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseAggregateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseAggregateNode.cs index 218843f6..2925e2ed 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseAggregateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseAggregateNode.cs @@ -55,15 +55,15 @@ protected class AggregateFunctionState [Browsable(false)] public IDataExecutionPlanNodeInternal Source { get; set; } - protected void InitializeAggregates(INodeSchema schema, IDictionary parameterTypes) + protected void InitializeAggregates(DataSource primaryDataSource, INodeSchema schema, IDictionary parameterTypes) { foreach (var aggregate in Aggregates.Where(agg => agg.Value.SqlExpression != null)) { - aggregate.Value.SqlExpression.GetType(schema, null, parameterTypes, out var retType); + aggregate.Value.SqlExpression.GetType(primaryDataSource, schema, null, parameterTypes, out var retType); aggregate.Value.SourceType = retType; aggregate.Value.ReturnType = retType; - aggregate.Value.Expression = aggregate.Value.SqlExpression.Compile(schema, parameterTypes); + aggregate.Value.Expression = aggregate.Value.SqlExpression.Compile(primaryDataSource, schema, parameterTypes); // Return type of SUM and AVG is based on the input type with some modifications // https://docs.microsoft.com/en-us/sql/t-sql/functions/avg-transact-sql?view=sql-server-ver15#return-types @@ -82,13 +82,13 @@ protected void InitializeAggregates(INodeSchema schema, IDictionary parameterTypes) + protected void InitializePartitionedAggregates(DataSource primaryDataSource, INodeSchema schema, IDictionary parameterTypes) { foreach (var aggregate in Aggregates) { var sourceExpression = aggregate.Key.ToColumnReference(); - aggregate.Value.Expression = sourceExpression.Compile(schema, parameterTypes); - sourceExpression.GetType(schema, null, parameterTypes, out var retType); + aggregate.Value.Expression = sourceExpression.Compile(primaryDataSource, schema, parameterTypes); + sourceExpression.GetType(primaryDataSource, schema, null, parameterTypes, out var retType); aggregate.Value.SourceType = retType; aggregate.Value.ReturnType = retType; } @@ -213,7 +213,7 @@ public override INodeSchema GetSchema(IDictionary dataSource break; default: - aggregate.Value.SqlExpression.GetType(sourceSchema, null, parameterTypes, out aggregateType); + aggregate.Value.SqlExpression.GetType(dataSources[options.PrimaryDataSource], sourceSchema, null, parameterTypes, out aggregateType); // Return type of SUM and AVG is based on the input type with some modifications // https://docs.microsoft.com/en-us/sql/t-sql/functions/avg-transact-sql?view=sql-server-ver15#return-types diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ComputeScalarNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ComputeScalarNode.cs index 0d8d08d7..533ecb76 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ComputeScalarNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ComputeScalarNode.cs @@ -32,7 +32,7 @@ protected override IEnumerable ExecuteInternal(IDictionary new { Name = kvp.Key, Expression = kvp.Value.Compile(schema, parameterTypes) }) + .Select(kvp => new { Name = kvp.Key, Expression = kvp.Value.Compile(dataSources[options.PrimaryDataSource], schema, parameterTypes) }) .ToList(); foreach (var entity in Source.Execute(dataSources, options, parameterTypes, parameterValues)) @@ -55,7 +55,7 @@ public override INodeSchema GetSchema(IDictionary dataSource foreach (var calc in Columns) { - calc.Value.GetType(sourceSchema, null, parameterTypes, out var calcType); + calc.Value.GetType(dataSources[options.PrimaryDataSource], sourceSchema, null, parameterTypes, out var calcType); schema[calc.Key] = calcType; } @@ -118,7 +118,7 @@ calc.Value is CastCall c2 && c2.Parameter is Literal || } else { - calc.Value.GetType(null, null, parameterTypes, out var calcType); + calc.Value.GetType(dataSources[options.PrimaryDataSource], null, null, parameterTypes, out var calcType); constant.Schema[calc.Key] = calcType; } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConstantScanNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConstantScanNode.cs index 11e71894..dc931359 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConstantScanNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConstantScanNode.cs @@ -40,7 +40,7 @@ protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQue if (secondaryKey != null) fullMappings[secondaryKey] = SecondaryIdSource; - var attributeAccessors = CompileColumnMappings(dataSource.Metadata, LogicalName, fullMappings, schema, dateTimeKind, entities); + var attributeAccessors = CompileColumnMappings(dataSource, LogicalName, fullMappings, schema, dateTimeKind, entities); primaryIdAccessor = attributeAccessors[primaryKey]; if (SecondaryIdSource != null) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteAsNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteAsNode.cs index 525cc6de..0a5383ce 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteAsNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteAsNode.cs @@ -65,7 +65,7 @@ public override string Execute(IDictionary dataSources, IQue throw new QueryExecutionException("Ambiguous user"); // Precompile mappings with type conversions - var attributeAccessors = CompileColumnMappings(dataSource.Metadata, "systemuser", new Dictionary(StringComparer.OrdinalIgnoreCase) { ["systemuserid"] = UserIdSource }, schema, DateTimeKind.Unspecified, entities); + var attributeAccessors = CompileColumnMappings(dataSource, "systemuser", new Dictionary(StringComparer.OrdinalIgnoreCase) { ["systemuserid"] = UserIdSource }, schema, DateTimeKind.Unspecified, entities); var userIdAccessor = attributeAccessors["systemuserid"]; var userId = (Guid)userIdAccessor(entities[0]); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteMessageNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteMessageNode.cs index 45673227..3030db97 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteMessageNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteMessageNode.cs @@ -119,9 +119,9 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary value.Key, value => { - var exprType = value.Value.GetType(null, null, parameterTypes, out _); + var exprType = value.Value.GetType(dataSources[options.PrimaryDataSource], null, null, parameterTypes, out _); var expectedType = ValueTypes[value.Key]; - var expr = value.Value.Compile(null, parameterTypes); + var expr = value.Value.Compile(dataSources[options.PrimaryDataSource], null, parameterTypes); var conversion = SqlTypeConverter.GetConversion(exprType, expectedType); return (Func, IQueryExecutionOptions, object>) ((IDictionary parameterValues, IQueryExecutionOptions opts) => conversion(expr(null, parameterValues, opts))); }); @@ -154,13 +154,13 @@ public override INodeSchema GetSchema(IDictionary dataSource sortOrder: null); } - private void SetOutputSchema(IAttributeMetadataCache metadata, Message message, TSqlFragment source) + private void SetOutputSchema(DataSource dataSource, Message message, TSqlFragment source) { // Add the response fields to the node schema if (message.OutputParameters.All(f => f.IsScalarType())) { foreach (var value in message.OutputParameters) - AddSchemaColumn(value.Name, SqlTypeConverter.NetToSqlType(value.Type).ToSqlType()); // TODO: How are OSV and ER fields represented? + AddSchemaColumn(value.Name, SqlTypeConverter.NetToSqlType(value.Type).ToSqlType(dataSource)); // TODO: How are OSV and ER fields represented? } else { @@ -172,13 +172,13 @@ private void SetOutputSchema(IAttributeMetadataCache metadata, Message message, if (type == typeof(AuditDetail)) { type = typeof(Entity); - otc = metadata["audit"].ObjectTypeCode; + otc = dataSource.Metadata["audit"].ObjectTypeCode; audit = true; } else if (firstValue.Type == typeof(AuditDetailCollection)) { type = typeof(EntityCollection); - otc = metadata["audit"].ObjectTypeCode; + otc = dataSource.Metadata["audit"].ObjectTypeCode; audit = true; } @@ -187,30 +187,30 @@ private void SetOutputSchema(IAttributeMetadataCache metadata, Message message, else EntityCollectionResponseParameter = firstValue.Name; - foreach (var attrMetadata in metadata[otc.Value].Attributes.Where(a => a.AttributeOf == null)) + foreach (var attrMetadata in dataSource.Metadata[otc.Value].Attributes.Where(a => a.AttributeOf == null)) { - AddSchemaColumn(attrMetadata.LogicalName, attrMetadata.GetAttributeSqlType(metadata, false)); + AddSchemaColumn(attrMetadata.LogicalName, attrMetadata.GetAttributeSqlType(dataSource, false)); // Add standard virtual attributes if (attrMetadata is EnumAttributeMetadata || attrMetadata is BooleanAttributeMetadata) - AddSchemaColumn(attrMetadata.LogicalName + "name", DataTypeHelpers.NVarChar(FetchXmlScan.LabelMaxLength)); + AddSchemaColumn(attrMetadata.LogicalName + "name", DataTypeHelpers.NVarChar(FetchXmlScan.LabelMaxLength, dataSource.DefaultCollation, CollationLabel.CoercibleDefault)); if (attrMetadata is LookupAttributeMetadata lookup) { - AddSchemaColumn(attrMetadata.LogicalName + "name", DataTypeHelpers.NVarChar(lookup.Targets == null || lookup.Targets.Length == 0 ? 100 : lookup.Targets.Select(e => ((StringAttributeMetadata)metadata[e].Attributes.SingleOrDefault(a => a.LogicalName == metadata[e].PrimaryNameAttribute))?.MaxLength ?? 100).Max())); + AddSchemaColumn(attrMetadata.LogicalName + "name", DataTypeHelpers.NVarChar(lookup.Targets == null || lookup.Targets.Length == 0 ? 100 : lookup.Targets.Select(e => ((StringAttributeMetadata)dataSource.Metadata[e].Attributes.SingleOrDefault(a => a.LogicalName == dataSource.Metadata[e].PrimaryNameAttribute))?.MaxLength ?? 100).Max(), dataSource.DefaultCollation, CollationLabel.CoercibleDefault)); if (lookup.Targets?.Length != 1 && lookup.AttributeType != AttributeTypeCode.PartyList) - AddSchemaColumn(attrMetadata.LogicalName + "type", DataTypeHelpers.NVarChar(MetadataExtensions.EntityLogicalNameMaxLength)); + AddSchemaColumn(attrMetadata.LogicalName + "type", DataTypeHelpers.NVarChar(MetadataExtensions.EntityLogicalNameMaxLength, dataSource.DefaultCollation, CollationLabel.CoercibleDefault)); } } if (audit) { - AddSchemaColumn("newvalues", DataTypeHelpers.NVarChar(Int32.MaxValue)); - AddSchemaColumn("oldvalues", DataTypeHelpers.NVarChar(Int32.MaxValue)); + AddSchemaColumn("newvalues", DataTypeHelpers.NVarChar(Int32.MaxValue, dataSource.DefaultCollation, CollationLabel.CoercibleDefault)); + AddSchemaColumn("oldvalues", DataTypeHelpers.NVarChar(Int32.MaxValue, dataSource.DefaultCollation, CollationLabel.CoercibleDefault)); } - _primaryKeyColumn = PrefixWithAlias(metadata[otc.Value].PrimaryIdAttribute); + _primaryKeyColumn = PrefixWithAlias(dataSource.Metadata[otc.Value].PrimaryIdAttribute); } if (!String.IsNullOrEmpty(PagingParameter) && EntityCollectionResponseParameter == null) @@ -487,7 +487,7 @@ public override object Clone() return clone; } - public static ExecuteMessageNode FromMessage(SchemaObjectFunctionTableReference tvf, DataSource dataSource, IDictionary parameterTypes) + public static ExecuteMessageNode FromMessage(SchemaObjectFunctionTableReference tvf, DataSource dataSource, DataSource primaryDataSource, IDictionary parameterTypes) { // All messages are in the "dbo" schema if (tvf.SchemaObject.SchemaIdentifier != null && !String.IsNullOrEmpty(tvf.SchemaObject.SchemaIdentifier.Value) && @@ -539,8 +539,8 @@ public static ExecuteMessageNode FromMessage(SchemaObjectFunctionTableReference { var f = expectedInputParameters[i]; var sourceExpression = tvf.Parameters[i]; - sourceExpression.GetType(null, null, parameterTypes, out var sourceType); - var expectedType = SqlTypeConverter.NetToSqlType(f.Type).ToSqlType(); + sourceExpression.GetType(primaryDataSource, null, null, parameterTypes, out var sourceType); + var expectedType = SqlTypeConverter.NetToSqlType(f.Type).ToSqlType(primaryDataSource); if (!SqlTypeConverter.CanChangeTypeImplicit(sourceType, expectedType)) throw new NotSupportedQueryFragmentException($"Cannot convert value of type {sourceType.ToSql()} to {expectedType.ToSql()}", tvf.Parameters[f.Position]); @@ -561,12 +561,12 @@ public static ExecuteMessageNode FromMessage(SchemaObjectFunctionTableReference node.ValueTypes[f.Name] = f.Type; } - node.SetOutputSchema(dataSource.Metadata, message, tvf.SchemaObject); + node.SetOutputSchema(dataSource, message, tvf.SchemaObject); return node; } - public static ExecuteMessageNode FromMessage(ExecutableProcedureReference sproc, DataSource dataSource, IDictionary parameterTypes) + public static ExecuteMessageNode FromMessage(ExecutableProcedureReference sproc, DataSource dataSource, DataSource primaryDataSource, IDictionary parameterTypes) { // All messages are in the "dbo" schema if (sproc.ProcedureReference.ProcedureReference.Name.SchemaIdentifier != null && !String.IsNullOrEmpty(sproc.ProcedureReference.ProcedureReference.Name.SchemaIdentifier.Value) && @@ -647,8 +647,8 @@ public static ExecuteMessageNode FromMessage(ExecutableProcedureReference sproc, throw new NotSupportedQueryFragmentException("Unknown parameter", sproc.Parameters[i]); var sourceExpression = sproc.Parameters[i].ParameterValue; - sourceExpression.GetType(null, null, parameterTypes, out var sourceType); - var expectedType = SqlTypeConverter.NetToSqlType(targetParam.Type).ToSqlType(); + sourceExpression.GetType(primaryDataSource, null, null, parameterTypes, out var sourceType); + var expectedType = SqlTypeConverter.NetToSqlType(targetParam.Type).ToSqlType(primaryDataSource); if (!SqlTypeConverter.CanChangeTypeImplicit(sourceType, expectedType)) throw new NotSupportedQueryFragmentException($"Cannot convert value of type {sourceType.ToSql()} to {expectedType.ToSql()}", sproc.Parameters[i].ParameterValue); @@ -679,7 +679,7 @@ public static ExecuteMessageNode FromMessage(ExecutableProcedureReference sproc, throw new NotSupportedQueryFragmentException($"Missing parameter '{inputParameter.Name}'", sproc); } - node.SetOutputSchema(dataSource.Metadata, message, sproc.ProcedureReference); + node.SetOutputSchema(dataSource, message, sproc.ProcedureReference); return node; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs index 5fb2e53c..96841cf6 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs @@ -1302,10 +1302,10 @@ public static bool IsType(this DataTypeReference type, SqlDataTypeOption sqlType /// /// The data type to convert /// The equivalent SQL - public static DataTypeReference ToSqlType(this Type type, DataSource primaryDataSource) + public static DataTypeReference ToSqlType(this Type type, DataSource dataSource) { if (type == typeof(SqlString)) - return DataTypeHelpers.NVarChar(Int32.MaxValue, primaryDataSource?.DefaultCollation ?? Collation.USEnglish, CollationLabel.CoercibleDefault); + return DataTypeHelpers.NVarChar(Int32.MaxValue, dataSource?.DefaultCollation ?? Collation.USEnglish, CollationLabel.CoercibleDefault); return _netTypeMapping[type]; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FoldableJoinNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FoldableJoinNode.cs index 130416ff..25d1c078 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FoldableJoinNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FoldableJoinNode.cs @@ -318,7 +318,7 @@ private bool FoldFetchXmlJoin(IDictionary dataSources, IQuer // in the new link entity or we must be using an inner join so we can use a post-filter node var additionalCriteria = AdditionalJoinCriteria; - if (TranslateFetchXMLCriteria(dataSource.Metadata, options, additionalCriteria, rightSchema, rightFetch.Alias, rightEntity.name, rightFetch.Alias, rightEntity.Items, parameterTypes, out var filter)) + if (TranslateFetchXMLCriteria(dataSources[options.PrimaryDataSource], dataSource.Metadata, options, additionalCriteria, rightSchema, rightFetch.Alias, rightEntity.name, rightFetch.Alias, rightEntity.Items, parameterTypes, out var filter)) { rightEntity.AddItem(filter); additionalCriteria = null; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GoToNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GoToNode.cs index 8c30aa03..5cc822fb 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GoToNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GoToNode.cs @@ -110,7 +110,7 @@ public IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary ExecuteInternal(IDictionary>(); var mergedSchema = GetSchema(dataSources, parameterTypes, true); - var additionalJoinCriteria = AdditionalJoinCriteria?.Compile(mergedSchema, parameterTypes); + var additionalJoinCriteria = AdditionalJoinCriteria?.Compile(dataSources[options.PrimaryDataSource], mergedSchema, parameterTypes); // Build the hash table var leftSchema = LeftSource.GetSchema(dataSources, parameterTypes); @@ -37,18 +37,18 @@ protected override IEnumerable ExecuteInternal(IDictionary ExecuteInternal(IDictionary>(new DistinctEqualityComparer(groupByCols)); - InitializeAggregates(schema, parameterTypes); + InitializeAggregates(dataSources[options.PrimaryDataSource], schema, parameterTypes); var aggregates = CreateAggregateFunctions(parameterValues, options, false); if (IsScalarAggregate) @@ -340,7 +340,7 @@ Source is FetchXmlScan fetch && attribute.dategrouping = dateGrouping.Value; attribute.dategroupingSpecified = true; } - else if (grouping.GetType(schema, null, parameterTypes, out _) == typeof(SqlDateTime)) + else if (grouping.GetType(dataSources[options.PrimaryDataSource], schema, null, parameterTypes, out _) == typeof(SqlDateTime)) { // Can't group on datetime columns without a DATEPART specification canUseFetchXmlAggregate = false; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IndexSpoolNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IndexSpoolNode.cs index 000569b8..11128996 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IndexSpoolNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IndexSpoolNode.cs @@ -65,7 +65,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQue meta = dataSource.Metadata[LogicalName]; attributes = meta.Attributes.ToDictionary(a => a.LogicalName, StringComparer.OrdinalIgnoreCase); var dateTimeKind = options.UseLocalTimeZone ? DateTimeKind.Local : DateTimeKind.Utc; - attributeAccessors = CompileColumnMappings(dataSource.Metadata, LogicalName, ColumnMappings, schema, dateTimeKind, entities); + attributeAccessors = CompileColumnMappings(dataSource, LogicalName, ColumnMappings, schema, dateTimeKind, entities); attributeAccessors.TryGetValue(meta.PrimaryIdAttribute, out primaryIdAccessor); } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/NestedLoopNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/NestedLoopNode.cs index 84299778..0bf2af10 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/NestedLoopNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/NestedLoopNode.cs @@ -47,7 +47,7 @@ protected override IEnumerable ExecuteInternal(IDictionary ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) { - var offset = SqlTypeConverter.ChangeType(Offset.Compile(null, parameterTypes)(null, parameterValues, options)); - var fetch = SqlTypeConverter.ChangeType(Fetch.Compile(null, parameterTypes)(null, parameterValues, options)); + var offset = SqlTypeConverter.ChangeType(Offset.Compile(dataSources[options.PrimaryDataSource], null, parameterTypes)(null, parameterValues, options)); + var fetch = SqlTypeConverter.ChangeType(Fetch.Compile(dataSources[options.PrimaryDataSource], null, parameterTypes)(null, parameterValues, options)); if (offset < 0) throw new QueryExecutionException("The offset specified in a OFFSET clause may not be negative."); @@ -63,14 +63,14 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary(offsetLiteral.Compile(null, null)(null, null, options)); - var count = SqlTypeConverter.ChangeType(fetchLiteral.Compile(null, null)(null, null, options)); + var offset = SqlTypeConverter.ChangeType(offsetLiteral.Compile(dataSources[options.PrimaryDataSource], null, null)(null, null, options)); + var count = SqlTypeConverter.ChangeType(fetchLiteral.Compile(dataSources[options.PrimaryDataSource], null, null)(null, null, options)); var page = offset / count; if (page * count == offset && count <= 5000) @@ -94,8 +94,8 @@ protected override RowCountEstimate EstimateRowsOutInternal(IDictionary ExecuteInternal(IDictionary>(new DistinctEqualityComparer(groupByCols)); - InitializePartitionedAggregates(schema, parameterTypes); + InitializePartitionedAggregates(dataSources[options.PrimaryDataSource], schema, parameterTypes); var aggregates = CreateAggregateFunctions(parameterValues, options, true); var fetchXmlNode = (FetchXmlScan)Source; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PrintNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PrintNode.cs index 2e47b356..7a7eff9c 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PrintNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PrintNode.cs @@ -65,7 +65,7 @@ public string Execute(IDictionary dataSources, IQueryExecuti public IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) { - _expression = Expression.Compile(null, parameterTypes); + _expression = Expression.Compile(dataSources[options.PrimaryDataSource], null, parameterTypes); return new[] { this }; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SortNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SortNode.cs index ddad8e5a..01202570 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SortNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SortNode.cs @@ -43,7 +43,7 @@ protected override IEnumerable ExecuteInternal(IDictionary sort.Expression.Compile(schema, parameterTypes)).ToList(); + var expressions = Sorts.Select(sort => sort.Expression.Compile(dataSources[options.PrimaryDataSource], schema, parameterTypes)).ToList(); if (PresortedCount == 0) { @@ -367,7 +367,7 @@ private IDataExecutionPlanNodeInternal FoldSorts(IDictionary if (top != null) { - if (!top.Top.IsConstantValueExpression(null, options, out var topLiteral)) + if (!top.Top.IsConstantValueExpression(dataSources[options.PrimaryDataSource], null, options, out var topLiteral)) return this; if (Int32.Parse(topLiteral.Value, CultureInfo.InvariantCulture) > 50000) @@ -375,8 +375,8 @@ private IDataExecutionPlanNodeInternal FoldSorts(IDictionary } else if (offset != null) { - if (!offset.Offset.IsConstantValueExpression(null, options, out var offsetLiteral) || - !offset.Fetch.IsConstantValueExpression(null, options, out var fetchLiteral)) + if (!offset.Offset.IsConstantValueExpression(dataSources[options.PrimaryDataSource], null, options, out var offsetLiteral) || + !offset.Fetch.IsConstantValueExpression(dataSources[options.PrimaryDataSource], null, options, out var fetchLiteral)) return this; if (Int32.Parse(offsetLiteral.Value, CultureInfo.InvariantCulture) + Int32.Parse(fetchLiteral.Value, CultureInfo.InvariantCulture) > 50000) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/TopNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/TopNode.cs index 3324789b..995581ba 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/TopNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/TopNode.cs @@ -50,12 +50,12 @@ class TopNode : BaseDataNode, ISingleSourceExecutionPlanNode protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) { int topCount; - Top.GetType(null, null, parameterTypes, out var topType); + Top.GetType(dataSources[options.PrimaryDataSource], null, null, parameterTypes, out var topType); if (Percent) { var top = new ConvertCall { Parameter = Top, DataType = DataTypeHelpers.Float }; - var topPercent = (SqlDouble)top.Compile(null, parameterTypes)(null, parameterValues, options); + var topPercent = (SqlDouble)top.Compile(dataSources[options.PrimaryDataSource], null, parameterTypes)(null, parameterValues, options); if (topPercent.IsNull) { @@ -76,7 +76,7 @@ protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecuti using (_timer.Run()) { if (_timeExpr == null) - _timeExpr = Time.Compile(null, parameterTypes); + _timeExpr = Time.Compile(dataSources[options.PrimaryDataSource], null, parameterTypes); var time = (SqlTime) _timeExpr(null, parameterValues, options); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs index b091c2b9..1d2a1162 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs @@ -52,6 +52,8 @@ public ExecutionPlanBuilder(IEnumerable dataSources, IQueryExecution /// public bool EstimatedPlanOnly { get; set; } + private DataSource PrimaryDataSource => DataSources[Options.PrimaryDataSource]; + /// /// Builds the execution plans for a SQL command /// @@ -72,8 +74,8 @@ public IRootExecutionPlanNode[] Build(string sql, IDictionary(); @@ -243,7 +245,7 @@ private IRootExecutionPlanNodeInternal[] ConvertExecuteStatement(ExecuteStatemen var dataSource = SelectDataSource(sproc.ProcedureReference.ProcedureReference.Name); - var node = ExecuteMessageNode.FromMessage(sproc, dataSource, _parameterTypes); + var node = ExecuteMessageNode.FromMessage(sproc, dataSource, PrimaryDataSource, _parameterTypes); var schema = node.GetSchema(DataSources, _parameterTypes); dataSource.MessageCache.TryGetValue(node.MessageName, out var message); @@ -353,7 +355,7 @@ private IRootExecutionPlanNodeInternal ConvertWaitForStatement(WaitForStatement if (waitFor.WaitForOption == WaitForOption.Statement) throw new NotSupportedQueryFragmentException("WAITFOR is not supported", waitFor); - waitFor.Parameter.GetType(null, null, _parameterTypes, out var paramSqlType); + waitFor.Parameter.GetType(PrimaryDataSource, null, null, _parameterTypes, out var paramSqlType); var timeType = DataTypeHelpers.Time(3); if (!SqlTypeConverter.CanChangeTypeImplicit(paramSqlType, timeType)) @@ -408,15 +410,15 @@ private IRootExecutionPlanNodeInternal ConvertPrintStatement(PrintStatement prin // Check the expression for errors. Ensure it can be converted to a string var expr = print.Expression.Clone(); - if (expr.GetType(null, null, _parameterTypes, out _) != typeof(SqlString)) + if (expr.GetType(PrimaryDataSource, null, null, _parameterTypes, out _) != typeof(SqlString)) { expr = new ConvertCall { - DataType = typeof(SqlString).ToSqlType(), + DataType = typeof(SqlString).ToSqlType(PrimaryDataSource), Parameter = expr }; - expr.GetType(null, null, _parameterTypes, out _); + expr.GetType(PrimaryDataSource, null, null, _parameterTypes, out _); } return new PrintNode @@ -436,7 +438,7 @@ private IRootExecutionPlanNodeInternal ConvertIfWhileStatement(ConditionalNodeTy if (subqueryVisitor.Subqueries.Count == 0) { // Check the predicate for errors - predicate.GetType(null, null, _parameterTypes, out _); + predicate.GetType(PrimaryDataSource, null, null, _parameterTypes, out _); } else { @@ -716,7 +718,7 @@ private ExecuteAsNode ConvertExecuteAsStatement(ExecuteAsStatement impersonate) Alias = "systemuser", Schema = { - ["systemuserid"] = typeof(SqlString).ToSqlType() + ["systemuserid"] = typeof(SqlString).ToSqlType(PrimaryDataSource) }, Values = { @@ -943,13 +945,13 @@ attr is LookupAttributeMetadata lookupAttr && if (virtualTypeAttributes.Contains(colName)) { targetName = colName; - targetType = DataTypeHelpers.NVarChar(MetadataExtensions.EntityLogicalNameMaxLength); + targetType = DataTypeHelpers.NVarChar(MetadataExtensions.EntityLogicalNameMaxLength, dataSource.DefaultCollation, CollationLabel.CoercibleDefault); } else { var attr = attributes[colName]; targetName = attr.LogicalName; - targetType = attr.GetAttributeSqlType(dataSource.Metadata, true); + targetType = attr.GetAttributeSqlType(dataSource, true); // If we're inserting into a lookup field, the field type will be a SqlEntityReference. Change this to // a SqlGuid so we can accept any guid values, including from TDS endpoint where SqlEntityReference @@ -985,7 +987,7 @@ attr is LookupAttributeMetadata lookupAttr && if (targetLookupAttribute.Targets.Length > 1 && !virtualTypeAttributes.Contains(targetAttrName + "type") && targetLookupAttribute.AttributeType != AttributeTypeCode.PartyList && - (schema == null || (node.ColumnMappings[targetAttrName].ToColumnReference().GetType(schema, null, null, out var lookupType) != typeof(SqlEntityReference) && lookupType != DataTypeHelpers.ImplicitIntForNullLiteral))) + (schema == null || (node.ColumnMappings[targetAttrName].ToColumnReference().GetType(PrimaryDataSource, schema, null, null, out var lookupType) != typeof(SqlEntityReference) && lookupType != DataTypeHelpers.ImplicitIntForNullLiteral))) { // Special case: not required for listmember.entityid if (metadata.LogicalName == "listmember" && targetLookupAttribute.LogicalName == "entityid") @@ -1439,7 +1441,7 @@ private UpdateNode ConvertSetClause(IList setClauses, HashSet // entityid in listmember has an associated entitytypecode attribute if (virtualTypeAttributes.Contains(targetAttrName)) { - targetType = DataTypeHelpers.NVarChar(MetadataExtensions.EntityLogicalNameMaxLength); + targetType = DataTypeHelpers.NVarChar(MetadataExtensions.EntityLogicalNameMaxLength, dataSource.DefaultCollation, CollationLabel.CoercibleDefault); var targetAttribute = attributes[targetAttrName.Substring(0, targetAttrName.Length - 4)]; targetAttrName = targetAttribute.LogicalName + targetAttrName.Substring(targetAttrName.Length - 4, 4).ToLower(); @@ -1447,7 +1449,7 @@ private UpdateNode ConvertSetClause(IList setClauses, HashSet else { var targetAttribute = attributes[targetAttrName]; - targetType = targetAttribute.GetAttributeSqlType(dataSource.Metadata, true); + targetType = targetAttribute.GetAttributeSqlType(dataSource, true); targetAttrName = targetAttribute.LogicalName; // If we're updating a lookup field, the field type will be a SqlEntityReference. Change this to @@ -1459,7 +1461,7 @@ private UpdateNode ConvertSetClause(IList setClauses, HashSet var sourceColName = select.ColumnSet.Single(col => col.OutputColumn == "new_" + targetAttrName.ToLower()).SourceColumn; var sourceCol = sourceColName.ToColumnReference(); - sourceCol.GetType(schema, null, null, out var sourceType); + sourceCol.GetType(PrimaryDataSource, schema, null, null, out var sourceType); if (!SqlTypeConverter.CanChangeTypeImplicit(sourceType, targetType)) throw new NotSupportedQueryFragmentException($"Cannot convert value of type {sourceType.ToSql()} to {targetType.ToSql()}", assignment); @@ -1862,7 +1864,7 @@ private IDataExecutionPlanNodeInternal ConvertInSubqueries(IDataExecutionPlanNod foreach (var inSubquery in visitor.InSubqueries) { // Validate the LHS expression - inSubquery.Expression.GetType(schema, null, parameterTypes, out _); + inSubquery.Expression.GetType(PrimaryDataSource, schema, null, parameterTypes, out _); // Each query of the format "col1 IN (SELECT col2 FROM source)" becomes a left outer join: // LEFT JOIN source ON col1 = col2 @@ -2139,7 +2141,7 @@ private IDataExecutionPlanNodeInternal ConvertHavingClause(IDataExecutionPlanNod ConvertScalarSubqueries(havingClause.SearchCondition, hints, ref source, computeScalar, parameterTypes, query); // Validate the final expression - havingClause.SearchCondition.GetType(source.GetSchema(DataSources, parameterTypes), nonAggregateSchema, parameterTypes, out _); + havingClause.SearchCondition.GetType(PrimaryDataSource, source.GetSchema(DataSources, parameterTypes), nonAggregateSchema, parameterTypes, out _); return new FilterNode { @@ -2183,7 +2185,7 @@ private IDataExecutionPlanNodeInternal ConvertGroupByAggregates(IDataExecutionPl throw new NotSupportedQueryFragmentException("Unhandled GROUP BY expression", grouping); // Validate the GROUP BY expression - exprGroup.Expression.GetType(schema, null, parameterTypes, out _); + exprGroup.Expression.GetType(PrimaryDataSource, schema, null, parameterTypes, out _); if (exprGroup.Expression is ColumnReferenceExpression col) { @@ -2339,7 +2341,7 @@ private IDataExecutionPlanNodeInternal ConvertGroupByAggregates(IDataExecutionPl if (converted.AggregateType == AggregateType.CountStar) converted.SqlExpression = null; else - converted.SqlExpression.GetType(schema, null, parameterTypes, out _); + converted.SqlExpression.GetType(PrimaryDataSource, schema, null, parameterTypes, out _); // Create a name for the column that holds the aggregate value in the result set. string aggregateName; @@ -2418,8 +2420,8 @@ private IDataExecutionPlanNodeInternal ConvertOffsetClause(IDataExecutionPlanNod if (offsetClause == null) return source; - offsetClause.OffsetExpression.GetType(null, null, parameterTypes, out var offsetType); - offsetClause.FetchExpression.GetType(null, null, parameterTypes, out var fetchType); + offsetClause.OffsetExpression.GetType(PrimaryDataSource, null, null, parameterTypes, out var offsetType); + offsetClause.FetchExpression.GetType(PrimaryDataSource, null, null, parameterTypes, out var fetchType); var intType = DataTypeHelpers.Int; if (!SqlTypeConverter.CanChangeTypeImplicit(offsetType, intType)) @@ -2441,7 +2443,7 @@ private IDataExecutionPlanNodeInternal ConvertTopClause(IDataExecutionPlanNodeIn if (topRowFilter == null) return source; - topRowFilter.Expression.GetType(null, null, parameterTypes, out var topType); + topRowFilter.Expression.GetType(PrimaryDataSource, null, null, parameterTypes, out var topType); var targetType = topRowFilter.Percent ? DataTypeHelpers.Float : DataTypeHelpers.BigInt; if (!SqlTypeConverter.CanChangeTypeImplicit(topType, targetType)) @@ -2559,7 +2561,7 @@ private IDataExecutionPlanNodeInternal ConvertOrderByClause(IDataExecutionPlanNo } // Validate the expression - orderBy.Expression.GetType(schema, nonAggregateSchema, parameterTypes, out _); + orderBy.Expression.GetType(PrimaryDataSource, schema, nonAggregateSchema, parameterTypes, out _); sort.Sorts.Add(orderBy.Clone()); } @@ -2588,7 +2590,7 @@ private IDataExecutionPlanNodeInternal ConvertWhereClause(IDataExecutionPlanNode ConvertScalarSubqueries(whereClause.SearchCondition, hints, ref source, computeScalar, parameterTypes, query); // Validate the final expression - whereClause.SearchCondition.GetType(source.GetSchema(DataSources, parameterTypes), null, parameterTypes, out _); + whereClause.SearchCondition.GetType(PrimaryDataSource, source.GetSchema(DataSources, parameterTypes), null, parameterTypes, out _); return new FilterNode { @@ -2668,7 +2670,7 @@ private SelectNode ConvertSelectClause(IList selectElements, ILis if (!schema.ContainsColumn(colName, out colName)) { // Column name isn't valid. Use the expression extensions to throw a consistent error message - col.GetType(schema, nonAggregateSchema, parameterTypes, out _); + col.GetType(PrimaryDataSource, schema, nonAggregateSchema, parameterTypes, out _); } var alias = scalar.ColumnName?.Value ?? col.MultiPartIdentifier.Identifiers.Last().Value; @@ -2757,7 +2759,7 @@ private string ComputeScalarExpression(ScalarExpression expression, IList Date: Mon, 20 Mar 2023 17:09:15 +0000 Subject: [PATCH 03/34] Refactoring --- .../ExecutionPlanNodeTests.cs | 38 +- .../ExecutionPlanTests.cs | 508 ++++++------------ .../ExpressionTests.cs | 34 +- .../FakeXrmEasyTestsBase.cs | 50 +- .../Sql2FetchXmlTests.cs | 407 ++++++-------- .../Ado/Sql4CdsDataReader.cs | 8 +- .../Ado/Sql4CdsParameter.cs | 2 +- MarkMpn.Sql4Cds.Engine/DataSource.cs | 2 +- MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs | 1 + .../ExecutionPlan/Aggregate.cs | 46 +- .../ExecutionPlan/AliasNode.cs | 24 +- .../ExecutionPlan/AssertNode.cs | 20 +- .../ExecutionPlan/AssignVariablesNode.cs | 14 +- .../ExecutionPlan/BaseAggregateNode.cs | 35 +- .../ExecutionPlan/BaseDataNode.cs | 88 ++- .../ExecutionPlan/BaseDmlNode.cs | 52 +- .../ExecutionPlan/BaseJoinNode.cs | 14 +- .../ExecutionPlan/BaseNode.cs | 5 +- .../ExecutionPlan/BulkDeleteJobNode.cs | 12 +- .../ExecutionPlan/ComputeScalarNode.cs | 40 +- .../ExecutionPlan/ConcatenateNode.cs | 32 +- .../ExecutionPlan/ConditionalNode.cs | 26 +- .../ExecutionPlan/ConstantScanNode.cs | 17 +- .../ExecutionPlan/ContinueBreakNode.cs | 10 +- .../ExecutionPlan/DeclareVariablesNode.cs | 10 +- .../ExecutionPlan/DeleteNode.cs | 28 +- .../ExecutionPlan/DistinctNode.cs | 26 +- .../ExecutionPlan/ExecuteAsNode.cs | 10 +- .../ExecutionPlan/ExecuteMessageNode.cs | 61 ++- .../ExecutionPlan/ExpressionExtensions.cs | 319 ++++++----- .../ExecutionPlan/FetchXmlScan.cs | 67 ++- .../ExecutionPlan/FilterNode.cs | 144 ++--- .../ExecutionPlan/FoldableJoinNode.cs | 100 ++-- .../ExecutionPlan/GlobalOptionSetQueryNode.cs | 12 +- .../ExecutionPlan/GoToNode.cs | 18 +- .../ExecutionPlan/GotoLabelNode.cs | 4 +- .../ExecutionPlan/HashJoinNode.cs | 41 +- .../ExecutionPlan/HashMatchAggregateNode.cs | 46 +- .../ExecutionPlan/IDataExecutionPlanNode.cs | 17 +- .../IDataSetExecutionPlanNode.cs | 7 +- .../IDmlQueryExecutionPlanNode.cs | 5 +- .../ExecutionPlan/IExecutionPlanNode.cs | 7 +- .../ExecutionPlan/IGoToNode.cs | 7 +- .../ExecutionPlan/IRootExecutionPlanNode.cs | 6 +- .../ExecutionPlan/IndexSpoolNode.cs | 28 +- .../ExecutionPlan/InsertNode.cs | 22 +- .../ExecutionPlan/MergeJoinNode.cs | 42 +- .../ExecutionPlan/MetadataQueryNode.cs | 29 +- .../ExecutionPlan/NestedLoopNode.cs | 93 ++-- .../ExecutionPlan/OffsetFetchNode.cs | 42 +- .../ExecutionPlan/PartitionedAggregateNode.cs | 64 +-- .../ExecutionPlan/PrintNode.cs | 12 +- .../RetrieveTotalRecordCountNode.cs | 13 +- .../ExecutionPlan/RevertNode.cs | 8 +- .../ExecutionPlan/SelectNode.cs | 34 +- .../ExecutionPlan/SortNode.cs | 76 +-- .../ExecutionPlan/SqlNode.cs | 18 +- .../ExecutionPlan/SqlTypeConverter.cs | 99 ++-- .../ExecutionPlan/StreamAggregateNode.cs | 24 +- .../ExecutionPlan/TableSpoolNode.cs | 26 +- .../ExecutionPlan/TopNode.cs | 45 +- .../ExecutionPlan/TryCatchNode.cs | 30 +- .../ExecutionPlan/UpdateNode.cs | 20 +- .../ExecutionPlan/WaitForNode.cs | 16 +- .../ExecutionPlanBuilder.cs | 113 ++-- .../ExecutionPlanOptimizer.cs | 6 +- MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs | 38 +- .../MarkMpn.Sql4Cds.Engine.projitems | 1 + MarkMpn.Sql4Cds.Engine/MetaMetadataCache.cs | 4 +- MarkMpn.Sql4Cds.Engine/MetadataExtensions.cs | 10 +- MarkMpn.Sql4Cds.Engine/NodeContext.cs | 196 +++++++ 71 files changed, 1801 insertions(+), 1728 deletions(-) create mode 100644 MarkMpn.Sql4Cds.Engine/NodeContext.cs diff --git a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanNodeTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanNodeTests.cs index 41721271..48fac417 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanNodeTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanNodeTests.cs @@ -37,8 +37,8 @@ public void ConstantScanTest() Alias = "test" }; - var results = node.Execute(_dataSources, new StubOptions(), null, null).ToArray(); - + var results = node.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)).ToArray(); + Assert.AreEqual(1, results.Length); Assert.AreEqual("Mark", ((SqlString)results[0]["test.firstname"]).Value); } @@ -84,7 +84,7 @@ public void FilterNodeTest() } }; - var results = node.Execute(_dataSources, new StubOptions(), null, null).ToArray(); + var results = node.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)).ToArray(); Assert.AreEqual(1, results.Length); Assert.AreEqual("Mark", ((SqlString)results[0]["test.firstname"]).Value); @@ -161,7 +161,7 @@ public void MergeJoinInnerTest() JoinType = QualifiedJoinType.Inner }; - var results = node.Execute(_dataSources, new StubOptions(), null, null).ToArray(); + var results = node.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)).ToArray(); Assert.AreEqual(2, results.Length); Assert.AreEqual("Mark", ((SqlString)results[0]["f.firstname"]).Value); @@ -241,7 +241,7 @@ public void MergeJoinLeftOuterTest() JoinType = QualifiedJoinType.LeftOuter }; - var results = node.Execute(_dataSources, new StubOptions(), null, null).ToArray(); + var results = node.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)).ToArray(); Assert.AreEqual(3, results.Length); Assert.AreEqual("Mark", ((SqlString)results[0]["f.firstname"]).Value); @@ -323,7 +323,7 @@ public void MergeJoinRightOuterTest() JoinType = QualifiedJoinType.RightOuter }; - var results = node.Execute(_dataSources, new StubOptions(), null, null).ToArray(); + var results = node.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)).ToArray(); Assert.AreEqual(3, results.Length); Assert.AreEqual("Mark", ((SqlString)results[0]["f.firstname"]).Value); @@ -362,7 +362,7 @@ public void AssertionTest() ErrorMessage = "Only Mark is allowed" }; - var results = node.Execute(_dataSources, new StubOptions(), null, null).GetEnumerator(); + var results = node.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)).GetEnumerator(); Assert.IsTrue(results.MoveNext()); Assert.AreEqual("Mark", results.Current.GetAttributeValue("test.name").Value); @@ -420,7 +420,7 @@ public void ComputeScalarTest() } }; - var results = node.Execute(_dataSources, new StubOptions(), null, null) + var results = node.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)) .Select(e => e.GetAttributeValue("mul").Value) .ToArray(); @@ -462,7 +462,7 @@ public void DistinctTest() } }; - var results = node.Execute(_dataSources, new StubOptions(), null, null) + var results = node.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)) .Select(e => e.GetAttributeValue("test.value1").Value) .ToArray(); @@ -504,7 +504,7 @@ public void DistinctCaseInsensitiveTest() } }; - var results = node.Execute(_dataSources, new StubOptions(), null, null) + var results = node.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)) .Select(e => e.GetAttributeValue("test.value1").Value) .ToArray(); @@ -562,7 +562,7 @@ public void SortNodeTest() } }; - var results = node.Execute(_dataSources, new StubOptions(), null, null) + var results = node.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)) .Select(e => e.GetAttributeValue("test.expectedorder").Value) .ToArray(); @@ -627,7 +627,7 @@ public void SortNodePresortedTest() } }; - var results = node.Execute(_dataSources, new StubOptions(), null, null) + var results = node.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)) .Select(e => e.GetAttributeValue("test.expectedorder").Value) .ToArray(); @@ -659,11 +659,11 @@ public void TableSpoolTest() var spool = new TableSpoolNode { Source = source }; - var results1 = spool.Execute(_dataSources, new StubOptions(), null, null) + var results1 = spool.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)) .Select(e => e.GetAttributeValue("test.value1").Value) .ToArray(); - var results2 = spool.Execute(_dataSources, new StubOptions(), null, null) + var results2 = spool.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)) .Select(e => e.GetAttributeValue("test.value1").Value) .ToArray(); @@ -708,7 +708,7 @@ public void CaseInsenstiveHashMatchAggregateNodeTest() } }; - var results = spool.Execute(_dataSources, new StubOptions(), null, null) + var results = spool.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)) .Select(e => new { Name = e.GetAttributeValue("src.value1").Value, Count = e.GetAttributeValue("count").Value }) .ToArray(); @@ -749,7 +749,7 @@ public void SqlTransformSingleResult() public void AggregateInitialTest() { var aggregate = CreateAggregateTest(); - var result = aggregate.Execute(_dataSources, new StubOptions(), null, null).Single(); + var result = aggregate.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)).Single(); Assert.AreEqual(SqlInt32.Null, result["min"]); Assert.AreEqual(SqlInt32.Null, result["max"]); @@ -767,7 +767,7 @@ public void AggregateInitialTest() public void AggregateSingleValueTest() { var aggregate = CreateAggregateTest(1); - var result = aggregate.Execute(_dataSources, new StubOptions(), null, null).Single(); + var result = aggregate.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)).Single(); Assert.AreEqual((SqlInt32)1, result["min"]); Assert.AreEqual((SqlInt32)1, result["max"]); @@ -785,7 +785,7 @@ public void AggregateSingleValueTest() public void AggregateTwoEqualValuesTest() { var aggregate = CreateAggregateTest(1, 1); - var result = aggregate.Execute(_dataSources, new StubOptions(), null, null).Single(); + var result = aggregate.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)).Single(); Assert.AreEqual((SqlInt32)1, result["min"]); Assert.AreEqual((SqlInt32)1, result["max"]); @@ -803,7 +803,7 @@ public void AggregateTwoEqualValuesTest() public void AggregateMultipleValuesTest() { var aggregate = CreateAggregateTest(1, 3, 1, 1); - var result = aggregate.Execute(_dataSources, new StubOptions(), null, null).Single(); + var result = aggregate.Execute(new NodeExecutionContext(_localDataSource, new StubOptions(), null, null)).Single(); Assert.AreEqual((SqlInt32)1, result["min"]); Assert.AreEqual((SqlInt32)3, result["max"]); diff --git a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs index 6b78f1c2..f42b08df 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs @@ -11,6 +11,7 @@ using System.Threading.Tasks; using System.Xml.Serialization; using FakeXrmEasy; +using FakeXrmEasy.Extensions; using MarkMpn.Sql4Cds.Engine.ExecutionPlan; using Microsoft.SqlServer.TransactSql.ScriptDom; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -83,8 +84,7 @@ void IQueryExecutionOptions.Progress(double? progress, string message) [TestMethod] public void SimpleSelect() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT accountid, name FROM account"; @@ -106,8 +106,7 @@ public void SimpleSelect() [TestMethod] public void SimpleSelectStar() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT * FROM account"; @@ -141,8 +140,7 @@ public void SimpleSelectStar() [TestMethod] public void Join() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT accountid, name FROM account INNER JOIN contact ON account.accountid = contact.parentcustomerid"; @@ -166,8 +164,7 @@ public void Join() [TestMethod] public void JoinWithExtraCondition() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -200,8 +197,7 @@ public void JoinWithExtraCondition() [TestMethod] public void NonUniqueJoin() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT accountid, name FROM account INNER JOIN contact ON account.name = contact.fullname"; @@ -229,8 +225,7 @@ public void NonUniqueJoin() [TestMethod] public void NonUniqueJoinExpression() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT accountid, name FROM account INNER JOIN contact ON account.name = (contact.firstname + ' ' + contact.lastname)"; @@ -265,8 +260,7 @@ public void NonUniqueJoinExpression() [TestMethod] public void SimpleWhere() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -298,8 +292,7 @@ public void SimpleWhere() [TestMethod] public void SimpleSort() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -329,8 +322,7 @@ ORDER BY [TestMethod] public void SimpleSortIndex() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -360,8 +352,7 @@ ORDER BY [TestMethod] public void SimpleDistinct() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT DISTINCT @@ -387,8 +378,7 @@ SELECT DISTINCT [TestMethod] public void SimpleTop() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT TOP 10 @@ -415,8 +405,7 @@ SELECT TOP 10 [TestMethod] public void SimpleOffset() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -446,8 +435,7 @@ ORDER BY name [TestMethod] public void SimpleGroupAggregate() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -513,8 +501,7 @@ public void SimpleGroupAggregate() [TestMethod] public void AliasedAggregate() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -564,8 +551,7 @@ public void AliasedAggregate() [TestMethod] public void AliasedGroupingAggregate() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -615,8 +601,7 @@ public void AliasedGroupingAggregate() [TestMethod] public void SimpleAlias() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT accountid, name AS test FROM account"; @@ -638,8 +623,7 @@ public void SimpleAlias() [TestMethod] public void SimpleHaving() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -705,8 +689,7 @@ GROUP BY name [TestMethod] public void GroupByDatePart() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -764,8 +747,7 @@ public void GroupByDatePart() [TestMethod] public void GroupByDatePartUsingYearMonthDayFunctions() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -844,8 +826,7 @@ GROUP BY [TestMethod] public void PartialOrdering() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -882,8 +863,7 @@ ORDER BY [TestMethod] public void PartialOrderingAvoidingLegacyPagingWithTop() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT TOP 100 @@ -922,8 +902,7 @@ ORDER BY [TestMethod] public void PartialWhere() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -959,8 +938,7 @@ public void PartialWhere() [TestMethod] public void ComputeScalarSelect() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT firstname + ' ' + lastname AS fullname FROM contact WHERE firstname = 'Mark'"; @@ -989,8 +967,7 @@ public void ComputeScalarSelect() [TestMethod] public void ComputeScalarFilter() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT contactid FROM contact WHERE firstname + ' ' + lastname = 'Mark Carrington'"; @@ -1016,8 +993,7 @@ public void ComputeScalarFilter() [TestMethod] public void SelectSubqueryWithMergeJoin() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT firstname + ' ' + lastname AS fullname, 'Account: ' + (SELECT name FROM account WHERE accountid = parentcustomerid) AS accountname FROM contact WHERE firstname = 'Mark'"; @@ -1050,8 +1026,7 @@ public void SelectSubqueryWithMergeJoin() [TestMethod] public void SelectSubqueryWithNestedLoop() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT firstname + ' ' + lastname AS fullname, 'Account: ' + (SELECT name FROM account WHERE createdon = contact.createdon) AS accountname FROM contact WHERE firstname = 'Mark'"; @@ -1100,8 +1075,7 @@ public void SelectSubqueryWithNestedLoop() [TestMethod] public void SelectSubqueryWithChildRecordUsesNestedLoop() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT name, (SELECT TOP 1 fullname FROM contact WHERE parentcustomerid = account.accountid) FROM account WHERE name = 'Data8'"; @@ -1147,8 +1121,7 @@ public void SelectSubqueryWithChildRecordUsesNestedLoop() [TestMethod] public void SelectSubqueryWithSmallNestedLoop() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT TOP 10 firstname + ' ' + lastname AS fullname, 'Account: ' + (SELECT name FROM account WHERE createdon = contact.createdon) AS accountname FROM contact WHERE firstname = 'Mark'"; @@ -1193,8 +1166,7 @@ public void SelectSubqueryWithSmallNestedLoop() [TestMethod] public void SelectSubqueryWithNonCorrelatedNestedLoop() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT firstname + ' ' + lastname AS fullname, 'Account: ' + (SELECT TOP 1 name FROM account) AS accountname FROM contact WHERE firstname = 'Mark'"; @@ -1236,8 +1208,7 @@ public void SelectSubqueryWithNonCorrelatedNestedLoop() [TestMethod] public void SelectSubqueryWithCorrelatedSpooledNestedLoop() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT firstname + ' ' + lastname AS fullname, 'Account: ' + (SELECT name FROM account WHERE createdon = contact.createdon) AS accountname FROM contact WHERE firstname = 'Mark'"; @@ -1287,8 +1258,7 @@ public void SelectSubqueryWithCorrelatedSpooledNestedLoop() [TestMethod] public void SelectSubqueryWithPartiallyCorrelatedSpooledNestedLoop() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT firstname + ' ' + lastname AS fullname, 'Account: ' + (SELECT name FROM account WHERE createdon = contact.createdon AND employees > 10) AS accountname FROM contact WHERE firstname = 'Mark'"; @@ -1339,9 +1309,8 @@ public void SelectSubqueryWithPartiallyCorrelatedSpooledNestedLoop() [TestMethod] public void SelectSubqueryUsingOuterReferenceInSelectClause() { - var metadata = new AttributeMetadataCache(_service); var tableSize = new StubTableSizeCache(); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT firstname + ' ' + lastname AS fullname, 'Account: ' + (SELECT firstname + ' ' + name FROM account WHERE accountid = parentcustomerid) AS accountname FROM contact WHERE firstname = 'Mark'"; @@ -1389,8 +1358,7 @@ public void SelectSubqueryUsingOuterReferenceInSelectClause() [TestMethod] public void SelectSubqueryUsingOuterReferenceInOrderByClause() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT firstname FROM contact ORDER BY (SELECT TOP 1 name FROM account WHERE accountid = parentcustomerid ORDER BY firstname)"; @@ -1430,8 +1398,7 @@ public void SelectSubqueryUsingOuterReferenceInOrderByClause() [TestMethod] public void WhereSubquery() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT firstname + ' ' + lastname AS fullname FROM contact WHERE (SELECT name FROM account WHERE accountid = parentcustomerid) = 'Data8'"; @@ -1463,8 +1430,7 @@ public void WhereSubquery() [TestMethod] public void ComputeScalarDistinct() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT DISTINCT TOP 10 @@ -1494,8 +1460,7 @@ SELECT DISTINCT TOP 10 [TestMethod] public void UnionAll() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT name FROM account @@ -1534,8 +1499,7 @@ UNION ALL [TestMethod] public void SimpleInFilter() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -1570,8 +1534,7 @@ public void SimpleInFilter() [TestMethod] public void SubqueryInFilterUncorrelated() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -1612,8 +1575,7 @@ public void SubqueryInFilterUncorrelated() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void SubqueryInFilterMultipleColumnsError() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -1630,8 +1592,7 @@ public void SubqueryInFilterMultipleColumnsError() [TestMethod] public void SubqueryInFilterUncorrelatedPrimaryKey() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -1666,8 +1627,7 @@ public void SubqueryInFilterUncorrelatedPrimaryKey() [TestMethod] public void SubqueryInFilterCorrelated() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -1715,8 +1675,7 @@ public void SubqueryInFilterCorrelated() [TestMethod] public void SubqueryNotInFilterCorrelated() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -1764,8 +1723,7 @@ public void SubqueryNotInFilterCorrelated() [TestMethod] public void ExistsFilterUncorrelated() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -1804,8 +1762,7 @@ public void ExistsFilterUncorrelated() [TestMethod] public void ExistsFilterCorrelatedPrimaryKey() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -1837,8 +1794,7 @@ public void ExistsFilterCorrelatedPrimaryKey() [TestMethod] public void ExistsFilterCorrelatedPrimaryKeyOr() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -1875,8 +1831,7 @@ public void ExistsFilterCorrelatedPrimaryKeyOr() [TestMethod] public void ExistsFilterCorrelated() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -1920,8 +1875,7 @@ public void ExistsFilterCorrelated() [TestMethod] public void NotExistsFilterCorrelated() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -1966,8 +1920,7 @@ public void NotExistsFilterCorrelated() [TestMethod] public void QueryDerivedTableSimple() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT TOP 10 @@ -1998,8 +1951,7 @@ SELECT TOP 10 [TestMethod] public void QueryDerivedTableAlias() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT TOP 10 @@ -2030,8 +1982,7 @@ SELECT TOP 10 [TestMethod] public void QueryDerivedTableValues() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT TOP 10 @@ -2050,7 +2001,7 @@ SELECT TOP 10 var filter = AssertNode(top.Source); var constant = AssertNode(filter.Source); - var schema = constant.GetSchema(_dataSources, null); + var schema = constant.GetSchema(new NodeCompilationContext(_dataSources, this, null)); Assert.AreEqual(typeof(SqlInt32), schema.Schema["a.ID"].ToNetType(out _)); Assert.AreEqual(typeof(SqlString), schema.Schema["a.name"].ToNetType(out _)); } @@ -2058,8 +2009,7 @@ SELECT TOP 10 [TestMethod] public void NoLockTableHint() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT TOP 10 @@ -2089,8 +2039,7 @@ SELECT TOP 10 [TestMethod] public void CrossJoin() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -2129,8 +2078,7 @@ CROSS JOIN [TestMethod] public void CrossApply() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -2167,8 +2115,7 @@ FROM contact [TestMethod] public void OuterApply() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -2205,8 +2152,7 @@ FROM contact [TestMethod] public void OuterApplyNestedLoop() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -2264,8 +2210,7 @@ ORDER BY firstname [TestMethod] public void FetchXmlNativeWhere() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -2297,8 +2242,7 @@ public void FetchXmlNativeWhere() [TestMethod] public void SimpleMetadataSelect() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT logicalname @@ -2319,8 +2263,7 @@ SELECT logicalname [TestMethod] public void SimpleMetadataWhere() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT logicalname @@ -2346,8 +2289,7 @@ FROM metadata.entity [TestMethod] public void CaseSensitiveMetadataWhere() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT logicalname @@ -2378,8 +2320,7 @@ FROM metadata.entity [TestMethod] public void SimpleUpdate() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "UPDATE account SET name = 'foo' WHERE name = 'bar'"; @@ -2408,8 +2349,7 @@ public void SimpleUpdate() [TestMethod] public void UpdateFromJoin() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "UPDATE a SET name = 'foo' FROM account a INNER JOIN contact c ON a.accountid = c.parentcustomerid WHERE name = 'bar'"; @@ -2441,8 +2381,7 @@ public void UpdateFromJoin() [TestMethod] public void QueryHints() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT accountid, name FROM account OPTION (OPTIMIZE FOR UNKNOWN, FORCE ORDER, RECOMPILE, USE HINT('DISABLE_OPTIMIZER_ROWGOAL'), USE HINT('ENABLE_QUERY_OPTIMIZER_HOTFIXES'), LOOP JOIN, MERGE JOIN, HASH JOIN, NO_PERFORMANCE_SPOOL, MAXRECURSION 2)"; @@ -2464,8 +2403,7 @@ public void QueryHints() [TestMethod] public void AggregateSort() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT name, count(*) from account group by name order by 2 desc"; @@ -2517,8 +2455,7 @@ public void AggregateSort() [TestMethod] public void FoldFilterWithNonFoldedJoin() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT name from account INNER JOIN contact ON left(name, 4) = left(firstname, 4) where name like 'Data8%' and firstname like 'Mark%'"; @@ -2555,8 +2492,7 @@ public void FoldFilterWithNonFoldedJoin() [TestMethod] public void FoldFilterWithInClause() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT name from account where name like 'Data8%' and primarycontactid in (select contactid from contact where firstname = 'Mark')"; @@ -2585,8 +2521,7 @@ public void FoldFilterWithInClause() [TestMethod] public void FoldFilterWithInClauseOr() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT name from account where name like 'Data8%' or primarycontactid in (select contactid from contact where firstname = 'Mark')"; @@ -2620,8 +2555,7 @@ public void FoldFilterWithInClauseWithoutPrimaryKey() try { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT name from account where name like 'Data8%' and createdon in (select createdon from contact where firstname = 'Mark')"; @@ -2659,8 +2593,7 @@ public void FoldFilterWithInClauseOnLinkEntityWithoutPrimaryKey() try { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT name from account inner join contact on account.accountid = contact.parentcustomerid where name like 'Data8%' and contact.createdon in (select createdon from contact where firstname = 'Mark')"; @@ -2700,8 +2633,7 @@ public void FoldFilterWithExistsClauseWithoutPrimaryKey() try { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT name from account where name like 'Data8%' and exists (select * from contact where firstname = 'Mark' and createdon = account.createdon)"; @@ -2736,8 +2668,7 @@ public void FoldFilterWithExistsClauseWithoutPrimaryKey() [TestMethod] public void DistinctNotRequiredWithPrimaryKey() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT DISTINCT accountid, name from account"; @@ -2759,8 +2690,7 @@ public void DistinctNotRequiredWithPrimaryKey() [TestMethod] public void DistinctRequiredWithoutPrimaryKey() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT DISTINCT accountid, name from account INNER JOIN contact ON account.accountid = contact.parentcustomerid"; @@ -2786,8 +2716,7 @@ public void DistinctRequiredWithoutPrimaryKey() [TestMethod] public void SimpleDelete() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "DELETE FROM account WHERE name = 'bar'"; @@ -2813,8 +2742,7 @@ public void SimpleDelete() [TestMethod] public void SimpleInsertSelect() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "INSERT INTO account (name) SELECT fullname FROM contact WHERE firstname = 'Mark'"; @@ -2840,8 +2768,7 @@ public void SimpleInsertSelect() [TestMethod] public void SelectDuplicateColumnNames() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT fullname, lastname + ', ' + firstname as fullname FROM contact WHERE firstname = 'Mark'"; @@ -2874,8 +2801,7 @@ public void SelectDuplicateColumnNames() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void SubQueryDuplicateColumnNamesError() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT * FROM (SELECT fullname, lastname + ', ' + firstname as fullname FROM contact WHERE firstname = 'Mark') a"; @@ -2885,8 +2811,7 @@ public void SubQueryDuplicateColumnNamesError() [TestMethod] public void UnionDuplicateColumnNames() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT fullname, lastname + ', ' + firstname as fullname FROM contact WHERE firstname = 'Mark' UNION @@ -2917,8 +2842,7 @@ public void UnionDuplicateColumnNames() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void SubQueryUnionDuplicateColumnNamesError() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT * FROM ( SELECT fullname, lastname + ', ' + firstname as fullname FROM contact WHERE firstname = 'Mark' UNION @@ -2954,8 +2878,7 @@ private void AssertFetchXml(FetchXmlScan node, string fetchXml) [TestMethod] public void SelectStarInSubquery() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT * FROM account WHERE accountid IN (SELECT parentcustomerid FROM contact)"; @@ -2995,8 +2918,7 @@ public void SelectStarInSubquery() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void CannotSelectColumnsFromSemiJoin() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT contact.* FROM account WHERE accountid IN (SELECT parentcustomerid FROM contact)"; @@ -3006,8 +2928,7 @@ public void CannotSelectColumnsFromSemiJoin() [TestMethod] public void MinAggregateNotFoldedToFetchXmlForOptionset() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT new_name, min(new_optionsetvalue) FROM new_customentity GROUP BY new_name"; @@ -3032,8 +2953,7 @@ public void MinAggregateNotFoldedToFetchXmlForOptionset() [TestMethod] public void HelpfulErrorMessageOnMissingGroupBy() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT new_name, min(new_optionsetvalue) FROM new_customentity"; @@ -3051,8 +2971,7 @@ public void HelpfulErrorMessageOnMissingGroupBy() [TestMethod] public void AggregateInSubquery() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT firstname FROM contact @@ -3144,8 +3063,8 @@ GROUP BY firstname ["firstname"] = "Matt" }, }; - - var result = select.Execute(_localDataSource, this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + + var result = select.Execute(new NodeExecutionContext(_localDataSource, this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(result); @@ -3157,8 +3076,7 @@ GROUP BY firstname [TestMethod] public void SelectVirtualNameAttributeFromLinkEntity() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT parentcustomeridname FROM account INNER JOIN contact ON account.accountid = contact.parentcustomerid"; @@ -3181,8 +3099,7 @@ public void SelectVirtualNameAttributeFromLinkEntity() [TestMethod] public void DuplicatedDistinctColumns() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT DISTINCT name AS n1, name AS n2 FROM account"; @@ -3204,8 +3121,7 @@ public void DuplicatedDistinctColumns() [TestMethod] public void GroupByDatetimeWithoutDatePart() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT createdon, COUNT(*) FROM account GROUP BY createdon"; @@ -3228,8 +3144,7 @@ public void GroupByDatetimeWithoutDatePart() [TestMethod] public void MetadataExpressions() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT collectionschemaname + '.' + entitysetname FROM metadata.entity WHERE description LIKE '%test%'"; @@ -3254,8 +3169,7 @@ public void MetadataExpressions() [TestMethod] public void AliasedAttribute() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT name AS n1 FROM account WHERE name = 'test'"; @@ -3282,8 +3196,7 @@ public void AliasedAttribute() [TestMethod] public void MultipleAliases() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT name AS n1, name AS n2 FROM account WHERE name = 'test'"; @@ -3322,7 +3235,8 @@ public void CrossInstanceJoin() Connection = _context.GetOrganizationService(), Metadata = metadata1, TableSizeCache = new StubTableSizeCache(), - MessageCache = new StubMessageCache() + MessageCache = new StubMessageCache(), + DefaultCollation = Collation.USEnglish }, new DataSource { @@ -3330,11 +3244,13 @@ public void CrossInstanceJoin() Connection = _context2.GetOrganizationService(), Metadata = metadata2, TableSizeCache = new StubTableSizeCache(), - MessageCache = new StubMessageCache() + MessageCache = new StubMessageCache(), + DefaultCollation = Collation.USEnglish }, new DataSource { - Name = "local" // Hack so that ((IQueryExecutionOptions)this).PrimaryDataSource = "local" doesn't cause test to fail + Name = "local", // Hack so that ((IQueryExecutionOptions)this).PrimaryDataSource = "local" doesn't cause test to fail + DefaultCollation = Collation.USEnglish } }; var planBuilder = new ExecutionPlanBuilder(datasources, this); @@ -3383,8 +3299,7 @@ public void CrossInstanceJoin() [TestMethod] public void FilterOnGroupByExpression() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT @@ -3428,8 +3343,7 @@ GROUP BY [TestMethod] public void SystemFunctions() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT CURRENT_TIMESTAMP, CURRENT_USER, GETDATE(), USER_NAME()"; @@ -3445,8 +3359,7 @@ public void SystemFunctions() [TestMethod] public void FoldEqualsCurrentUser() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT name FROM account WHERE ownerid = CURRENT_USER"; @@ -3470,8 +3383,7 @@ public void FoldEqualsCurrentUser() [TestMethod] public void EntityReferenceInQuery() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT name FROM account WHERE accountid IN ('0000000000000000-0000-0000-000000000000', '0000000000000000-0000-0000-000000000001')"; @@ -3498,8 +3410,7 @@ public void EntityReferenceInQuery() [TestMethod] public void OrderBySelectExpression() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT name + 'foo' FROM account ORDER BY 1"; @@ -3525,8 +3436,7 @@ public void OrderBySelectExpression() [TestMethod] public void DistinctOrderByUsesScalarAggregate() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT DISTINCT account.accountid FROM metadata.entity INNER JOIN account ON entity.metadataid = account.accountid"; @@ -3556,8 +3466,7 @@ public void DistinctOrderByUsesScalarAggregate() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void WindowFunctionsNotSupported() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT COUNT(accountid) OVER(PARTITION BY accountid) AS test FROM account"; @@ -3567,8 +3476,7 @@ public void WindowFunctionsNotSupported() [TestMethod] public void DeclareVariableSetLiteralSelect() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" DECLARE @test int @@ -3605,7 +3513,7 @@ DECLARE @test int { if (plan is IDataReaderExecutionPlanNode selectQuery) { - var results = selectQuery.Execute(_dataSources, this, parameterTypes, parameterValues, CommandBehavior.Default); + var results = selectQuery.Execute(new NodeExecutionContext(_dataSources, this, parameterTypes, parameterValues), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(results); @@ -3615,7 +3523,7 @@ DECLARE @test int } else if (plan is IDmlQueryExecutionPlanNode dmlQuery) { - dmlQuery.Execute(_dataSources, this, parameterTypes, parameterValues, out _); + dmlQuery.Execute(new NodeExecutionContext(_dataSources, this, parameterTypes, parameterValues), out _); } } } @@ -3623,8 +3531,7 @@ DECLARE @test int [TestMethod] public void SetVariableInDeclaration() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" DECLARE @test int = 1 @@ -3660,7 +3567,7 @@ public void SetVariableInDeclaration() { if (plan is IDataReaderExecutionPlanNode selectQuery) { - var results = selectQuery.Execute(_dataSources, this, parameterTypes, parameterValues, CommandBehavior.Default); + var results = selectQuery.Execute(new NodeExecutionContext(_dataSources, this, parameterTypes, parameterValues), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(results); @@ -3670,7 +3577,7 @@ public void SetVariableInDeclaration() } else if (plan is IDmlQueryExecutionPlanNode dmlQuery) { - dmlQuery.Execute(_dataSources, this, parameterTypes, parameterValues, out _); + dmlQuery.Execute(new NodeExecutionContext(_dataSources, this, parameterTypes, parameterValues), out _); } } } @@ -3679,8 +3586,7 @@ public void SetVariableInDeclaration() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void UnknownVariable() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SET @test = 1"; @@ -3691,8 +3597,7 @@ public void UnknownVariable() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void DuplicateVariable() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" DECLARE @test INT @@ -3704,8 +3609,7 @@ DECLARE @test INT [TestMethod] public void VariableTypeConversionIntToString() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" DECLARE @test varchar(3) @@ -3721,7 +3625,7 @@ DECLARE @test varchar(3) { if (plan is IDataReaderExecutionPlanNode selectQuery) { - var results = selectQuery.Execute(_dataSources, this, parameterTypes, parameterValues, CommandBehavior.Default); + var results = selectQuery.Execute(new NodeExecutionContext(_dataSources, this, parameterTypes, parameterValues), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(results); @@ -3731,7 +3635,7 @@ DECLARE @test varchar(3) } else if (plan is IDmlQueryExecutionPlanNode dmlQuery) { - dmlQuery.Execute(_dataSources, this, parameterTypes, parameterValues, out _); + dmlQuery.Execute(new NodeExecutionContext(_dataSources, this, parameterTypes, parameterValues), out _); } } } @@ -3739,8 +3643,7 @@ DECLARE @test varchar(3) [TestMethod] public void VariableTypeConversionStringTruncation() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" DECLARE @test varchar(3) @@ -3756,7 +3659,7 @@ DECLARE @test varchar(3) { if (plan is IDataReaderExecutionPlanNode selectQuery) { - var results = selectQuery.Execute(_dataSources, this, parameterTypes, parameterValues, CommandBehavior.Default); + var results = selectQuery.Execute(new NodeExecutionContext(_localDataSource, this, parameterTypes, parameterValues), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(results); @@ -3766,7 +3669,7 @@ DECLARE @test varchar(3) } else if (plan is IDmlQueryExecutionPlanNode dmlQuery) { - dmlQuery.Execute(_dataSources, this, parameterTypes, parameterValues, out _); + dmlQuery.Execute(new NodeExecutionContext(_localDataSource, this, parameterTypes, parameterValues), out _); } } } @@ -3775,8 +3678,7 @@ DECLARE @test varchar(3) [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void CannotCombineSetVariableAndDataRetrievalInSelect() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); // A SELECT statement that assigns a value to a variable must not be combined with data-retrieval operations var query = @" @@ -3789,8 +3691,7 @@ DECLARE @test varchar(3) [TestMethod] public void SetVariableWithSelectUsesFinalValue() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" DECLARE @test varchar(3) @@ -3850,7 +3751,7 @@ DECLARE @test varchar(3) { if (plan is IDataReaderExecutionPlanNode selectQuery) { - var results = selectQuery.Execute(_localDataSource, this, parameterTypes, parameterValues, CommandBehavior.Default); + var results = selectQuery.Execute(new NodeExecutionContext(_localDataSource, this, parameterTypes, parameterValues), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(results); @@ -3860,7 +3761,7 @@ DECLARE @test varchar(3) } else if (plan is IDmlQueryExecutionPlanNode dmlQuery) { - dmlQuery.Execute(_localDataSource, this, parameterTypes, parameterValues, out _); + dmlQuery.Execute(new NodeExecutionContext(_localDataSource, this, parameterTypes, parameterValues), out _); } } } @@ -3868,8 +3769,7 @@ DECLARE @test varchar(3) [TestMethod] public void VarCharLengthDefaultsTo1() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" DECLARE @test varchar @@ -3885,7 +3785,7 @@ DECLARE @test varchar { if (plan is IDataReaderExecutionPlanNode selectQuery) { - var results = selectQuery.Execute(_dataSources, this, parameterTypes, parameterValues, CommandBehavior.Default); + var results = selectQuery.Execute(new NodeExecutionContext(_localDataSource, this, parameterTypes, parameterValues), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(results); @@ -3895,7 +3795,7 @@ DECLARE @test varchar } else if (plan is IDmlQueryExecutionPlanNode dmlQuery) { - dmlQuery.Execute(_dataSources, this, parameterTypes, parameterValues, out _); + dmlQuery.Execute(new NodeExecutionContext(_localDataSource, this, parameterTypes, parameterValues), out _); } } } @@ -3904,8 +3804,7 @@ DECLARE @test varchar [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void CursorVariableNotSupported() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" DECLARE @test CURSOR"; @@ -3917,8 +3816,7 @@ public void CursorVariableNotSupported() [ExpectedException(typeof(NotSupportedQueryFragmentException))] public void TableVariableNotSupported() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" DECLARE @test TABLE (ID INT)"; @@ -3929,8 +3827,7 @@ public void TableVariableNotSupported() [TestMethod] public void IfStatement() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this) { EstimatedPlanOnly = true }; + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this) { EstimatedPlanOnly = true }; var query = @" IF @param1 = 1 @@ -3964,8 +3861,7 @@ INSERT INTO account (name) VALUES ('one') [TestMethod] public void WhileStatement() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this) { EstimatedPlanOnly = true }; + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this) { EstimatedPlanOnly = true }; var query = @" WHILE @param1 < 10 @@ -3994,8 +3890,7 @@ INSERT INTO account (name) VALUES (@param1) [TestMethod] public void IfNotExists() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this) { EstimatedPlanOnly = true }; + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this) { EstimatedPlanOnly = true }; var query = @" IF NOT EXISTS(SELECT * FROM account WHERE name = @param1) @@ -4020,8 +3915,7 @@ INSERT INTO account (name) VALUES (@param1) [TestMethod] public void DuplicatedAliases() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT name, createdon AS name FROM account"; @@ -4043,8 +3937,7 @@ public void DuplicatedAliases() [TestMethod] public void MetadataLeftJoinData() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT entity.logicalname, account.name, contact.firstname @@ -4083,8 +3976,7 @@ public void MetadataLeftJoinData() [TestMethod] public void NotEqualExcludesNull() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" SELECT name FROM account WHERE name <> 'Data8'"; @@ -4110,8 +4002,7 @@ public void NotEqualExcludesNull() [TestMethod] public void DoNotFoldFilterOnNameVirtualAttributeWithTooManyJoins() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @" select top 10 a.name @@ -4204,8 +4095,7 @@ from account a [TestMethod] public void FilterOnVirtualTypeAttributeEquals() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT firstname FROM contact WHERE parentcustomeridtype = 'contact'"; @@ -4229,8 +4119,7 @@ public void FilterOnVirtualTypeAttributeEquals() [TestMethod] public void FilterOnVirtualTypeAttributeEqualsImpossible() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT firstname FROM contact WHERE parentcustomeridtype = 'non-existent-entity'"; @@ -4254,8 +4143,7 @@ public void FilterOnVirtualTypeAttributeEqualsImpossible() [TestMethod] public void FilterOnVirtualTypeAttributeNotEquals() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT firstname FROM contact WHERE parentcustomeridtype <> 'contact'"; @@ -4280,8 +4168,7 @@ public void FilterOnVirtualTypeAttributeNotEquals() [TestMethod] public void FilterOnVirtualTypeAttributeNotInImpossible() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT firstname FROM contact WHERE parentcustomeridtype NOT IN ('account', 'contact')"; @@ -4308,8 +4195,7 @@ public void FilterOnVirtualTypeAttributeNotInImpossible() [TestMethod] public void FilterOnVirtualTypeAttributeNull() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT firstname FROM contact WHERE parentcustomeridtype IS NULL"; @@ -4333,8 +4219,7 @@ public void FilterOnVirtualTypeAttributeNull() [TestMethod] public void FilterOnVirtualTypeAttributeNotNull() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT firstname FROM contact WHERE parentcustomeridtype IS NOT NULL"; @@ -4358,8 +4243,7 @@ public void FilterOnVirtualTypeAttributeNotNull() [TestMethod] public void SubqueriesInValueList() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT a FROM (VALUES ('a'), ((SELECT TOP 1 firstname FROM contact)), ('b'), (1)) AS MyTable (a)"; @@ -4388,8 +4272,7 @@ public void SubqueriesInValueList() [TestMethod] public void FoldFilterOnIdentity() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT name FROM account WHERE accountid = @@IDENTITY"; @@ -4414,8 +4297,7 @@ public void FoldFilterOnIdentity() [TestMethod] public void FoldPrimaryIdInQuery() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT name FROM account WHERE accountid IN (SELECT accountid FROM account INNER JOIN contact ON account.primarycontactid = contact.contactid WHERE name = 'Data8')"; @@ -4441,8 +4323,7 @@ public void FoldPrimaryIdInQuery() [TestMethod] public void FoldPrimaryIdInQueryWithTop() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"DELETE FROM account WHERE accountid IN (SELECT TOP 10 accountid FROM account ORDER BY createdon DESC)"; @@ -4465,8 +4346,7 @@ public void FoldPrimaryIdInQueryWithTop() [TestMethod] public void InsertParameters() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"DECLARE @name varchar(100) = 'test'; INSERT INTO account (name) VALUES (@name)"; @@ -4484,8 +4364,7 @@ public void InsertParameters() [TestMethod] public void NotExistsParameters() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"DECLARE @firstname AS VARCHAR (100) = 'Mark', @lastname AS VARCHAR (100) = 'Carrington'; @@ -4526,8 +4405,7 @@ INSERT INTO contact (firstname, lastname) [TestMethod] public void UpdateParameters() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"declare @name varchar(100) = 'Data8', @employees int = 10 UPDATE account SET employees = @employees WHERE name = @name"; @@ -4558,8 +4436,7 @@ public void UpdateParameters() [TestMethod] public void CountUsesAggregateByDefault() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT count(*) FROM account"; @@ -4583,8 +4460,7 @@ public void CountUsesAggregateByDefault() [TestMethod] public void CountUsesRetrieveTotalRecordCountWithHint() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT count(*) FROM account OPTION (USE HINT ('RETRIEVE_TOTAL_RECORD_COUNT'))"; @@ -4603,8 +4479,7 @@ public void CountUsesRetrieveTotalRecordCountWithHint() [TestMethod] public void MaxDOPUsesDefault() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"UPDATE account SET name = 'test'"; @@ -4619,8 +4494,7 @@ public void MaxDOPUsesDefault() [TestMethod] public void MaxDOPUsesHint() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"UPDATE account SET name = 'test' OPTION (MAXDOP 7)"; @@ -4635,8 +4509,7 @@ public void MaxDOPUsesHint() [TestMethod] public void SubqueryUsesSpoolByDefault() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT accountid, (SELECT TOP 1 fullname FROM contact) FROM account"; @@ -4655,8 +4528,7 @@ public void SubqueryUsesSpoolByDefault() [TestMethod] public void SubqueryDoesntUseSpoolWithHint() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT accountid, (SELECT TOP 1 fullname FROM contact) FROM account OPTION (NO_PERFORMANCE_SPOOL)"; @@ -4674,8 +4546,7 @@ public void SubqueryDoesntUseSpoolWithHint() [TestMethod] public void BypassPluginExecutionUsesDefault() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"UPDATE account SET name = 'test'"; @@ -4690,8 +4561,7 @@ public void BypassPluginExecutionUsesDefault() [TestMethod] public void BypassPluginExecutionUsesHint() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"UPDATE account SET name = 'test' OPTION (USE HINT ('BYPASS_CUSTOM_PLUGIN_EXECUTION'))"; @@ -4706,8 +4576,7 @@ public void BypassPluginExecutionUsesHint() [TestMethod] public void PageSizeUsesHint() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT name FROM account OPTION (USE HINT ('FETCHXML_PAGE_SIZE_100'))"; @@ -4729,8 +4598,7 @@ public void PageSizeUsesHint() [TestMethod] public void DistinctOrderByOptionSet() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT DISTINCT new_optionsetvalue FROM new_customentity ORDER BY new_optionsetvalue"; @@ -4752,8 +4620,7 @@ public void DistinctOrderByOptionSet() [TestMethod] public void DistinctVirtualAttribute() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT DISTINCT new_optionsetvaluename FROM new_customentity"; @@ -4778,8 +4645,7 @@ public void DistinctVirtualAttribute() [TestMethod] public void TopAliasStar() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT TOP 10 A.* FROM account A"; @@ -4800,8 +4666,7 @@ public void TopAliasStar() [TestMethod] public void OrderByStar() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT * FROM account ORDER BY primarycontactid"; @@ -4823,8 +4688,7 @@ public void OrderByStar() [TestMethod] public void UpdateColumnInWhereClause() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "UPDATE account SET name = '1' WHERE name <> '1'"; @@ -4854,8 +4718,7 @@ public void UpdateColumnInWhereClause() [TestMethod] public void NestedOrFilters() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT * FROM account WHERE name = '1' OR name = '2' OR name = '3' OR name = '4'"; @@ -4883,8 +4746,7 @@ public void NestedOrFilters() [TestMethod] public void UnknownHint() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT * FROM account OPTION(USE HINT('invalid'))"; @@ -4894,8 +4756,7 @@ public void UnknownHint() [TestMethod] public void MultipleTablesJoinFromWhereClause() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT firstname FROM account, contact WHERE accountid = parentcustomerid AND lastname = 'Carrington' AND name = 'Data8'"; @@ -4924,8 +4785,7 @@ public void MultipleTablesJoinFromWhereClause() [TestMethod] public void MultipleTablesJoinFromWhereClauseReversed() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT firstname FROM account, contact WHERE lastname = 'Carrington' AND name = 'Data8' AND parentcustomerid = accountid"; @@ -4954,8 +4814,7 @@ public void MultipleTablesJoinFromWhereClauseReversed() [TestMethod] public void MultipleTablesJoinFromWhereClause3() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT firstname FROM account, contact, systemuser WHERE accountid = parentcustomerid AND lastname = 'Carrington' AND name = 'Data8' AND account.ownerid = systemuserid"; @@ -4985,8 +4844,7 @@ public void MultipleTablesJoinFromWhereClause3() [TestMethod] public void NestedInSubqueries() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT firstname FROM contact WHERE parentcustomerid IN (SELECT accountid FROM account WHERE primarycontactid IN (SELECT contactid FROM contact WHERE lastname = 'Carrington'))"; @@ -5014,8 +4872,7 @@ public void NestedInSubqueries() [TestMethod] public void SpoolNestedLoop() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT account.name, contact.fullname FROM account INNER JOIN contact ON account.accountid = contact.parentcustomerid OR account.createdon < contact.createdon"; @@ -5050,8 +4907,7 @@ public void SpoolNestedLoop() [TestMethod] public void SelectFromTVF() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT * FROM SampleMessage('test')"; @@ -5070,8 +4926,7 @@ public void SelectFromTVF() [TestMethod] public void OuterApplyCorrelatedTVF() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT account.name, msg.OutputParam1 FROM account OUTER APPLY (SELECT * FROM SampleMessage(account.name)) AS msg WHERE account.name = 'Data8'"; @@ -5098,8 +4953,7 @@ public void OuterApplyCorrelatedTVF() [TestMethod] public void OuterApplyUncorrelatedTVF() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT account.name, msg.OutputParam1 FROM account OUTER APPLY (SELECT * FROM SampleMessage('test')) AS msg WHERE account.name = 'Data8'"; @@ -5126,8 +4980,7 @@ public void OuterApplyUncorrelatedTVF() [TestMethod] public void TVFScalarSubqueryParameter() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT * FROM SampleMessage((SELECT TOP 1 name FROM account))"; @@ -5169,8 +5022,7 @@ public void TVFScalarSubqueryParameter() [TestMethod] public void ExecuteSproc() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "EXEC SampleMessage 'test'"; @@ -5187,8 +5039,7 @@ public void ExecuteSproc() [TestMethod] public void ExecuteSprocNamedParameters() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"DECLARE @i int EXEC SampleMessage @StringParam = 'test', @OutputParam2 = @i OUTPUT @@ -5206,8 +5057,7 @@ public void ExecuteSprocNamedParameters() [TestMethod] public void FoldMultipleJoinConditionsWithKnownValue() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = @"SELECT a.name, c.fullname FROM account a INNER JOIN contact c ON a.accountid = c.parentcustomerid AND a.name = c.fullname WHERE a.name = 'Data8'"; @@ -5236,8 +5086,7 @@ public void FoldMultipleJoinConditionsWithKnownValue() [TestMethod] public void ExplicitCollation() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT 'abc' COLLATE French_CI_AS"; @@ -5248,8 +5097,7 @@ public void ExplicitCollation() [TestMethod] public void TwoExplicitCollationsError() { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var query = "SELECT ('abc' COLLATE French_CI_AS) COLLATE French_CS_AS"; diff --git a/MarkMpn.Sql4Cds.Engine.Tests/ExpressionTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/ExpressionTests.cs index 7181611a..b324c11c 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/ExpressionTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/ExpressionTests.cs @@ -13,7 +13,7 @@ namespace MarkMpn.Sql4Cds.Engine.Tests { [TestClass] - public class ExpressionTests + public class ExpressionTests : FakeXrmEasyTestsBase { [TestMethod] public void SimpleCaseExpression() @@ -38,28 +38,29 @@ public void SimpleCaseExpression() }; var schema = new NodeSchema(new Dictionary { - ["name"] = DataTypeHelpers.NVarChar(100) + ["name"] = DataTypeHelpers.NVarChar(100, Collation.USEnglish, CollationLabel.CoercibleDefault) }, new Dictionary>(), null, Array.Empty(), Array.Empty()); var parameterTypes = new Dictionary(); - var func = expr.Compile(schema, parameterTypes); - var options = new StubOptions(); - var parameterValues = new Dictionary(); + var compilationContext = new ExpressionCompilationContext(_localDataSource, options, parameterTypes, schema, null); + var func = expr.Compile(compilationContext); var record = new Entity { - ["name"] = SqlTypeConverter.UseDefaultCollation("One") + ["name"] = compilationContext.PrimaryDataSource.DefaultCollation.ToSqlString("One") }; + var executionContext = new ExpressionExecutionContext(compilationContext); + executionContext.Entity = record; - var value = func(record, parameterValues, options); + var value = func(executionContext); Assert.AreEqual((SqlInt32)1, value); - record["name"] = SqlTypeConverter.UseDefaultCollation("Two"); - value = func(record, parameterValues, options); + record["name"] = compilationContext.PrimaryDataSource.DefaultCollation.ToSqlString("Two"); + value = func(executionContext); Assert.AreEqual((SqlInt32)2, value); - record["name"] = SqlTypeConverter.UseDefaultCollation("Five"); - value = func(record, parameterValues, options); + record["name"] = compilationContext.PrimaryDataSource.DefaultCollation.ToSqlString("Five"); + value = func(executionContext); Assert.AreEqual((SqlInt32)3, value); } @@ -80,18 +81,19 @@ public void FormatDateTime() ["createdon"] = DataTypeHelpers.DateTime }, new Dictionary>(), null, Array.Empty(), Array.Empty()); var parameterTypes = new Dictionary(); - var func = expr.Compile(schema, parameterTypes); - var options = new StubOptions(); - var parameterValues = new Dictionary(); + var compilationContext = new ExpressionCompilationContext(_localDataSource, options, parameterTypes, schema, null); + var func = expr.Compile(compilationContext); var record = new Entity { ["createdon"] = (SqlDateTime)new DateTime(2022, 1, 2) }; + var executionContext = new ExpressionExecutionContext(compilationContext); + executionContext.Entity = record; - var value = func(record, parameterValues, options); - Assert.AreEqual(SqlTypeConverter.UseDefaultCollation("2022-01-02"), value); + var value = func(executionContext); + Assert.AreEqual(compilationContext.PrimaryDataSource.DefaultCollation.ToSqlString("2022-01-02"), value); } private ColumnReferenceExpression Col(string name) diff --git a/MarkMpn.Sql4Cds.Engine.Tests/FakeXrmEasyTestsBase.cs b/MarkMpn.Sql4Cds.Engine.Tests/FakeXrmEasyTestsBase.cs index e14228a9..5ee5f543 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/FakeXrmEasyTestsBase.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/FakeXrmEasyTestsBase.cs @@ -34,7 +34,7 @@ public FakeXrmEasyTestsBase() _context.AddGenericFakeMessageExecutor(SampleMessageExecutor.MessageName, new SampleMessageExecutor()); _service = _context.GetOrganizationService(); - _dataSource = new DataSource { Name = "uat", Connection = _service, Metadata = new AttributeMetadataCache(_service), TableSizeCache = new StubTableSizeCache(), MessageCache = new StubMessageCache() }; + _dataSource = new DataSource { Name = "uat", Connection = _service, Metadata = new AttributeMetadataCache(_service), TableSizeCache = new StubTableSizeCache(), MessageCache = new StubMessageCache(), DefaultCollation = Collation.USEnglish }; _context2 = new XrmFakedContext(); _context2.InitializeMetadata(Assembly.GetExecutingAssembly()); @@ -43,17 +43,20 @@ public FakeXrmEasyTestsBase() _context2.AddGenericFakeMessageExecutor(SampleMessageExecutor.MessageName, new SampleMessageExecutor()); _service2 = _context2.GetOrganizationService(); - _dataSource2 = new DataSource { Name = "prod", Connection = _service2, Metadata = new AttributeMetadataCache(_service2), TableSizeCache = new StubTableSizeCache(), MessageCache = new StubMessageCache() }; + _dataSource2 = new DataSource { Name = "prod", Connection = _service2, Metadata = new AttributeMetadataCache(_service2), TableSizeCache = new StubTableSizeCache(), MessageCache = new StubMessageCache(), DefaultCollation = Collation.USEnglish }; _dataSources = new[] { _dataSource, _dataSource2 }.ToDictionary(ds => ds.Name); _localDataSource = new Dictionary { - ["local"] = new DataSource { Name = "local", Connection = _service, Metadata = _dataSource.Metadata, TableSizeCache = _dataSource.TableSizeCache, MessageCache = _dataSource.MessageCache } + ["local"] = new DataSource { Name = "local", Connection = _service, Metadata = _dataSource.Metadata, TableSizeCache = _dataSource.TableSizeCache, MessageCache = _dataSource.MessageCache, DefaultCollation = Collation.USEnglish } }; SetPrimaryIdAttributes(_context); SetPrimaryIdAttributes(_context2); + SetPrimaryNameAttributes(_context); + SetPrimaryNameAttributes(_context2); + SetLookupTargets(_context); SetLookupTargets(_context2); @@ -64,6 +67,18 @@ public FakeXrmEasyTestsBase() SetMaxLength(_context2); } + private void SetPrimaryNameAttributes(XrmFakedContext context) + { + foreach (var entity in context.CreateMetadataQuery()) + { + if (entity.LogicalName != "contact") + continue; + + // Set the primary name attribute on contact + typeof(EntityMetadata).GetProperty(nameof(EntityMetadata.PrimaryNameAttribute)).SetValue(entity, "fullname"); + } + } + private void SetPrimaryIdAttributes(XrmFakedContext context) { foreach (var entity in context.CreateMetadataQuery()) @@ -119,11 +134,32 @@ private void SetAttributeOf(XrmFakedContext context) { foreach (var entity in context.CreateMetadataQuery()) { - if (entity.LogicalName != "new_customentity") + if (entity.LogicalName == "new_customentity") + { + var attr = entity.Attributes.Single(a => a.LogicalName == "new_optionsetvaluename"); + typeof(AttributeMetadata).GetProperty(nameof(AttributeMetadata.AttributeOf)).SetValue(attr, "new_optionsetvalue"); + + var valueAttr = (EnumAttributeMetadata)entity.Attributes.Single(a => a.LogicalName == "new_optionsetvalue"); + valueAttr.OptionSet = new OptionSetMetadata + { + Options = + { + new OptionMetadata(new Label { UserLocalizedLabel = new LocalizedLabel(Metadata.New_OptionSet.Value1.ToString(), 1033) }, (int) Metadata.New_OptionSet.Value1), + new OptionMetadata(new Label { UserLocalizedLabel = new LocalizedLabel(Metadata.New_OptionSet.Value2.ToString(), 1033) }, (int) Metadata.New_OptionSet.Value2), + new OptionMetadata(new Label { UserLocalizedLabel = new LocalizedLabel(Metadata.New_OptionSet.Value3.ToString(), 1033) }, (int) Metadata.New_OptionSet.Value3) + } + }; + } + else if (entity.LogicalName == "account") + { + // Add metadata for primarycontactidname virtual attribute + var nameAttr = entity.Attributes.Single(a => a.LogicalName == "primarycontactidname"); + nameAttr.GetType().GetProperty(nameof(AttributeMetadata.AttributeOf)).SetValue(nameAttr, "primarycontactid"); + } + else + { continue; - - var attr = entity.Attributes.Single(a => a.LogicalName == "new_optionsetvaluename"); - typeof(AttributeMetadata).GetProperty(nameof(AttributeMetadata.AttributeOf)).SetValue(attr, "new_optionsetvalue"); + } context.SetEntityMetadata(entity); } diff --git a/MarkMpn.Sql4Cds.Engine.Tests/Sql2FetchXmlTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/Sql2FetchXmlTests.cs index 1b56c914..de6ee303 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/Sql2FetchXmlTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/Sql2FetchXmlTests.cs @@ -1,4 +1,5 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Data; using System.Data.SqlTypes; @@ -11,6 +12,7 @@ using System.Threading; using System.Xml.Serialization; using FakeXrmEasy; +using FakeXrmEasy.Extensions; using FakeXrmEasy.FakeMessageExecutors; using MarkMpn.Sql4Cds.Engine.ExecutionPlan; using Microsoft.SqlServer.TransactSql.ScriptDom; @@ -59,8 +61,7 @@ public void SimpleSelect() { var query = "SELECT accountid, name FROM account"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -78,8 +79,7 @@ public void SelectSameFieldMultipleTimes() { var query = "SELECT accountid, name, name FROM account"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -104,8 +104,7 @@ public void SelectStar() { var query = "SELECT * FROM account"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -136,8 +135,7 @@ public void SelectStarAndField() { var query = "SELECT *, name FROM account"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -169,8 +167,7 @@ public void SimpleFilter() { var query = "SELECT accountid, name FROM account WHERE name = 'test'"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -191,8 +188,7 @@ public void BetweenFilter() { var query = "SELECT accountid, name FROM account WHERE employees BETWEEN 1 AND 10 AND turnover NOT BETWEEN 2 AND 20"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -218,8 +214,7 @@ public void FetchFilter() { var query = "SELECT contactid, firstname FROM contact WHERE createdon = lastxdays(7)"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -240,8 +235,7 @@ public void NestedFilters() { var query = "SELECT accountid, name FROM account WHERE name = 'test' OR (accountid is not null and name like 'foo%')"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -266,8 +260,7 @@ public void Sorts() { var query = "SELECT accountid, name FROM account ORDER BY name DESC, accountid"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -287,8 +280,7 @@ public void SortByColumnIndex() { var query = "SELECT accountid, name FROM account ORDER BY 2 DESC, 1"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -308,8 +300,7 @@ public void SortByAliasedColumn() { var query = "SELECT accountid, name as accountname FROM account ORDER BY name"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -328,8 +319,7 @@ public void Top() { var query = "SELECT TOP 10 accountid, name FROM account"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -347,8 +337,7 @@ public void TopBrackets() { var query = "SELECT TOP (10) accountid, name FROM account"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -366,8 +355,7 @@ public void Top10KUsesExtension() { var query = "SELECT TOP 10000 accountid, name FROM account"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -388,8 +376,7 @@ public void NoLock() { var query = "SELECT accountid, name FROM account (NOLOCK)"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -407,8 +394,7 @@ public void Distinct() { var query = "SELECT DISTINCT name FROM account"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -426,8 +412,7 @@ public void Offset() { var query = "SELECT accountid, name FROM account ORDER BY name OFFSET 100 ROWS FETCH NEXT 50 ROWS ONLY"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -446,8 +431,7 @@ public void SimpleJoin() { var query = "SELECT accountid, name FROM account INNER JOIN contact ON primarycontactid = contactid"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -467,8 +451,7 @@ public void SelfReferentialJoin() { var query = "SELECT contact.contactid, contact.firstname, manager.firstname FROM contact LEFT OUTER JOIN contact AS manager ON contact.parentcustomerid = manager.contactid"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -489,8 +472,7 @@ public void AdditionalJoinCriteria() { var query = "SELECT accountid, name FROM account INNER JOIN contact ON accountid = parentcustomerid AND (firstname = 'Mark' OR lastname = 'Carrington')"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -514,8 +496,7 @@ public void InvalidAdditionalJoinCriteria() { var query = "SELECT accountid, name FROM account INNER JOIN contact ON accountid = parentcustomerid OR (firstname = 'Mark' AND lastname = 'Carrington')"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); Assert.IsNotInstanceOfType(((SelectNode)queries[0]).Source, typeof(FetchXmlScan)); @@ -526,8 +507,7 @@ public void SortOnLinkEntity() { var query = "SELECT TOP 100 accountid, name FROM account INNER JOIN contact ON primarycontactid = contactid ORDER BY name, firstname"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -549,8 +529,7 @@ public void InvalidSortOnLinkEntity() { var query = "SELECT TOP 100 accountid, name FROM account INNER JOIN contact ON accountid = parentcustomerid ORDER BY name, firstname"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -575,8 +554,7 @@ public void SimpleAggregate() { var query = "SELECT count(*), count(name), count(DISTINCT name), max(name), min(name), avg(name) FROM account"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -598,8 +576,7 @@ public void GroupBy() { var query = "SELECT name, count(*) FROM account GROUP BY name"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -618,8 +595,7 @@ public void GroupBySorting() { var query = "SELECT name, count(*) FROM account GROUP BY name ORDER BY name, count(*)"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -639,8 +615,7 @@ public void GroupBySortingOnLinkEntity() { var query = "SELECT name, firstname, count(*) FROM account INNER JOIN contact ON parentcustomerid = account.accountid GROUP BY name, firstname ORDER BY firstname, name"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -663,8 +638,7 @@ public void GroupBySortingOnAliasedAggregate() { var query = "SELECT name, firstname, count(*) as count FROM account INNER JOIN contact ON parentcustomerid = account.accountid GROUP BY name, firstname ORDER BY count"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -686,8 +660,7 @@ public void UpdateFieldToValue() { var query = "UPDATE contact SET firstname = 'Mark'"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -704,8 +677,7 @@ public void SelectArithmetic() { var query = "SELECT employees + 1 AS a, employees * 2 AS b, turnover / 3 AS c, turnover - 4 AS d, turnover / employees AS e FROM account"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -728,7 +700,7 @@ public void SelectArithmetic() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -745,8 +717,7 @@ public void WhereComparingTwoFields() { var query = "SELECT contactid FROM contact WHERE firstname = lastname"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), new OptionsWrapper(this) { ColumnComparisonAvailable = false }); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, new OptionsWrapper(this) { ColumnComparisonAvailable = false }); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -777,7 +748,7 @@ public void WhereComparingTwoFields() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -790,8 +761,7 @@ public void WhereComparingExpression() { var query = "SELECT contactid FROM contact WHERE lastname = firstname + 'rington'"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -822,7 +792,7 @@ public void WhereComparingExpression() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -835,8 +805,7 @@ public void BackToFrontLikeExpression() { var query = "SELECT contactid FROM contact WHERE 'Mark' like firstname"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -864,7 +833,7 @@ public void BackToFrontLikeExpression() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -877,8 +846,7 @@ public void UpdateFieldToField() { var query = "UPDATE contact SET firstname = lastname"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -900,7 +868,7 @@ public void UpdateFieldToField() } }; - ((UpdateNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), out _); + ((UpdateNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), out _); Assert.AreEqual("Carrington", _context.Data["contact"][guid]["firstname"]); } @@ -910,8 +878,7 @@ public void UpdateFieldToExpression() { var query = "UPDATE contact SET firstname = 'Hello ' + lastname"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -933,7 +900,7 @@ public void UpdateFieldToExpression() } }; - ((UpdateNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), out _); + ((UpdateNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), out _); Assert.AreEqual("Hello Carrington", _context.Data["contact"][guid]["firstname"]); } @@ -943,8 +910,7 @@ public void UpdateReplace() { var query = "UPDATE contact SET firstname = REPLACE(firstname, 'Dataflex Pro', 'CDS') WHERE lastname = 'Carrington'"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -970,7 +936,7 @@ public void UpdateReplace() } }; - ((UpdateNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), out _); + ((UpdateNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), out _); Assert.AreEqual("--CDS--", _context.Data["contact"][guid]["firstname"]); } @@ -980,8 +946,7 @@ public void StringFunctions() { var query = "SELECT trim(firstname) as trim, ltrim(firstname) as ltrim, rtrim(firstname) as rtrim, substring(firstname, 2, 3) as substring23, len(firstname) as len FROM contact"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1002,7 +967,7 @@ public void StringFunctions() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1021,8 +986,7 @@ public void SelectExpression() { var query = "SELECT firstname, 'Hello ' + firstname AS greeting FROM contact"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1043,7 +1007,7 @@ public void SelectExpression() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1060,7 +1024,8 @@ private IDictionary GetDataSources(XrmFakedContext context) Connection = context.GetOrganizationService(), Metadata = new AttributeMetadataCache(context.GetOrganizationService()), TableSizeCache = new StubTableSizeCache(), - MessageCache = new StubMessageCache() + MessageCache = new StubMessageCache(), + DefaultCollation = Collation.USEnglish }; return new Dictionary { ["local"] = dataSource }; @@ -1071,8 +1036,7 @@ public void SelectExpressionNullValues() { var query = "SELECT firstname, 'Hello ' + firstname AS greeting, case when createdon > '2020-01-01' then 'new' else 'old' end AS age FROM contact"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1093,7 +1057,7 @@ public void SelectExpressionNullValues() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1108,8 +1072,7 @@ public void OrderByExpression() { var query = "SELECT firstname, lastname FROM contact ORDER BY lastname + ', ' + firstname"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1139,7 +1102,7 @@ public void OrderByExpression() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1153,8 +1116,7 @@ public void OrderByAliasedField() { var query = "SELECT firstname, lastname AS surname FROM contact ORDER BY surname"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1185,7 +1147,7 @@ public void OrderByAliasedField() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1199,8 +1161,7 @@ public void OrderByCalculatedField() { var query = "SELECT firstname, lastname, lastname + ', ' + firstname AS fullname FROM contact ORDER BY fullname"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1230,7 +1191,7 @@ public void OrderByCalculatedField() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1244,8 +1205,7 @@ public void OrderByCalculatedFieldByIndex() { var query = "SELECT firstname, lastname, lastname + ', ' + firstname AS fullname FROM contact ORDER BY 3"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1275,7 +1235,7 @@ public void OrderByCalculatedFieldByIndex() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1289,8 +1249,7 @@ public void DateCalculations() { var query = "SELECT contactid, DATEADD(day, 1, createdon) AS nextday, DATEPART(minute, createdon) AS minute FROM contact WHERE DATEDIFF(hour, '2020-01-01', createdon) < 1"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1318,7 +1277,7 @@ public void DateCalculations() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1333,8 +1292,7 @@ public void TopAppliedAfterCustomFilter() { var query = "SELECT TOP 10 contactid FROM contact WHERE firstname = lastname"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), new OptionsWrapper(this) { ColumnComparisonAvailable = false }); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, new OptionsWrapper(this) { ColumnComparisonAvailable = false }); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1356,8 +1314,7 @@ public void CustomFilterAggregateHavingProjectionSortAndTop() { var query = "SELECT TOP 10 lastname, SUM(CASE WHEN firstname = 'Mark' THEN 1 ELSE 0 END) as nummarks, LEFT(lastname, 1) AS lastinitial FROM contact WHERE DATEDIFF(day, '2020-01-01', createdon) > 10 GROUP BY lastname HAVING count(*) > 1 ORDER BY 2 DESC"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1405,7 +1362,7 @@ public void CustomFilterAggregateHavingProjectionSortAndTop() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1425,8 +1382,7 @@ public void FilterCaseInsensitive() { var query = "SELECT contactid FROM contact WHERE DATEDIFF(day, '2020-01-01', createdon) < 10 OR lastname = 'Carrington' ORDER BY createdon"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1469,7 +1425,7 @@ public void FilterCaseInsensitive() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1484,8 +1440,7 @@ public void GroupCaseInsensitive() { var query = "SELECT lastname, count(*) FROM contact WHERE DATEDIFF(day, '2020-01-01', createdon) > 10 GROUP BY lastname ORDER BY 2 DESC"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1527,7 +1482,7 @@ public void GroupCaseInsensitive() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1542,8 +1497,7 @@ public void AggregateExpressionsWithoutGrouping() { var query = "SELECT count(DISTINCT firstname + ' ' + lastname) AS distinctnames FROM contact"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1581,7 +1535,7 @@ public void AggregateExpressionsWithoutGrouping() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1595,8 +1549,7 @@ public void AggregateQueryProducesAlternative() { var query = "SELECT name, count(*) FROM account GROUP BY name ORDER BY 2 DESC"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); var select = (SelectNode)queries[0]; @@ -1637,7 +1590,7 @@ public void AggregateQueryProducesAlternative() } }; - var dataReader = alternativeQuery.Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = alternativeQuery.Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1652,8 +1605,7 @@ public void GuidEntityReferenceInequality() { var query = "SELECT a.name FROM account a INNER JOIN contact c ON a.primarycontactid = c.contactid WHERE (c.parentcustomerid is null or a.accountid <> c.parentcustomerid)"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); var select = (SelectNode)queries[0]; @@ -1692,7 +1644,7 @@ public void GuidEntityReferenceInequality() } }; - var dataReader = select.Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = select.Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -1706,10 +1658,7 @@ public void UpdateGuidToEntityReference() { var query = "UPDATE a SET primarycontactid = c.contactid FROM account AS a INNER JOIN contact AS c ON a.accountid = c.parentcustomerid"; - var metadata = new AttributeMetadataCache(_service); - var lookup = (LookupAttributeMetadata)metadata["account"].Attributes.Single(a => a.LogicalName == "primarycontactid"); - lookup.Targets = new[] { "contact" }; - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); var update = (UpdateNode)queries[0]; @@ -1746,7 +1695,7 @@ public void UpdateGuidToEntityReference() } }; - update.Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), out _); + update.Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), out _); Assert.AreEqual(new EntityReference("contact", contact1), _context.Data["account"][account1].GetAttributeValue("primarycontactid")); Assert.AreEqual(new EntityReference("contact", contact2), _context.Data["account"][account2].GetAttributeValue("primarycontactid")); @@ -1757,8 +1706,7 @@ public void CompareDateFields() { var query = "DELETE c2 FROM contact c1 INNER JOIN contact c2 ON c1.parentcustomerid = c2.parentcustomerid AND c2.createdon > c1.createdon"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1784,8 +1732,7 @@ public void ColumnComparison() { var query = "SELECT firstname, lastname FROM contact WHERE firstname = lastname"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -1808,8 +1755,7 @@ public void QuotedIdentifierError() try { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), new OptionsWrapper(this) { QuotedIdentifiers = true }); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, new OptionsWrapper(this) { QuotedIdentifiers = true }); var queries = planBuilder.Build(query, null, out _); Assert.Fail("Expected exception"); @@ -1825,8 +1771,7 @@ public void FilterExpressionConstantValueToFetchXml() { var query = "SELECT firstname, lastname FROM contact WHERE firstname = 'Ma' + 'rk'"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, $@" @@ -1847,8 +1792,7 @@ public void Count1ConvertedToCountStar() { var query = "SELECT COUNT(1) FROM contact OPTION(USE HINT('RETRIEVE_TOTAL_RECORD_COUNT'))"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); var selectNode = (SelectNode)queries[0]; @@ -1861,8 +1805,7 @@ public void CaseInsensitive() { var query = "Select Name From Account"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, $@" @@ -1879,8 +1822,7 @@ public void ContainsValues1() { var query = "SELECT new_name FROM new_customentity WHERE CONTAINS(new_optionsetvaluecollection, '1')"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, $@" @@ -1902,8 +1844,7 @@ public void ContainsValuesFunction1() { var query = "SELECT new_name FROM new_customentity WHERE new_optionsetvaluecollection = containvalues(1)"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, $@" @@ -1925,8 +1866,7 @@ public void ContainsValues() { var query = "SELECT new_name FROM new_customentity WHERE CONTAINS(new_optionsetvaluecollection, '1 OR 2')"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, $@" @@ -1949,8 +1889,7 @@ public void ContainsValuesFunction() { var query = "SELECT new_name FROM new_customentity WHERE new_optionsetvaluecollection = containvalues(1, 2)"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, $@" @@ -1973,8 +1912,7 @@ public void NotContainsValues() { var query = "SELECT new_name FROM new_customentity WHERE NOT CONTAINS(new_optionsetvaluecollection, '1 OR 2')"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, $@" @@ -1997,12 +1935,13 @@ public void TSqlAggregates() { var query = "SELECT COUNT(*) AS count FROM account WHERE name IS NULL"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), new OptionsWrapper(this) { UseTDSEndpoint = true }); - var queries = planBuilder.Build(query, null, out var useTDSEndpointDirectly); - Assert.IsTrue(useTDSEndpointDirectly); - Assert.AreEqual(1, queries.Length); - Assert.IsInstanceOfType(queries[0], typeof(SqlNode)); + BuildTDSQuery(planBuilder => + { + var queries = planBuilder.Build(query, null, out var useTDSEndpointDirectly); + Assert.IsTrue(useTDSEndpointDirectly); + Assert.AreEqual(1, queries.Length); + Assert.IsInstanceOfType(queries[0], typeof(SqlNode)); + }); } [TestMethod] @@ -2010,8 +1949,7 @@ public void ImplicitTypeConversion() { var query = "SELECT employees / 2.0 AS half FROM account"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); var account1 = Guid.NewGuid(); @@ -2031,7 +1969,7 @@ public void ImplicitTypeConversion() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -2044,8 +1982,7 @@ public void ImplicitTypeConversionComparison() { var query = "SELECT * FROM account WHERE turnover / 2 > 10"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); var account1 = Guid.NewGuid(); @@ -2065,7 +2002,7 @@ public void ImplicitTypeConversionComparison() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -2077,9 +2014,8 @@ public void GlobalOptionSet() { var query = "SELECT displayname FROM metadata.globaloptionset WHERE name = 'test'"; - var metadata = new AttributeMetadataCache(_service); _context.AddFakeMessageExecutor(new RetrieveAllOptionSetsHandler()); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); Assert.IsInstanceOfType(queries.Single(), typeof(SelectNode)); @@ -2091,7 +2027,7 @@ public void GlobalOptionSet() Assert.AreEqual("name = 'test'", filterNode.Filter.ToSql()); var optionsetNode = (GlobalOptionSetQueryNode)filterNode.Source; - var dataReader = selectNode.Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = selectNode.Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -2103,9 +2039,8 @@ public void EntityDetails() { var query = "SELECT logicalname FROM metadata.entity ORDER BY 1"; - var metadata = new AttributeMetadataCache(_service); - _context.AddFakeMessageExecutor(new RetrieveMetadataChangesHandler(metadata)); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + _context.AddFakeMessageExecutor(new RetrieveMetadataChangesHandler(_localDataSource["local"].Metadata)); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); Assert.IsInstanceOfType(queries.Single(), typeof(SelectNode)); @@ -2114,7 +2049,7 @@ public void EntityDetails() var sortNode = (SortNode)selectNode.Source; var metadataNode = (MetadataQueryNode)sortNode.Source; - var dataReader = selectNode.Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = selectNode.Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -2129,12 +2064,11 @@ public void AttributeDetails() { var query = "SELECT e.logicalname, a.logicalname FROM metadata.entity e INNER JOIN metadata.attribute a ON e.logicalname = a.entitylogicalname WHERE e.logicalname = 'new_customentity' ORDER BY 1, 2"; - var metadata = new AttributeMetadataCache(_service); - _context.AddFakeMessageExecutor(new RetrieveMetadataChangesHandler(metadata)); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + _context.AddFakeMessageExecutor(new RetrieveMetadataChangesHandler(_localDataSource["local"].Metadata)); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -2154,11 +2088,7 @@ public void OptionSetNameSelect() { var query = "SELECT new_optionsetvalue, new_optionsetvaluename FROM new_customentity ORDER BY new_optionsetvaluename"; - var metadata = new AttributeMetadataCache(_service); - // Add metadata for new_optionsetvaluename virtual attribute - var attr = metadata["new_customentity"].Attributes.Single(a => a.LogicalName == "new_optionsetvaluename"); - attr.GetType().GetProperty(nameof(AttributeMetadata.AttributeOf)).SetValue(attr, "new_optionsetvalue"); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); var record1 = Guid.NewGuid(); @@ -2194,7 +2124,7 @@ public void OptionSetNameSelect() CollectionAssert.AreEqual(new[] { "new_optionsetvalue", "new_optionsetvaluename" }, select.ColumnSet.Select(c => c.OutputColumn).ToList()); - var dataReader = select.Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = select.Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -2209,22 +2139,7 @@ public void OptionSetNameFilter() { var query = "SELECT new_customentityid FROM new_customentity WHERE new_optionsetvaluename = 'Value1'"; - var metadata = new AttributeMetadataCache(_service); - - // Add metadata for new_optionsetvaluename virtual attribute - var nameAttr = metadata["new_customentity"].Attributes.Single(a => a.LogicalName == "new_optionsetvaluename"); - nameAttr.GetType().GetProperty(nameof(AttributeMetadata.AttributeOf)).SetValue(nameAttr, "new_optionsetvalue"); - var valueAttr = (EnumAttributeMetadata)metadata["new_customentity"].Attributes.Single(a => a.LogicalName == "new_optionsetvalue"); - valueAttr.OptionSet = new OptionSetMetadata - { - Options = - { - new OptionMetadata(new Label { UserLocalizedLabel = new LocalizedLabel(Metadata.New_OptionSet.Value1.ToString(), 1033) }, (int) Metadata.New_OptionSet.Value1), - new OptionMetadata(new Label { UserLocalizedLabel = new LocalizedLabel(Metadata.New_OptionSet.Value2.ToString(), 1033) }, (int) Metadata.New_OptionSet.Value2), - new OptionMetadata(new Label { UserLocalizedLabel = new LocalizedLabel(Metadata.New_OptionSet.Value3.ToString(), 1033) }, (int) Metadata.New_OptionSet.Value3) - } - }; - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -2243,12 +2158,7 @@ public void EntityReferenceNameSelect() { var query = "SELECT primarycontactid, primarycontactidname FROM account ORDER BY primarycontactidname"; - var metadata = new AttributeMetadataCache(_service); - - // Add metadata for primarycontactidname virtual attribute - var attr = metadata["account"].Attributes.Single(a => a.LogicalName == "primarycontactidname"); - attr.GetType().GetProperty(nameof(AttributeMetadata.AttributeOf)).SetValue(attr, "primarycontactid"); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -2269,19 +2179,7 @@ public void EntityReferenceNameFilter() { var query = "SELECT accountid FROM account WHERE primarycontactidname = 'Mark Carrington'"; - var metadata = new AttributeMetadataCache(_service); - - // Add metadata for primarycontactidname virtual attribute - var nameAttr = metadata["account"].Attributes.Single(a => a.LogicalName == "primarycontactidname"); - nameAttr.GetType().GetProperty(nameof(AttributeMetadata.AttributeOf)).SetValue(nameAttr, "primarycontactid"); - - var idAttr = (LookupAttributeMetadata) metadata["account"].Attributes.Single(a => a.LogicalName == "primarycontactid"); - idAttr.Targets = new[] { "contact" }; - - // Set the primary name attribute on contact - typeof(EntityMetadata).GetProperty(nameof(EntityMetadata.PrimaryNameAttribute)).SetValue(metadata["contact"], "fullname"); - - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -2300,10 +2198,7 @@ public void UpdateMissingAlias() { var query = "UPDATE account SET primarycontactid = c.contactid FROM account AS a INNER JOIN contact AS c ON a.name = c.fullname"; - var metadata = new AttributeMetadataCache(_service); - var lookup = (LookupAttributeMetadata)metadata["account"].Attributes.Single(a => a.LogicalName == "primarycontactid"); - lookup.Targets = new[] { "contact" }; - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); } @@ -2314,8 +2209,7 @@ public void UpdateMissingAliasAmbiguous() try { - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); Assert.Fail("Expected exception"); } @@ -2330,8 +2224,7 @@ public void ConvertIntToBool() { var query = "UPDATE new_customentity SET new_boolprop = CASE WHEN new_name = 'True' THEN 1 ELSE 0 END"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); } @@ -2342,8 +2235,7 @@ public void ImpersonateRevert() EXECUTE AS LOGIN = 'test1' REVERT"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); Assert.IsInstanceOfType(queries[0], typeof(ExecuteAsNode)); @@ -2368,8 +2260,7 @@ SELECT contact.fullname FROM contact INNER JOIN account ON contact.contactid = account.primarycontactid INNER JOIN new_customentity ON contact.parentcustomerid = new_customentity.new_parentid ORDER BY account.employees, contact.fullname"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(new[] { queries[0] }, @" @@ -2402,13 +2293,13 @@ UPDATE c FROM contact AS c INNER JOIN account ON c.parentcustomerid = account.accountid INNER JOIN new_customentity ON c.parentcustomerid = new_customentity.new_parentid WHERE c.fullname IN (SELECT fullname FROM contact WHERE firstname = 'Mark')"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), new OptionsWrapper(this) { UseTDSEndpoint = true }); - var queries = planBuilder.Build(query, null, out _); + BuildTDSQuery(planBuilder => + { + var queries = planBuilder.Build(query, null, out _); - var tds = (SqlNode)((UpdateNode)queries[0]).Source; + var tds = (SqlNode)((UpdateNode)queries[0]).Source; - Assert.AreEqual(Regex.Replace(@" + Assert.AreEqual(Regex.Replace(@" SELECT DISTINCT c.contactid AS contactid, account.accountid AS new_parentcustomerid, @@ -2421,6 +2312,7 @@ INNER JOIN new_customentity ON c.parentcustomerid = new_customentity.new_parentid WHERE c.fullname IN (SELECT fullname FROM contact WHERE firstname = 'Mark')", @"\s+", " ").Trim(), Regex.Replace(tds.Sql, @"\s+", " ").Trim()); + }); } [TestMethod] @@ -2431,13 +2323,13 @@ DELETE c FROM contact AS c INNER JOIN account ON c.parentcustomerid = account.accountid INNER JOIN new_customentity ON c.parentcustomerid = new_customentity.new_parentid WHERE c.fullname IN (SELECT fullname FROM contact WHERE firstname = 'Mark')"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), new OptionsWrapper(this) { UseTDSEndpoint = true }); - var queries = planBuilder.Build(query, null, out _); + BuildTDSQuery(planBuilder => + { + var queries = planBuilder.Build(query, null, out _); - var tds = (SqlNode)((DeleteNode)queries[0]).Source; + var tds = (SqlNode)((DeleteNode)queries[0]).Source; - Assert.AreEqual(Regex.Replace(@" + Assert.AreEqual(Regex.Replace(@" SELECT DISTINCT c.contactid AS contactid FROM @@ -2448,6 +2340,23 @@ INNER JOIN new_customentity ON c.parentcustomerid = new_customentity.new_parentid WHERE c.fullname IN (SELECT fullname FROM contact WHERE firstname = 'Mark')", @"\s+", " ").Trim(), Regex.Replace(tds.Sql, @"\s+", " ").Trim()); + }); + } + + private void BuildTDSQuery(Action action) + { + var ds = _localDataSource["local"]; + var con = ds.Connection; + ds.Connection = null; + try + { + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, new OptionsWrapper(this) { UseTDSEndpoint = true }); + action(planBuilder); + } + finally + { + ds.Connection = con; + } } [TestMethod] @@ -2455,8 +2364,7 @@ public void OrderByAggregateByIndex() { var query = "SELECT firstname, count(*) FROM contact GROUP BY firstname ORDER BY 2"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -2475,8 +2383,7 @@ public void OrderByAggregateJoinByIndex() { var query = "SELECT firstname, count(*) FROM contact INNER JOIN account ON contact.parentcustomerid = account.accountid GROUP BY firstname ORDER BY 2"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); AssertFetchXml(queries, @" @@ -2497,8 +2404,7 @@ public void AggregateAlternativeDoesNotOrderByLinkEntity() { var query = "SELECT name, count(*) FROM contact INNER JOIN account ON contact.parentcustomerid = account.accountid GROUP BY name"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); var select = (SelectNode)queries[0]; @@ -2528,8 +2434,7 @@ public void CharIndex() { var query = "SELECT CHARINDEX('a', fullname) AS ci0, CHARINDEX('a', fullname, 1) AS ci1, CHARINDEX('a', fullname, 2) AS ci2, CHARINDEX('a', fullname, 3) AS ci3, CHARINDEX('a', fullname, 8) AS ci8 FROM contact"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); var contact1 = Guid.NewGuid(); @@ -2543,7 +2448,7 @@ public void CharIndex() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -2559,8 +2464,7 @@ public void CastDateTimeToDate() { var query = "SELECT CAST(createdon AS date) AS converted FROM contact"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); var contact1 = Guid.NewGuid(); @@ -2574,7 +2478,7 @@ public void CastDateTimeToDate() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); @@ -2586,8 +2490,7 @@ public void GroupByPrimaryFunction() { var query = "SELECT left(firstname, 1) AS initial, count(*) AS count FROM contact GROUP BY left(firstname, 1) ORDER BY 2 DESC"; - var metadata = new AttributeMetadataCache(_service); - var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), new StubMessageCache(), this); + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); var queries = planBuilder.Build(query, null, out _); var contact1 = Guid.NewGuid(); @@ -2613,7 +2516,7 @@ public void GroupByPrimaryFunction() } }; - var dataReader = ((SelectNode)queries[0]).Execute(GetDataSources(_context), this, new Dictionary(), new Dictionary(), CommandBehavior.Default); + var dataReader = ((SelectNode)queries[0]).Execute(new NodeExecutionContext(GetDataSources(_context), this, new Dictionary(), new Dictionary()), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(dataReader); diff --git a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsDataReader.cs b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsDataReader.cs index d8099c07..c5b5bc47 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsDataReader.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsDataReader.cs @@ -62,6 +62,8 @@ public Sql4CdsDataReader(Sql4CdsCommand command, IQueryExecutionOptions options, private bool Execute(Dictionary parameterTypes, Dictionary parameterValues) { + var context = new NodeExecutionContext(_connection.DataSources, _options, parameterTypes, parameterValues); + try { while (_instructionPointer < _command.Plan.Length && !_options.CancellationToken.IsCancellationRequested) @@ -73,7 +75,7 @@ private bool Execute(Dictionary parameterTypes, Dicti if (_resultSetsReturned == 0 || (!_behavior.HasFlag(CommandBehavior.SingleResult) && !_behavior.HasFlag(CommandBehavior.SingleRow))) { _readerQuery = (IDataReaderExecutionPlanNode)dataSetNode.Clone(); - _reader = _readerQuery.Execute(_connection.DataSources, _options, parameterTypes, parameterValues, _behavior); + _reader = _readerQuery.Execute(context, _behavior); _resultSetsReturned++; _rows = 0; _instructionPointer++; @@ -88,7 +90,7 @@ private bool Execute(Dictionary parameterTypes, Dicti else if (node is IDmlQueryExecutionPlanNode dmlNode) { dmlNode = (IDmlQueryExecutionPlanNode)dmlNode.Clone(); - var msg = dmlNode.Execute(_connection.DataSources, _options, parameterTypes, parameterValues, out var recordsAffected); + var msg = dmlNode.Execute(context, out var recordsAffected); if (!String.IsNullOrEmpty(msg)) _connection.OnInfoMessage(dmlNode, msg); @@ -106,7 +108,7 @@ private bool Execute(Dictionary parameterTypes, Dicti else if (node is IGoToNode cond) { cond = (IGoToNode)cond.Clone(); - var label = cond.Execute(_connection.DataSources, _options, parameterTypes, parameterValues); + var label = cond.Execute(context); if (cond.GetSources().Any()) _command.OnStatementCompleted(cond, -1); diff --git a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsParameter.cs b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsParameter.cs index a20800cc..c8572558 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsParameter.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsParameter.cs @@ -297,7 +297,7 @@ internal INullable GetValue() else if (value is float f) value = (SqlSingle)f; else if (value is string str) - value = SqlTypeConverter.UseDefaultCollation(str); + value = new SqlString(str, LocaleId, CompareInfo); else if (value is DateTimeOffset dto) value = (SqlDateTime)dto.DateTime; else if (value is EntityReference er) diff --git a/MarkMpn.Sql4Cds.Engine/DataSource.cs b/MarkMpn.Sql4Cds.Engine/DataSource.cs index 994a378f..c6eee687 100644 --- a/MarkMpn.Sql4Cds.Engine/DataSource.cs +++ b/MarkMpn.Sql4Cds.Engine/DataSource.cs @@ -84,7 +84,7 @@ public DataSource() /// /// Returns the default collation used by this instance /// - internal Collation DefaultCollation { get; } + internal Collation DefaultCollation { get; set; } private Collation LoadDefaultCollation() { diff --git a/MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs b/MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs index 4bb48650..7af99588 100644 --- a/MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs +++ b/MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs @@ -597,6 +597,7 @@ internal static bool TryConvertCollation(SqlDataTypeReference lhsSql, SqlDataTyp } collationLabel = CollationLabel.CoercibleDefault; + collation = lhsSqlWithColl.Collation; return true; } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/Aggregate.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/Aggregate.cs index 7b949b01..d5a5b6d5 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/Aggregate.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/Aggregate.cs @@ -49,7 +49,7 @@ class Aggregate /// A compiled version of the that takes the row values and parameter values and returns the value to add to the aggregate /// [Browsable(false)] - public Func, IQueryExecutionOptions, object> Expression { get; set; } + public Func Expression { get; set; } /// /// The type of value produced by the @@ -69,13 +69,13 @@ class Aggregate /// abstract class AggregateFunction { - private readonly Func _selector; + private readonly Func _selector; /// /// Creates a new /// /// The function that returns the value to aggregate from the source entity - public AggregateFunction(Func selector) + public AggregateFunction(Func selector) { _selector = selector; } @@ -83,24 +83,22 @@ public AggregateFunction(Func selector) /// /// Updates the aggregate function state based on the next in the sequence /// - /// The to take the value from and apply to this aggregation /// The current state of the aggregation /// The new state of the aggregation - public virtual void NextRecord(Entity entity, object state) + public virtual void NextRecord(object state) { - var value = _selector == null ? entity : _selector(entity); + var value = _selector(); Update(value, state); } /// /// Updates the aggregate function state based on the aggregate values for a partition /// - /// The that contains aggregated values from a partition of the available records /// The current state of the aggregation /// The new state of the aggregation - public virtual void NextPartition(Entity entity, object state) + public virtual void NextPartition(object state) { - var value = _selector(entity); + var value = _selector(); UpdatePartition(value, state); } @@ -171,7 +169,7 @@ public State(object sumState, object countState) /// Creates a new /// /// A function that extracts the value to calculate the average from - public Average(Func selector, DataTypeReference sourceType, DataTypeReference returnType) : base(selector) + public Average(Func selector, DataTypeReference sourceType, DataTypeReference returnType) : base(selector) { Type = returnType; @@ -179,12 +177,12 @@ public Average(Func selector, DataTypeReference sourceType, Data _count = new CountColumn(selector); } - public override void NextRecord(Entity entity, object state) + public override void NextRecord(object state) { var s = (State)state; - _sum.NextRecord(entity, s.SumState); - _count.NextRecord(entity, s.CountState); + _sum.NextRecord(s.SumState); + _count.NextRecord(s.CountState); } protected override void Update(object value, object state) @@ -244,7 +242,7 @@ class State /// Creates a new /// /// Unused - public Count(Func selector) : base(selector) + public Count(Func selector) : base(selector) { } @@ -288,7 +286,7 @@ class State /// Creates a new /// /// A function that extracts the value to count non-null instances of - public CountColumn(Func selector) : base(selector) + public CountColumn(Func selector) : base(selector) { } @@ -340,7 +338,7 @@ public State(IComparable value) /// Creates a new /// /// A function that extracts the value to find the maximum value of - public Max(Func selector, DataTypeReference type) : base(selector) + public Max(Func selector, DataTypeReference type) : base(selector) { Type = type; } @@ -396,7 +394,7 @@ public State(IComparable value) /// Creates a new /// /// A function that extracts the value to find the minimum value of - public Min(Func selector, DataTypeReference type) : base(selector) + public Min(Func selector, DataTypeReference type) : base(selector) { Type = type; } @@ -455,7 +453,7 @@ class State /// Creates a new /// /// A function that extracts the value to sum - public Sum(Func selector, DataTypeReference sourceType, DataTypeReference returnType) : base(selector) + public Sum(Func selector, DataTypeReference sourceType, DataTypeReference returnType) : base(selector) { Type = returnType; @@ -559,7 +557,7 @@ public State(object value) /// Creates a new /// /// A function that extracts the value to sum - public First(Func selector, DataTypeReference type) : base(selector) + public First(Func selector, DataTypeReference type) : base(selector) { Type = type; } @@ -602,9 +600,9 @@ class State } private readonly AggregateFunction _func; - private readonly Func _selector; + private readonly Func _selector; - public DistinctAggregate(AggregateFunction func, Func selector) : base(selector) + public DistinctAggregate(AggregateFunction func, Func selector) : base(selector) { _func = func; _selector = selector; @@ -615,13 +613,13 @@ public DistinctAggregate(AggregateFunction func, Func selector) public override DataTypeReference Type => _func.Type; - public override void NextRecord(Entity entity, object state) + public override void NextRecord(object state) { - var value = _selector(entity); + var value = _selector(); var s = (State)state; if (s.Distinct.Add(value)) - _func.NextRecord(entity, s.InnerState); + _func.NextRecord(s.InnerState); } protected override void UpdatePartition(object value, object state) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AliasNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AliasNode.cs index f702a141..81fa4045 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AliasNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AliasNode.cs @@ -60,7 +60,7 @@ private AliasNode() [Browsable(false)] public IDataExecutionPlanNodeInternal Source { get; set; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { var mappings = ColumnSet.Where(col => !col.AllColumns).ToDictionary(col => col.OutputColumn, col => col.SourceColumn); ColumnSet.Clear(); @@ -85,16 +85,16 @@ public override void AddRequiredColumns(IDictionary dataSour } } - Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + Source.AddRequiredColumns(context, requiredColumns); } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - Source = Source.FoldQuery(dataSources, options, parameterTypes, hints); + Source = Source.FoldQuery(context, hints); Source.Parent = this; - SelectNode.FoldFetchXmlColumns(Source, ColumnSet, dataSources, parameterTypes); - SelectNode.ExpandWildcardColumns(Source, ColumnSet, dataSources, parameterTypes); + SelectNode.FoldFetchXmlColumns(Source, ColumnSet, context); + SelectNode.ExpandWildcardColumns(Source, ColumnSet, context); if (Source is FetchXmlScan fetchXml) { @@ -144,10 +144,10 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { // Map the base names to the alias names - var sourceSchema = Source.GetSchema(dataSources, parameterTypes); + var sourceSchema = Source.GetSchema(context); var schema = new Dictionary(StringComparer.OrdinalIgnoreCase); var aliases = new Dictionary>(); var primaryKey = (string)null; @@ -232,9 +232,9 @@ private void AddSchemaColumn(string outputColumn, string sourceColumn, Dictionar ((List)a).Add(mapped); } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - foreach (var entity in Source.Execute(dataSources, options, parameterTypes, parameterValues)) + foreach (var entity in Source.Execute(context)) { foreach (var col in ColumnSet) { @@ -251,9 +251,9 @@ public override string ToString() return "Subquery Alias"; } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { - return Source.EstimateRowsOut(dataSources, options, parameterTypes); + return Source.EstimateRowsOut(context); } public override IEnumerable GetSources() diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AssertNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AssertNode.cs index 78b313e6..9125ce29 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AssertNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AssertNode.cs @@ -34,9 +34,9 @@ class AssertNode : BaseDataNode, ISingleSourceExecutionPlanNode [DisplayName("Error Message")] public string ErrorMessage { get; set; } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - foreach (var entity in Source.Execute(dataSources, options, parameterTypes, parameterValues)) + foreach (var entity in Source.Execute(context)) { if (!Assertion(entity)) throw new ApplicationException(ErrorMessage); @@ -45,9 +45,9 @@ protected override IEnumerable ExecuteInternal(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { - return Source.GetSchema(dataSources, parameterTypes); + return Source.GetSchema(context); } public override IEnumerable GetSources() @@ -55,21 +55,21 @@ public override IEnumerable GetSources() yield return Source; } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - Source = Source.FoldQuery(dataSources, options, parameterTypes, hints); + Source = Source.FoldQuery(context, hints); Source.Parent = this; return this; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { - Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + Source.AddRequiredColumns(context, requiredColumns); } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { - return Source.EstimateRowsOut(dataSources, options, parameterTypes); + return Source.EstimateRowsOut(context); } public override object Clone() diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AssignVariablesNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AssignVariablesNode.cs index 694dc9a5..3d6a6baf 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AssignVariablesNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AssignVariablesNode.cs @@ -34,7 +34,7 @@ class AssignVariablesNode : BaseDmlNode [Browsable(false)] public override bool BypassCustomPluginExecution { get; set; } - public override string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, out int recordsAffected) + public override string Execute(NodeExecutionContext context, out int recordsAffected) { _executionCount++; @@ -42,18 +42,18 @@ public override string Execute(IDictionary dataSources, IQue { var count = 0; - var entities = GetDmlSourceEntities(dataSources, options, parameterTypes, parameterValues, out var schema); - var valueAccessors = CompileValueAccessors(schema, entities, parameterTypes); + var entities = GetDmlSourceEntities(context, out var schema); + var valueAccessors = CompileValueAccessors(schema, entities, context.ParameterTypes); foreach (var entity in entities) { foreach (var variable in Variables) - parameterValues[variable.VariableName] = valueAccessors[variable.VariableName](entity); + context.ParameterValues[variable.VariableName] = valueAccessors[variable.VariableName](entity); count++; } - parameterValues["@@ROWCOUNT"] = (SqlInt32)count; + context.ParameterValues["@@ROWCOUNT"] = (SqlInt32)count; } recordsAffected = -1; @@ -113,7 +113,7 @@ protected Dictionary> CompileValueAccessors(INodeSc return valueAccessors; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { foreach (var variable in Variables) { @@ -121,7 +121,7 @@ public override void AddRequiredColumns(IDictionary dataSour requiredColumns.Add(variable.SourceColumn); } - Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + Source.AddRequiredColumns(context, requiredColumns); } protected override void RenameSourceColumns(IDictionary columnRenamings) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseAggregateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseAggregateNode.cs index 2925e2ed..ad73a0fa 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseAggregateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseAggregateNode.cs @@ -55,15 +55,15 @@ protected class AggregateFunctionState [Browsable(false)] public IDataExecutionPlanNodeInternal Source { get; set; } - protected void InitializeAggregates(DataSource primaryDataSource, INodeSchema schema, IDictionary parameterTypes) + protected void InitializeAggregates(ExpressionCompilationContext context) { foreach (var aggregate in Aggregates.Where(agg => agg.Value.SqlExpression != null)) { - aggregate.Value.SqlExpression.GetType(primaryDataSource, schema, null, parameterTypes, out var retType); + aggregate.Value.SqlExpression.GetType(context, out var retType); aggregate.Value.SourceType = retType; aggregate.Value.ReturnType = retType; - aggregate.Value.Expression = aggregate.Value.SqlExpression.Compile(primaryDataSource, schema, parameterTypes); + aggregate.Value.Expression = aggregate.Value.SqlExpression.Compile(context); // Return type of SUM and AVG is based on the input type with some modifications // https://docs.microsoft.com/en-us/sql/t-sql/functions/avg-transact-sql?view=sql-server-ver15#return-types @@ -82,13 +82,13 @@ protected void InitializeAggregates(DataSource primaryDataSource, INodeSchema sc } } - protected void InitializePartitionedAggregates(DataSource primaryDataSource, INodeSchema schema, IDictionary parameterTypes) + protected void InitializePartitionedAggregates(ExpressionCompilationContext context) { foreach (var aggregate in Aggregates) { var sourceExpression = aggregate.Key.ToColumnReference(); - aggregate.Value.Expression = sourceExpression.Compile(primaryDataSource, schema, parameterTypes); - sourceExpression.GetType(primaryDataSource, schema, null, parameterTypes, out var retType); + aggregate.Value.Expression = sourceExpression.Compile(context); + sourceExpression.GetType(context, out var retType); aggregate.Value.SourceType = retType; aggregate.Value.ReturnType = retType; } @@ -108,16 +108,18 @@ protected List GetGroupingColumns(INodeSchema schema) return groupByCols; } - protected Dictionary CreateAggregateFunctions(IDictionary parameterValues, IQueryExecutionOptions options, bool partitioned) + protected Dictionary CreateAggregateFunctions(ExpressionExecutionContext context, bool partitioned) { var values = new Dictionary(); foreach (var aggregate in Aggregates) { - Func selector = null; + Func selector = null; if (partitioned || aggregate.Value.AggregateType != AggregateType.CountStar) - selector = e => aggregate.Value.Expression(e, parameterValues, options); + selector = () => aggregate.Value.Expression(context); + else + selector = () => null; switch (aggregate.Value.AggregateType) { @@ -172,9 +174,10 @@ protected IEnumerable> GetValues(Dictionary new KeyValuePair(kvp.Key, kvp.Value.AggregateFunction.GetValue(kvp.Value.State))); } - public override INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { - var sourceSchema = Source.GetSchema(dataSources, parameterTypes); + var sourceSchema = Source.GetSchema(context); + var expressionContext = new ExpressionCompilationContext(context, sourceSchema, null); var schema = new Dictionary(StringComparer.OrdinalIgnoreCase); var aliases = new Dictionary>(StringComparer.OrdinalIgnoreCase); var primaryKey = (string)null; @@ -213,7 +216,7 @@ public override INodeSchema GetSchema(IDictionary dataSource break; default: - aggregate.Value.SqlExpression.GetType(dataSources[options.PrimaryDataSource], sourceSchema, null, parameterTypes, out aggregateType); + aggregate.Value.SqlExpression.GetType(expressionContext, out aggregateType); // Return type of SUM and AVG is based on the input type with some modifications // https://docs.microsoft.com/en-us/sql/t-sql/functions/avg-transact-sql?view=sql-server-ver15#return-types @@ -307,17 +310,17 @@ protected bool IsCompositeAddressPluginBug(OrganizationServiceFault fault) return false; } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { if (GroupBy.Count == 0) return RowCountEstimateDefiniteRange.ExactlyOne; - var rows = Source.EstimateRowsOut(dataSources, options, parameterTypes).Value * 4 / 10; + var rows = Source.EstimateRowsOut(context).Value * 4 / 10; return new RowCountEstimate(rows); } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { // Columns required by previous nodes must be derived from this node, so no need to pass them through. // Just calculate the columns that are required to calculate the groups & aggregates @@ -327,7 +330,7 @@ public override void AddRequiredColumns(IDictionary dataSour scalarRequiredColumns.AddRange(Aggregates.Where(agg => agg.Value.SqlExpression != null).SelectMany(agg => agg.Value.SqlExpression.GetColumns()).Distinct()); - Source.AddRequiredColumns(dataSources, parameterTypes, scalarRequiredColumns); + Source.AddRequiredColumns(context, scalarRequiredColumns); } protected override IEnumerable GetVariablesInternal() diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDataNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDataNode.cs index 79a852f9..1d3ed01c 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDataNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDataNode.cs @@ -54,15 +54,11 @@ abstract class BaseDataNode : BaseNode, IDataExecutionPlanNodeInternal /// /// Executes the query and produces a stram of data in the results /// - /// The to use to get the data - /// The to use to get metadata - /// to indicate how the query can be executed - /// A mapping of parameter names to their related types - /// A mapping of parameter names to their current values + /// The context in which the node is being executed /// A stream of records - public IEnumerable Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + public IEnumerable Execute(NodeExecutionContext context) { - if (options.CancellationToken.IsCancellationRequested) + if (context.Options.CancellationToken.IsCancellationRequested) yield break; // Track execution times roughly using Environment.TickCount. Stopwatch provides more accurate results @@ -75,7 +71,7 @@ public IEnumerable Execute(IDictionary dataSources, { _executionCount++; - enumerator = ExecuteInternal(dataSources, options, parameterTypes, parameterValues).GetEnumerator(); + enumerator = ExecuteInternal(context).GetEnumerator(); } catch (QueryExecutionException ex) { @@ -90,7 +86,7 @@ public IEnumerable Execute(IDictionary dataSources, } } - while (!options.CancellationToken.IsCancellationRequested) + while (!context.Options.CancellationToken.IsCancellationRequested) { Entity current; @@ -128,14 +124,14 @@ public IEnumerable Execute(IDictionary dataSources, /// A mapping of parameter names to their related types /// A cache of the number of records in each table /// The number of rows the node is estimated to return - public RowCountEstimate EstimateRowsOut(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + public RowCountEstimate EstimateRowsOut(NodeCompilationContext context) { - var estimate = EstimateRowsOutInternal(dataSources, options, parameterTypes); + var estimate = EstimateRowsOutInternal(context); EstimatedRowsOut = estimate.Value; return estimate; } - protected abstract RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes); + protected abstract RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context); protected void ParseEstimate(RowCountEstimate estimate, out int min, out int max, out bool isRange) { @@ -167,50 +163,41 @@ public void MergeStatsFrom(BaseDataNode other) /// /// Produces the data for the node without keeping track of any execution time statistics /// - /// The to use to get the data - /// The to use to get metadata - /// to indicate how the query can be executed - /// A mapping of parameter names to their related types - /// A mapping of parameter names to their current values + /// The context in which the node is being executed /// A stream of records - protected abstract IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues); + protected abstract IEnumerable ExecuteInternal(NodeExecutionContext context); /// /// Gets the details of columns produced by the node /// - /// The to use to get metadata - /// A mapping of parameter names to their related types + /// The context in which the node is being built /// Details of the columns produced by the node - public abstract INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes); + public abstract INodeSchema GetSchema(NodeCompilationContext context); /// /// Attempts to fold this node into its source to simplify the query /// - /// The data sources the query can use - /// to indicate how the query can be executed - /// A mapping of parameter names to their related types + /// The context in which the node is being built /// Any optimizer hints to apply /// The node that should be used in place of this node - public abstract IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints); + public abstract IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints); /// /// Translates filter criteria from ScriptDom to FetchXML /// - /// The main for this connection + /// The context the query is being built in /// The to use to get metadata - /// to indicate how the query can be executed /// The SQL criteria to attempt to translate to FetchXML /// The schema of the node that the criteria apply to /// The prefix of the table that the can be translated for, or null if any tables can be referenced /// The logical name of the root entity that the FetchXML query is targetting /// The alias of the root entity that the FetchXML query is targetting /// The child items of the root entity in the FetchXML query - /// The types of any parameters that can be used /// The FetchXML version of the that is generated by this method /// true if the can be translated to FetchXML, or false otherwise - protected bool TranslateFetchXMLCriteria(DataSource primaryDataSource, IAttributeMetadataCache metadata, IQueryExecutionOptions options, BooleanExpression criteria, INodeSchema schema, string allowedPrefix, string targetEntityName, string targetEntityAlias, object[] items, IDictionary parameterTypes, out filter filter) + protected bool TranslateFetchXMLCriteria(NodeCompilationContext context, IAttributeMetadataCache metadata, BooleanExpression criteria, INodeSchema schema, string allowedPrefix, string targetEntityName, string targetEntityAlias, object[] items, out filter filter) { - if (!TranslateFetchXMLCriteria(primaryDataSource, metadata, options, criteria, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out var condition, out filter)) + if (!TranslateFetchXMLCriteria(context, metadata, criteria, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, out var condition, out filter)) return false; if (condition != null) @@ -222,29 +209,27 @@ protected bool TranslateFetchXMLCriteria(DataSource primaryDataSource, IAttribut /// /// Translates filter criteria from ScriptDom to FetchXML /// - /// The main for this connection + /// The context the query is being built in /// The to use to get metadata - /// to indicate how the query can be executed /// The SQL criteria to attempt to translate to FetchXML /// The schema of the node that the criteria apply to /// The prefix of the table that the can be translated for, or null if any tables can be referenced /// The logical name of the root entity that the FetchXML query is targetting /// The alias of the root entity that the FetchXML query is targetting /// The child items of the root entity in the FetchXML query - /// The types of any parameters that can be used /// The FetchXML version of the that is generated by this method when it covers multiple conditions /// The FetchXML version of the that is generated by this method when it is for a single condition only /// true if the can be translated to FetchXML, or false otherwise - private bool TranslateFetchXMLCriteria(DataSource primaryDataSource, IAttributeMetadataCache metadata, IQueryExecutionOptions options, BooleanExpression criteria, INodeSchema schema, string allowedPrefix, string targetEntityName, string targetEntityAlias, object[] items, IDictionary parameterTypes, out condition condition, out filter filter) + private bool TranslateFetchXMLCriteria(NodeCompilationContext context, IAttributeMetadataCache metadata, BooleanExpression criteria, INodeSchema schema, string allowedPrefix, string targetEntityName, string targetEntityAlias, object[] items, out condition condition, out filter filter) { condition = null; filter = null; if (criteria is BooleanBinaryExpression binary) { - if (!TranslateFetchXMLCriteria(primaryDataSource, metadata, options, binary.FirstExpression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out var lhsCondition, out var lhsFilter)) + if (!TranslateFetchXMLCriteria(context, metadata, binary.FirstExpression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, out var lhsCondition, out var lhsFilter)) return false; - if (!TranslateFetchXMLCriteria(primaryDataSource, metadata, options, binary.SecondExpression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out var rhsCondition, out var rhsFilter)) + if (!TranslateFetchXMLCriteria(context, metadata, binary.SecondExpression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, out var rhsCondition, out var rhsFilter)) return false; filter = new filter @@ -261,7 +246,7 @@ private bool TranslateFetchXMLCriteria(DataSource primaryDataSource, IAttributeM if (criteria is BooleanParenthesisExpression paren) { - return TranslateFetchXMLCriteria(primaryDataSource, metadata, options, paren.Expression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out condition, out filter); + return TranslateFetchXMLCriteria(context, metadata, paren.Expression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, out condition, out filter); } if (criteria is BooleanComparisonExpression comparison) @@ -284,7 +269,7 @@ private bool TranslateFetchXMLCriteria(DataSource primaryDataSource, IAttributeM { // The operator is comparing two attributes. This is allowed in join criteria, // but not in filter conditions before version 9.1.0.19251 - if (!options.ColumnComparisonAvailable) + if (!context.Options.ColumnComparisonAvailable) return false; } @@ -321,8 +306,10 @@ private bool TranslateFetchXMLCriteria(DataSource primaryDataSource, IAttributeM } } + var expressionContext = new ExpressionCompilationContext(context, schema, null); + // If we still couldn't find the column name and value, this isn't a pattern we can support in FetchXML - if (field == null || (literal == null && func == null && variable == null && parameterless == null && globalVariable == null && (field2 == null || !options.ColumnComparisonAvailable) && !expr.IsConstantValueExpression(primaryDataSource, schema, options, out literal))) + if (field == null || (literal == null && func == null && variable == null && parameterless == null && globalVariable == null && (field2 == null || !context.Options.ColumnComparisonAvailable) && !expr.IsConstantValueExpression(expressionContext, out literal))) return false; // Select the correct FetchXML operator @@ -427,14 +414,14 @@ private bool TranslateFetchXMLCriteria(DataSource primaryDataSource, IAttributeM } else if (func != null) { - if (func.IsConstantValueExpression(primaryDataSource, schema, options, out literal)) + if (func.IsConstantValueExpression(expressionContext, out literal)) values = new[] { literal }; else return false; } else if (parameterless != null) { - if (parameterless.IsConstantValueExpression(primaryDataSource, schema, options, out literal)) + if (parameterless.IsConstantValueExpression(expressionContext, out literal)) { values = new[] { literal }; } @@ -474,7 +461,7 @@ private bool TranslateFetchXMLCriteria(DataSource primaryDataSource, IAttributeM if (field2 == null) { - return TranslateFetchXMLCriteriaWithVirtualAttributes(meta, entityAlias, attrName, op, values, metadata, options, targetEntityAlias, items, parameterTypes, out condition, out filter); + return TranslateFetchXMLCriteriaWithVirtualAttributes(context, meta, entityAlias, attrName, op, values, metadata, targetEntityAlias, items, out condition, out filter); } else { @@ -537,7 +524,7 @@ private bool TranslateFetchXMLCriteria(DataSource primaryDataSource, IAttributeM var entityName = AliasToEntityName(targetEntityAlias, targetEntityName, items, entityAlias); var meta = metadata[entityName]; - return TranslateFetchXMLCriteriaWithVirtualAttributes(meta, entityAlias, attrName, inPred.NotDefined ? @operator.notin : @operator.@in, inPred.Values.Cast().ToArray(), metadata, options, targetEntityAlias, items, parameterTypes, out condition, out filter); + return TranslateFetchXMLCriteriaWithVirtualAttributes(context, meta, entityAlias, attrName, inPred.NotDefined ? @operator.notin : @operator.@in, inPred.Values.Cast().ToArray(), metadata, targetEntityAlias, items, out condition, out filter); } if (criteria is BooleanIsNullExpression isNull) @@ -560,7 +547,7 @@ private bool TranslateFetchXMLCriteria(DataSource primaryDataSource, IAttributeM var entityName = AliasToEntityName(targetEntityAlias, targetEntityName, items, entityAlias); var meta = metadata[entityName]; - return TranslateFetchXMLCriteriaWithVirtualAttributes(meta, entityAlias, attrName, isNull.IsNot ? @operator.notnull : @operator.@null, null, metadata, options, targetEntityAlias, items, parameterTypes, out condition, out filter); + return TranslateFetchXMLCriteriaWithVirtualAttributes(context, meta, entityAlias, attrName, isNull.IsNot ? @operator.notnull : @operator.@null, null, metadata, targetEntityAlias, items, out condition, out filter); } if (criteria is LikePredicate like) @@ -589,7 +576,7 @@ private bool TranslateFetchXMLCriteria(DataSource primaryDataSource, IAttributeM var entityName = AliasToEntityName(targetEntityAlias, targetEntityName, items, entityAlias); var meta = metadata[entityName]; - return TranslateFetchXMLCriteriaWithVirtualAttributes(meta, entityAlias, attrName, like.NotDefined ? @operator.notlike : @operator.like, new[] { value }, metadata, options, targetEntityAlias, items, parameterTypes, out condition, out filter); + return TranslateFetchXMLCriteriaWithVirtualAttributes(context, meta, entityAlias, attrName, like.NotDefined ? @operator.notlike : @operator.like, new[] { value }, metadata, targetEntityAlias, items, out condition, out filter); } if (criteria is FullTextPredicate || @@ -636,7 +623,7 @@ private bool TranslateFetchXMLCriteria(DataSource primaryDataSource, IAttributeM if (!(attr is MultiSelectPicklistAttributeMetadata)) return false; - return TranslateFetchXMLCriteriaWithVirtualAttributes(meta, entityAlias, attrName, not == null ? @operator.containvalues : @operator.notcontainvalues, valueParts.Select(v => new IntegerLiteral { Value = v }).ToArray(), metadata, options, targetEntityAlias, items, parameterTypes, out condition, out filter); + return TranslateFetchXMLCriteriaWithVirtualAttributes(context, meta, entityAlias, attrName, not == null ? @operator.containvalues : @operator.notcontainvalues, valueParts.Select(v => new IntegerLiteral { Value = v }).ToArray(), metadata, targetEntityAlias, items, out condition, out filter); } return false; @@ -645,6 +632,7 @@ private bool TranslateFetchXMLCriteria(DataSource primaryDataSource, IAttributeM /// /// Handles special cases for virtual attributes in FetchXML conditions /// + /// The context the query is being built in /// The metadata for the target entity the condition is for /// The alias of the entity in the query the condition is for /// The logical name of the attribute the condition is for @@ -653,12 +641,10 @@ private bool TranslateFetchXMLCriteria(DataSource primaryDataSource, IAttributeM /// The to use to get metadata /// The alias of the root entity that the FetchXML query is targetting /// The child items of the root entity in the FetchXML query - /// The options that control how the query will be executed - /// The types of any parameters that can be used /// The FetchXML version of the that is generated by this method when it covers multiple conditions /// The FetchXML version of the that is generated by this method when it is for a single condition only /// true if the condition can be translated to FetchXML, or false otherwise - private bool TranslateFetchXMLCriteriaWithVirtualAttributes(EntityMetadata meta, string entityAlias, string attrName, @operator op, ValueExpression[] literals, IAttributeMetadataCache metadata, IQueryExecutionOptions options, string targetEntityAlias, object[] items, IDictionary parameterTypes, out condition condition, out filter filter) + private bool TranslateFetchXMLCriteriaWithVirtualAttributes(NodeCompilationContext context, EntityMetadata meta, string entityAlias, string attrName, @operator op, ValueExpression[] literals, IAttributeMetadataCache metadata, string targetEntityAlias, object[] items, out condition condition, out filter filter) { condition = null; filter = null; @@ -729,7 +715,7 @@ private bool TranslateFetchXMLCriteriaWithVirtualAttributes(EntityMetadata meta, DateTimeOffset dto; - if (options.UseLocalTimeZone) + if (context.Options.UseLocalTimeZone) dto = new DateTimeOffset(dt, TimeZoneInfo.Local.GetUtcOffset(dt)); else dto = new DateTimeOffset(dt, TimeSpan.Zero); @@ -857,7 +843,7 @@ private bool TranslateFetchXMLCriteriaWithVirtualAttributes(EntityMetadata meta, if (values[i].IsVariable) { // Variables must be an integer type - var variableType = parameterTypes[values[i].Value].ToNetType(out _); + var variableType = context.ParameterTypes[values[i].Value].ToNetType(out _); if (variableType != typeof(SqlInt32)) return false; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs index 9394d61c..416f6780 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs @@ -113,27 +113,23 @@ public void Dispose() /// /// Executes the DML query and returns an appropriate log message /// - /// The to use to get the data - /// The to use to get metadata - /// to indicate how the query can be executed - /// A mapping of parameter names to their related types - /// A mapping of parameter names to their current values + /// The context in which the node is being executed + /// The number of records that were affected by the query /// A log message to display - public abstract string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, out int recordsAffected); + public abstract string Execute(NodeExecutionContext context, out int recordsAffected); /// /// Attempts to fold this node into its source to simplify the query /// - /// The to use to get metadata - /// to indicate how the query can be executed - /// A mapping of parameter names to their related types + /// The context in which the node is being built + /// Any hints that can control the folding of this node /// The node that should be used in place of this node - public virtual IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public virtual IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext context, IList hints) { if (Source is IDataExecutionPlanNodeInternal dataNode) - Source = dataNode.FoldQuery(dataSources, options, parameterTypes, hints); + Source = dataNode.FoldQuery(context, hints); else if (Source is IDataReaderExecutionPlanNode dataSetNode) - Source = dataSetNode.FoldQuery(dataSources, options, parameterTypes, hints).Single(); + Source = dataSetNode.FoldQuery(context, hints).Single(); if (Source is AliasNode alias) { @@ -142,16 +138,16 @@ public virtual IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary alias.Alias + "." + col.OutputColumn, col => col.SourceColumn, StringComparer.OrdinalIgnoreCase)); } - MaxDOP = GetMaxDOP(hints, options); - BypassCustomPluginExecution = GetBypassPluginExecution(hints, options); + MaxDOP = GetMaxDOP(context, hints); + BypassCustomPluginExecution = GetBypassPluginExecution(context, hints); return new[] { this }; } - private int GetMaxDOP(IList queryHints, IQueryExecutionOptions options) + private int GetMaxDOP(NodeCompilationContext context, IList queryHints) { if (queryHints == null) - return options.MaxDegreeOfParallelism; + return context.Options.MaxDegreeOfParallelism; var maxDopHint = queryHints .OfType() @@ -166,20 +162,20 @@ private int GetMaxDOP(IList queryHints, IQueryExecutionOptions op return value; } - return options.MaxDegreeOfParallelism; + return context.Options.MaxDegreeOfParallelism; } - private bool GetBypassPluginExecution(IList queryHints, IQueryExecutionOptions options) + private bool GetBypassPluginExecution(NodeCompilationContext context, IList queryHints) { if (queryHints == null) - return options.BypassCustomPlugins; + return context.Options.BypassCustomPlugins; var bypassPluginExecution = queryHints .OfType() .Where(hint => hint.Hints.Any(s => s.Value.Equals("BYPASS_CUSTOM_PLUGIN_EXECUTION", StringComparison.OrdinalIgnoreCase))) .Any(); - return bypassPluginExecution || options.BypassCustomPlugins; + return bypassPluginExecution || context.Options.BypassCustomPlugins; } /// @@ -196,31 +192,27 @@ public override IEnumerable GetSources() /// /// Gets the records to perform the DML operation on /// - /// The to use to get the data - /// The to use to get metadata - /// to indicate how the query can be executed - /// A mapping of parameter names to their related types - /// A mapping of parameter names to their current values + /// The context in which the node is being executed /// The schema of the data source /// The entities to perform the DML operation on - protected List GetDmlSourceEntities(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, out INodeSchema schema) + protected List GetDmlSourceEntities(NodeExecutionContext context, out INodeSchema schema) { List entities; if (Source is IDataExecutionPlanNodeInternal dataSource) { - schema = dataSource.GetSchema(dataSources, parameterTypes); - entities = dataSource.Execute(dataSources, options, parameterTypes, parameterValues).ToList(); + schema = dataSource.GetSchema(context); + entities = dataSource.Execute(context).ToList(); } else if (Source is IDataReaderExecutionPlanNode dataSetSource) { - var dataReader = dataSetSource.Execute(dataSources, options, parameterTypes, parameterValues, CommandBehavior.Default); + var dataReader = dataSetSource.Execute(context, CommandBehavior.Default); // Store the values under the column index as well as name for compatibility with INSERT ... SELECT ... var dataTable = new DataTable(); var schemaTable = dataReader.GetSchemaTable(); var columnTypes = new Dictionary(StringComparer.OrdinalIgnoreCase); - var targetDataSource = dataSources[DataSource]; + var targetDataSource = context.DataSources[DataSource]; for (var i = 0; i < schemaTable.Rows.Count; i++) { diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseJoinNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseJoinNode.cs index c6f072d2..1b46b3fa 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseJoinNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseJoinNode.cs @@ -105,15 +105,15 @@ public override IEnumerable GetSources() yield return RightSource; } - public override INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { - return GetSchema(dataSources, parameterTypes, false); + return GetSchema(context, false); } - protected virtual INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes, bool includeSemiJoin) + protected virtual INodeSchema GetSchema(NodeCompilationContext context, bool includeSemiJoin) { - var outerSchema = LeftSource.GetSchema(dataSources, parameterTypes); - var innerSchema = GetRightSchema(dataSources, parameterTypes); + var outerSchema = LeftSource.GetSchema(context); + var innerSchema = GetRightSchema(context); if (outerSchema == _lastLeftSchema && innerSchema == _lastRightSchema) return _lastSchema; @@ -178,9 +178,9 @@ protected virtual IReadOnlyList GetSortOrder(INodeSchema outerSchema, IN return null; } - protected virtual INodeSchema GetRightSchema(IDictionary dataSources, IDictionary parameterTypes) + protected virtual INodeSchema GetRightSchema(NodeCompilationContext context) { - return RightSource.GetSchema(dataSources, parameterTypes); + return RightSource.GetSchema(context); } public override string ToString() diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseNode.cs index 35e2d453..1060f2fb 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseNode.cs @@ -43,10 +43,9 @@ abstract class BaseNode : IExecutionPlanNode /// /// Adds columns to the data source that are required by parent nodes /// - /// The to use to get metadata - /// A mapping of parameter names to their related types + /// The context in which the node is being built /// The names of columns that are required by the parent node - public abstract void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns); + public abstract void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns); /// /// Gets the name to show for an entity diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BulkDeleteJobNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BulkDeleteJobNode.cs index 0cc9ecdf..0e10cf80 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BulkDeleteJobNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BulkDeleteJobNode.cs @@ -39,22 +39,22 @@ class BulkDeleteJobNode : BaseNode, IDmlQueryExecutionPlanNode [Browsable(false)] public FetchXmlScan Source { get; set; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { } - public string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, out int recordsAffected) + public string Execute(NodeExecutionContext context, out int recordsAffected) { _executionCount++; try { - if (!dataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) throw new QueryExecutionException("Missing datasource " + DataSource); using (_timer.Run()) { - Source.ApplyParameterValues(options, parameterValues); + Source.ApplyParameterValues(context); var query = ((FetchXmlToQueryExpressionResponse)dataSource.Connection.Execute(new FetchXmlToQueryExpressionRequest { FetchXml = Source.FetchXmlString })).Query; var meta = dataSource.Metadata[query.EntityName]; @@ -72,7 +72,7 @@ public string Execute(IDictionary dataSources, IQueryExecuti var resp = (BulkDeleteResponse)dataSource.Connection.Execute(req); recordsAffected = 1; - parameterValues["@@IDENTITY"] = new SqlEntityReference(DataSource, "asyncoperation", resp.JobId); + context.ParameterValues["@@IDENTITY"] = new SqlEntityReference(DataSource, "asyncoperation", resp.JobId); return $"Bulk delete job started"; } } @@ -89,7 +89,7 @@ public string Execute(IDictionary dataSources, IQueryExecuti } } - public IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext context, IList hints) { return new[] { this }; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ComputeScalarNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ComputeScalarNode.cs index 533ecb76..f6bd3489 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ComputeScalarNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ComputeScalarNode.cs @@ -28,26 +28,32 @@ class ComputeScalarNode : BaseDataNode, ISingleSourceExecutionPlanNode [Browsable(false)] public IDataExecutionPlanNodeInternal Source { get; set; } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - var schema = Source.GetSchema(dataSources, parameterTypes); + var schema = Source.GetSchema(context); + var expressionCompilationContext = new ExpressionCompilationContext(context, schema, null); var columns = Columns - .Select(kvp => new { Name = kvp.Key, Expression = kvp.Value.Compile(dataSources[options.PrimaryDataSource], schema, parameterTypes) }) + .Select(kvp => new { Name = kvp.Key, Expression = kvp.Value.Compile(expressionCompilationContext) }) .ToList(); - foreach (var entity in Source.Execute(dataSources, options, parameterTypes, parameterValues)) + var expressionContext = new ExpressionExecutionContext(context); + + foreach (var entity in Source.Execute(context)) { + expressionContext.Entity = entity; + foreach (var col in columns) - entity[col.Name] = col.Expression(entity, parameterValues, options); + entity[col.Name] = col.Expression(expressionContext); yield return entity; } } - public override INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { // Copy the source schema and add in the additional computed columns - var sourceSchema = Source.GetSchema(dataSources, parameterTypes); + var sourceSchema = Source.GetSchema(context); + var expressionCompilationContext = new ExpressionCompilationContext(context, sourceSchema, null); var schema = new Dictionary(sourceSchema.Schema.Count, StringComparer.OrdinalIgnoreCase); foreach (var col in sourceSchema.Schema) @@ -55,7 +61,7 @@ public override INodeSchema GetSchema(IDictionary dataSource foreach (var calc in Columns) { - calc.Value.GetType(dataSources[options.PrimaryDataSource], sourceSchema, null, parameterTypes, out var calcType); + calc.Value.GetType(expressionCompilationContext, out var calcType); schema[calc.Key] = calcType; } @@ -72,9 +78,9 @@ public override IEnumerable GetSources() yield return Source; } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - Source = Source.FoldQuery(dataSources, options, parameterTypes, hints); + Source = Source.FoldQuery(context, hints); // Combine multiple ComputeScalar nodes. Calculations in this node might be dependent on those in the previous node, so rewrite any references // to the earlier computed columns @@ -93,9 +99,11 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary(); + var expressionContext = new ExpressionCompilationContext(context, null, null); foreach (var calc in Columns) { @@ -118,7 +126,7 @@ calc.Value is CastCall c2 && c2.Parameter is Literal || } else { - calc.Value.GetType(dataSources[options.PrimaryDataSource], null, null, parameterTypes, out var calcType); + calc.Value.GetType(expressionContext, out var calcType); constant.Schema[calc.Key] = calcType; } } @@ -135,9 +143,9 @@ calc.Value is CastCall c2 && c2.Parameter is Literal || return this; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { - var schema = Source.GetSchema(dataSources, parameterTypes); + var schema = Source.GetSchema(context); var calcSourceColumns = Columns.Values .SelectMany(expr => expr.GetColumns()); @@ -151,12 +159,12 @@ public override void AddRequiredColumns(IDictionary dataSour requiredColumns.Add(normalized); } - Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + Source.AddRequiredColumns(context, requiredColumns); } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { - return Source.EstimateRowsOut(dataSources, options, parameterTypes); + return Source.EstimateRowsOut(context); } protected override IEnumerable GetVariablesInternal() diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConcatenateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConcatenateNode.cs index 2ab3ced4..478d93e7 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConcatenateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConcatenateNode.cs @@ -28,13 +28,13 @@ class ConcatenateNode : BaseDataNode [DisplayName("Column Set")] public List ColumnSet { get; } = new List(); - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { for (var i = 0; i < Sources.Count; i++) { var source = Sources[i]; - foreach (var entity in source.Execute(dataSources, options, parameterTypes, parameterValues)) + foreach (var entity in source.Execute(context)) { var result = new Entity(entity.LogicalName, entity.Id); @@ -46,11 +46,11 @@ protected override IEnumerable ExecuteInternal(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { var schema = new Dictionary(StringComparer.OrdinalIgnoreCase); - var sourceSchema = Sources[0].GetSchema(dataSources, parameterTypes); + var sourceSchema = Sources[0].GetSchema(context); foreach (var col in ColumnSet) schema[col.OutputColumn] = sourceSchema.Schema[col.SourceColumns[0]]; @@ -68,25 +68,25 @@ public override IEnumerable GetSources() return Sources; } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { for (var i = 0; i < Sources.Count; i++) { - Sources[i] = Sources[i].FoldQuery(dataSources, options, parameterTypes, hints); + Sources[i] = Sources[i].FoldQuery(context, hints); Sources[i].Parent = this; } // Work out the column types - var sourceColumnTypes = Sources.Select((source, index) => GetColumnTypes(index, dataSources, parameterTypes)).ToArray(); + var sourceColumnTypes = Sources.Select((source, index) => GetColumnTypes(index, context)).ToArray(); var types = (DataTypeReference[]) sourceColumnTypes[0].Clone(); for (var i = 1; i < Sources.Count; i++) { - var nextTypes = GetColumnTypes(i, dataSources, parameterTypes); + var nextTypes = GetColumnTypes(i, context); for (var colIndex = 0; colIndex < types.Length; colIndex++) { - if (!SqlTypeConverter.CanMakeConsistentTypes(types[colIndex], nextTypes[colIndex], dataSources[options.PrimaryDataSource], out var colType)) + if (!SqlTypeConverter.CanMakeConsistentTypes(types[colIndex], nextTypes[colIndex], context.PrimaryDataSource, out var colType)) throw new NotSupportedQueryFragmentException("No available implicit type conversion", ColumnSet[colIndex].SourceExpressions[i]); types[colIndex] = colType; @@ -122,7 +122,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary 0) { - Sources[i] = conversion.FoldQuery(dataSources, options, parameterTypes, hints); + Sources[i] = conversion.FoldQuery(context, hints); Sources[i].Parent = this; if (Sources[i] is ComputeScalarNode foldedConversion) @@ -171,9 +171,9 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IDictionary parameterTypes) + private DataTypeReference[] GetColumnTypes(int sourceIndex, NodeCompilationContext context) { - var schema = Sources[sourceIndex].GetSchema(dataSources, parameterTypes); + var schema = Sources[sourceIndex].GetSchema(context); var types = new DataTypeReference[ColumnSet.Count]; for (var i = 0; i < ColumnSet.Count; i++) @@ -182,7 +182,7 @@ private DataTypeReference[] GetColumnTypes(int sourceIndex, IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { for (var i = 0; i < Sources.Count; i++) { @@ -191,13 +191,13 @@ public override void AddRequiredColumns(IDictionary dataSour .Distinct() .ToList(); - Sources[i].AddRequiredColumns(dataSources, parameterTypes, sourceRequiredColumns); + Sources[i].AddRequiredColumns(context, sourceRequiredColumns); } } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { - return new RowCountEstimate(Sources.Sum(s => s.EstimateRowsOut(dataSources, options, parameterTypes).Value)); + return new RowCountEstimate(Sources.Sum(s => s.EstimateRowsOut(context).Value)); } public override object Clone() diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConditionalNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConditionalNode.cs index c147d3fc..1727f31d 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConditionalNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConditionalNode.cs @@ -47,39 +47,39 @@ class ConditionalNode : BaseNode, IGoToNode internal string FalseLabel { get; private set; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { if (Source != null) - Source.AddRequiredColumns(dataSources, parameterTypes, new List(requiredColumns)); + Source.AddRequiredColumns(context, new List(requiredColumns)); foreach (var node in TrueStatements) - node.AddRequiredColumns(dataSources, parameterTypes, new List(requiredColumns)); + node.AddRequiredColumns(context, new List(requiredColumns)); if (FalseStatements != null) { foreach (var node in FalseStatements) - node.AddRequiredColumns(dataSources, parameterTypes, new List(requiredColumns)); + node.AddRequiredColumns(context, new List(requiredColumns)); } } - public string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + public string Execute(NodeExecutionContext context) { throw new NotSupportedException("Conditional node should have been converted to GOTO during query plan building"); } - public IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext context, IList hints) { TrueLabel = Guid.NewGuid().ToString(); FalseLabel = Guid.NewGuid().ToString(); - Source = Source?.FoldQuery(dataSources, options, parameterTypes, hints); + Source = Source?.FoldQuery(context, hints); TrueStatements = TrueStatements - .SelectMany(s => s.FoldQuery(dataSources, options, parameterTypes, hints)) + .SelectMany(s => s.FoldQuery(context, hints)) .ToArray(); FalseStatements = FalseStatements - ?.SelectMany(s => s.FoldQuery(dataSources, options, parameterTypes, hints)) + ?.SelectMany(s => s.FoldQuery(context, hints)) ?.ToArray(); if (hints != null && hints.OfType().Any()) @@ -94,13 +94,13 @@ public IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary Schema { get; private set; } = new Dictionary(); - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - foreach (var expressions in Values) + var compilationContext = new ExpressionCompilationContext(context, null, null); + var executionContext = new ExpressionExecutionContext(context); + + foreach (var row in Values) { var value = new Entity(); foreach (var col in Schema) - value[PrefixWithAlias(col.Key)] = expressions[col.Key].Compile(dataSources[options.PrimaryDataSource], null, parameterTypes)(null, parameterValues, options); + value[PrefixWithAlias(col.Key)] = row[col.Key].Compile(compilationContext)(executionContext); yield return value; } @@ -51,7 +54,7 @@ public override IEnumerable GetSources() return Array.Empty(); } - public override INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { return new NodeSchema( primaryKey: null, @@ -69,16 +72,16 @@ private string PrefixWithAlias(string columnName) return Alias + "." + columnName; } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { return this; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { return new RowCountEstimateDefiniteRange(Values.Count, Values.Count); } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ContinueBreakNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ContinueBreakNode.cs index 4e79bd52..31c7563b 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ContinueBreakNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ContinueBreakNode.cs @@ -25,7 +25,7 @@ class ContinueBreakNode : BaseNode, IGoToNode [Browsable(false)] public ContinueBreakNodeType Type { get; set; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { } @@ -40,12 +40,12 @@ public object Clone() }; } - public string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + public string Execute(NodeExecutionContext context) { throw new NotSupportedException(Type.ToString() + " node should have been converted to GOTO during query plan building"); } - public IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext context, IList hints) { if (hints != null && hints.OfType().Any()) return new[] { this }; @@ -71,14 +71,14 @@ public IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary Variables { get; } = new Dictionary(); - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { } - public string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, out int recordsAffected) + public string Execute(NodeExecutionContext context, out int recordsAffected) { _executionCount++; @@ -51,8 +51,8 @@ public string Execute(IDictionary dataSources, IQueryExecuti { foreach (var variable in Variables) { - parameterTypes[variable.Key] = variable.Value; - parameterValues[variable.Key] = SqlTypeConverter.GetNullValue(variable.Value.ToNetType(out _)); + context.ParameterTypes[variable.Key] = variable.Value; + context.ParameterValues[variable.Key] = SqlTypeConverter.GetNullValue(variable.Value.ToNetType(out _)); } } @@ -60,7 +60,7 @@ public string Execute(IDictionary dataSources, IQueryExecuti return null; } - public IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext context, IList hints) { return new[] { this }; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DeleteNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DeleteNode.cs index 47414999..a7552813 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DeleteNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DeleteNode.cs @@ -53,7 +53,7 @@ class DeleteNode : BaseDmlNode [Category("Delete")] public override bool BypassCustomPluginExecution { get; set; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { if (!requiredColumns.Contains(PrimaryIdSource)) requiredColumns.Add(PrimaryIdSource); @@ -61,21 +61,21 @@ public override void AddRequiredColumns(IDictionary dataSour if (SecondaryIdSource != null && !requiredColumns.Contains(SecondaryIdSource)) requiredColumns.Add(SecondaryIdSource); - Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + Source.AddRequiredColumns(context, requiredColumns); } - public override IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext context, IList hints) { - var result = base.FoldQuery(dataSources, options, parameterTypes, hints); + var result = base.FoldQuery(context, hints); if (result.Length != 1 || result[0] != this) return result; - if (!dataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Missing datasource " + DataSource); // Use bulk delete if requested & possible - if ((options.UseBulkDelete || LogicalName == "audit") && + if ((context.Options.UseBulkDelete || LogicalName == "audit") && Source is FetchXmlScan fetch && LogicalName == fetch.Entity.name && PrimaryIdSource.Equals($"{fetch.Alias}.{dataSource.Metadata[LogicalName].PrimaryIdAttribute}") && @@ -96,13 +96,13 @@ protected override void RenameSourceColumns(IDictionary columnRe SecondaryIdSource = secondaryIdSourceRenamed; } - public override string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, out int recordsAffected) + public override string Execute(NodeExecutionContext context, out int recordsAffected) { _executionCount++; try { - if (!dataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) throw new QueryExecutionException("Missing datasource " + DataSource); List entities; @@ -112,11 +112,11 @@ public override string Execute(IDictionary dataSources, IQue using (_timer.Run()) { - entities = GetDmlSourceEntities(dataSources, options, parameterTypes, parameterValues, out var schema); + entities = GetDmlSourceEntities(context, out var schema); // Precompile mappings with type conversions meta = dataSource.Metadata[LogicalName]; - var dateTimeKind = options.UseLocalTimeZone ? DateTimeKind.Local : DateTimeKind.Utc; + var dateTimeKind = context.Options.UseLocalTimeZone ? DateTimeKind.Local : DateTimeKind.Utc; var primaryKey = meta.PrimaryIdAttribute; string secondaryKey = null; @@ -150,9 +150,9 @@ public override string Execute(IDictionary dataSources, IQue // Check again that the update is allowed. Don't count any UI interaction in the execution time var confirmArgs = new ConfirmDmlStatementEventArgs(entities.Count, meta, BypassCustomPluginExecution); - if (options.CancellationToken.IsCancellationRequested) + if (context.Options.CancellationToken.IsCancellationRequested) confirmArgs.Cancel = true; - options.ConfirmDelete(confirmArgs); + context.Options.ConfirmDelete(confirmArgs); if (confirmArgs.Cancel) throw new OperationCanceledException("DELETE cancelled by user"); @@ -160,7 +160,7 @@ public override string Execute(IDictionary dataSources, IQue { return ExecuteDmlOperation( dataSource.Connection, - options, + context.Options, entities, meta, entity => CreateDeleteRequest(meta, entity, primaryIdAccessor, secondaryIdAccessor), @@ -171,7 +171,7 @@ public override string Execute(IDictionary dataSources, IQue CompletedLowercase = "deleted" }, out recordsAffected, - parameterValues); + context.ParameterValues); } } catch (QueryExecutionException ex) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DistinctNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DistinctNode.cs index e40825ae..f098294a 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DistinctNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DistinctNode.cs @@ -29,20 +29,20 @@ class DistinctNode : BaseDataNode, ISingleSourceExecutionPlanNode [Browsable(false)] public IDataExecutionPlanNodeInternal Source { get; set; } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { var distinct = new HashSet(new DistinctEqualityComparer(Columns)); - foreach (var entity in Source.Execute(dataSources, options, parameterTypes, parameterValues)) + foreach (var entity in Source.Execute(context)) { if (distinct.Add(entity)) yield return entity; } } - public override INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { - var schema = Source.GetSchema(dataSources, parameterTypes); + var schema = Source.GetSchema(context); // If this is a distinct list of one column we know the values in that column will be unique if (Columns.Count == 1) @@ -61,9 +61,9 @@ public override IEnumerable GetSources() yield return Source; } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - Source = Source.FoldQuery(dataSources, options, parameterTypes, hints); + Source = Source.FoldQuery(context, hints); Source.Parent = this; // Remove any duplicated column names @@ -75,7 +75,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { foreach (var col in Columns) { @@ -153,15 +153,15 @@ public override void AddRequiredColumns(IDictionary dataSour requiredColumns.Add(col); } - Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + Source.AddRequiredColumns(context, requiredColumns); } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { // TODO: Is there any metadata available that could help give a better estimate for this? // Maybe get the schema and check if any of the columns included in the DISTINCT list are the // primary key and if so return the entire count, if some are optionset then there's a known list - var totalCount = Source.EstimateRowsOut(dataSources, options, parameterTypes); + var totalCount = Source.EstimateRowsOut(context); if (totalCount is RowCountEstimateDefiniteRange range && range.Maximum == 1) return totalCount; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteAsNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteAsNode.cs index 0a5383ce..30a5a4ab 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteAsNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteAsNode.cs @@ -37,15 +37,15 @@ class ExecuteAsNode : BaseDmlNode, IImpersonateRevertExecutionPlanNode [Browsable(false)] public override bool BypassCustomPluginExecution { get; set; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { if (!requiredColumns.Contains(UserIdSource)) requiredColumns.Add(UserIdSource); - Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + Source.AddRequiredColumns(context, requiredColumns); } - public override string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, out int recordsAffected) + public override string Execute(NodeExecutionContext context, out int recordsAffected) { _executionCount++; @@ -53,10 +53,10 @@ public override string Execute(IDictionary dataSources, IQue { using (_timer.Run()) { - if (!dataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) throw new QueryExecutionException("Missing datasource " + DataSource); - var entities = GetDmlSourceEntities(dataSources, options, parameterTypes, parameterValues, out var schema); + var entities = GetDmlSourceEntities(context, out var schema); if (entities.Count == 0) throw new QueryExecutionException("Cannot find user to impersonate"); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteMessageNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteMessageNode.cs index 3030db97..33a2469d 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteMessageNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteMessageNode.cs @@ -15,7 +15,7 @@ namespace MarkMpn.Sql4Cds.Engine.ExecutionPlan { class ExecuteMessageNode : BaseDataNode, IDmlQueryExecutionPlanNode { - private Dictionary, IQueryExecutionOptions, object>> _inputParameters; + private Dictionary> _inputParameters; private string _primaryKeyColumn; /// @@ -110,23 +110,25 @@ class ExecuteMessageNode : BaseDataNode, IDmlQueryExecutionPlanNode [DisplayName("Pages Retrieved")] public int PagesRetrieved { get; set; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { + var expressionContext = new ExpressionCompilationContext(context, null, null); + _inputParameters = Values .ToDictionary(value => value.Key, value => { - var exprType = value.Value.GetType(dataSources[options.PrimaryDataSource], null, null, parameterTypes, out _); + var exprType = value.Value.GetType(expressionContext, out _); var expectedType = ValueTypes[value.Key]; - var expr = value.Value.Compile(dataSources[options.PrimaryDataSource], null, parameterTypes); + var expr = value.Value.Compile(expressionContext); var conversion = SqlTypeConverter.GetConversion(exprType, expectedType); - return (Func, IQueryExecutionOptions, object>) ((IDictionary parameterValues, IQueryExecutionOptions opts) => conversion(expr(null, parameterValues, opts))); + return (Func) ((ExpressionExecutionContext ctx) => conversion(expr(ctx))); }); - BypassCustomPluginExecution = GetBypassPluginExecution(hints, options); + BypassCustomPluginExecution = GetBypassPluginExecution(hints, context.Options); return this; } @@ -144,7 +146,7 @@ private bool GetBypassPluginExecution(IList queryHints, IQueryExe return bypassPluginExecution || options.BypassCustomPlugins; } - public override INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { return new NodeSchema( primaryKey: null, @@ -235,7 +237,7 @@ public override IEnumerable GetSources() return Array.Empty(); } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { if (EntityCollectionResponseParameter == null) return RowCountEstimateDefiniteRange.ExactlyOne; @@ -248,24 +250,25 @@ protected override IEnumerable GetVariablesInternal() return Values.Values.SelectMany(v => v.GetVariables()); } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { PagesRetrieved = 0; - if (!dataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Missing datasource " + DataSource); - options.Progress(0, $"Executing {MessageName}..."); + context.Options.Progress(0, $"Executing {MessageName}..."); // Get the first page of results - if (!options.ContinueRetrieve(0)) + if (!context.Options.ContinueRetrieve(0)) yield break; var request = new OrganizationRequest(MessageName); var pageNumber = 1; + var expressionContext = new ExpressionExecutionContext(context); foreach (var value in _inputParameters) - request[value.Key] = value.Value(parameterValues, options); + request[value.Key] = value.Value(expressionContext); if (BypassCustomPluginExecution) request.Parameters["BypassCustomPluginExecution"] = true; @@ -278,11 +281,11 @@ protected override IEnumerable ExecuteInternal(IDictionary ExecuteInternal(IDictionary parameterTypes) + public static ExecuteMessageNode FromMessage(SchemaObjectFunctionTableReference tvf, DataSource dataSource, ExpressionCompilationContext context) { // All messages are in the "dbo" schema if (tvf.SchemaObject.SchemaIdentifier != null && !String.IsNullOrEmpty(tvf.SchemaObject.SchemaIdentifier.Value) && @@ -539,8 +542,8 @@ public static ExecuteMessageNode FromMessage(SchemaObjectFunctionTableReference { var f = expectedInputParameters[i]; var sourceExpression = tvf.Parameters[i]; - sourceExpression.GetType(primaryDataSource, null, null, parameterTypes, out var sourceType); - var expectedType = SqlTypeConverter.NetToSqlType(f.Type).ToSqlType(primaryDataSource); + sourceExpression.GetType(context, out var sourceType); + var expectedType = SqlTypeConverter.NetToSqlType(f.Type).ToSqlType(context.PrimaryDataSource); if (!SqlTypeConverter.CanChangeTypeImplicit(sourceType, expectedType)) throw new NotSupportedQueryFragmentException($"Cannot convert value of type {sourceType.ToSql()} to {expectedType.ToSql()}", tvf.Parameters[f.Position]); @@ -566,7 +569,7 @@ public static ExecuteMessageNode FromMessage(SchemaObjectFunctionTableReference return node; } - public static ExecuteMessageNode FromMessage(ExecutableProcedureReference sproc, DataSource dataSource, DataSource primaryDataSource, IDictionary parameterTypes) + public static ExecuteMessageNode FromMessage(ExecutableProcedureReference sproc, DataSource dataSource, ExpressionCompilationContext context) { // All messages are in the "dbo" schema if (sproc.ProcedureReference.ProcedureReference.Name.SchemaIdentifier != null && !String.IsNullOrEmpty(sproc.ProcedureReference.ProcedureReference.Name.SchemaIdentifier.Value) && @@ -647,8 +650,8 @@ public static ExecuteMessageNode FromMessage(ExecutableProcedureReference sproc, throw new NotSupportedQueryFragmentException("Unknown parameter", sproc.Parameters[i]); var sourceExpression = sproc.Parameters[i].ParameterValue; - sourceExpression.GetType(primaryDataSource, null, null, parameterTypes, out var sourceType); - var expectedType = SqlTypeConverter.NetToSqlType(targetParam.Type).ToSqlType(primaryDataSource); + sourceExpression.GetType(context, out var sourceType); + var expectedType = SqlTypeConverter.NetToSqlType(targetParam.Type).ToSqlType(context.PrimaryDataSource); if (!SqlTypeConverter.CanChangeTypeImplicit(sourceType, expectedType)) throw new NotSupportedQueryFragmentException($"Cannot convert value of type {sourceType.ToSql()} to {expectedType.ToSql()}", sproc.Parameters[i].ParameterValue); @@ -684,15 +687,15 @@ public static ExecuteMessageNode FromMessage(ExecutableProcedureReference sproc, return node; } - public string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, out int recordsAffected) + public string Execute(NodeExecutionContext context, out int recordsAffected) { - recordsAffected = Execute(dataSources, options, parameterTypes, parameterValues).Count(); + recordsAffected = Execute(context).Count(); return "Executed " + MessageName; } - IRootExecutionPlanNodeInternal[] IRootExecutionPlanNodeInternal.FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + IRootExecutionPlanNodeInternal[] IRootExecutionPlanNodeInternal.FoldQuery(NodeCompilationContext context, IList hints) { - FoldQuery(dataSources, options, parameterTypes, hints); + FoldQuery(context, hints); return new[] { this }; } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs index 96841cf6..eb1c4b50 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs @@ -25,18 +25,13 @@ static class ExpressionExtensions /// Gets the type of value that will be generated by an expression /// /// The expression to get the type of - /// The schema of the node that the expression will be evaluated in the context of - /// For aggregate queries, the schema of the data prior to applying the aggregation - /// A mapping of parameter names to their types that are available to the expression + /// The context the expression is being compiled in /// The SQL data type that will be returned /// The type of value that will be returned by the expression - public static Type GetType(this TSqlFragment expr, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, out DataTypeReference sqlType) + public static Type GetType(this TSqlFragment expr, ExpressionCompilationContext context, out DataTypeReference sqlType) { - var entityParam = Expression.Parameter(typeof(Entity)); - var parameterParam = Expression.Parameter(typeof(IDictionary)); - var optionsParam = Expression.Parameter(typeof(IQueryExecutionOptions)); - - var expression = ToExpression(expr, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + var contextParam = Expression.Parameter(typeof(ExpressionExecutionContext)); + var expression = ToExpression(expr, context, contextParam, out sqlType); return expression.Type; } @@ -44,110 +39,102 @@ public static Type GetType(this TSqlFragment expr, DataSource primaryDataSource, /// Compiles an expression to a function /// /// The expression to be compiled - /// The schema of the node that the expression will be evaluated in the context of - /// A mapping of parameter names to their types that are available to the expression - /// A function that accepts a representing the data values of a record, a holding parameter values and an defining how the query should be run and returns the value of the expression - public static Func, IQueryExecutionOptions, object> Compile(this TSqlFragment expr, DataSource primaryDataSource, INodeSchema schema, IDictionary parameterTypes) + /// The context the expression is being compiled in + /// A function that accepts an representing the context the expression is being evaluated in and returns the value of the expression + public static Func Compile(this TSqlFragment expr, ExpressionCompilationContext context) { - var entityParam = Expression.Parameter(typeof(Entity)); - var parameterParam = Expression.Parameter(typeof(IDictionary)); - var optionsParam = Expression.Parameter(typeof(IQueryExecutionOptions)); - - var expression = ToExpression(expr, primaryDataSource, schema, null, parameterTypes, entityParam, parameterParam, optionsParam, out _); + var contextParam = Expression.Parameter(typeof(ExpressionExecutionContext)); + var expression = ToExpression(expr, context, contextParam, out _); expression = Expr.Box(expression); - return Expression.Lambda, IQueryExecutionOptions, object>>(expression, entityParam, parameterParam, optionsParam).Compile(); + return Expression.Lambda>(expression, contextParam).Compile(); } /// /// Compiles a boolean expression to a function /// /// The expression to be compiled - /// The schema of the node that the expression will be evaluated in the context of - /// A mapping of parameter names to their types that are available to the expression - /// A function that accepts a representing the data values of a record, a holding parameter values and an defining how the query should be run and returns the value of the expression - public static Func, IQueryExecutionOptions, bool> Compile(this BooleanExpression b, DataSource primaryDataSource, INodeSchema schema, IDictionary parameterTypes) + /// The context the expression is being compiled in + /// A function that accepts aan representing the context the expression is being evaluated in and returns the value of the expression + public static Func Compile(this BooleanExpression b, ExpressionCompilationContext context) { - var entityParam = Expression.Parameter(typeof(Entity)); - var parameterParam = Expression.Parameter(typeof(IDictionary)); - var optionsParam = Expression.Parameter(typeof(IQueryExecutionOptions)); - - var expression = ToExpression(b, primaryDataSource, schema, null, parameterTypes, entityParam, parameterParam, optionsParam, out _); + var contextParam = Expression.Parameter(typeof(ExpressionExecutionContext)); + var expression = ToExpression(b, context, contextParam, out _); expression = Expression.IsTrue(expression); - return Expression.Lambda, IQueryExecutionOptions, bool>>(expression, entityParam, parameterParam, optionsParam).Compile(); + return Expression.Lambda>(expression, contextParam).Compile(); } - private static Expression ToExpression(this TSqlFragment expr, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this TSqlFragment expr, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { if (expr is ColumnReferenceExpression col) - return ToExpression(col, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(col, context, contextParam, out sqlType); else if (expr is IdentifierLiteral guid) - return ToExpression(guid, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(guid, context, contextParam, out sqlType); else if (expr is IntegerLiteral i) - return ToExpression(i, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(i, context, contextParam, out sqlType); else if (expr is MoneyLiteral money) - return ToExpression(money, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(money, context, contextParam, out sqlType); else if (expr is NullLiteral n) - return ToExpression(n, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(n, context, contextParam, out sqlType); else if (expr is NumericLiteral num) - return ToExpression(num, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(num, context, contextParam, out sqlType); else if (expr is RealLiteral real) - return ToExpression(real, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(real, context, contextParam, out sqlType); else if (expr is StringLiteral str) - return ToExpression(str, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(str, context, contextParam, out sqlType); else if (expr is OdbcLiteral odbc) - return ToExpression(odbc, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(odbc, context, contextParam, out sqlType); else if (expr is BooleanBinaryExpression boolBin) - return ToExpression(boolBin, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(boolBin, context, contextParam, out sqlType); else if (expr is BooleanComparisonExpression cmp) - return ToExpression(cmp, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(cmp, context, contextParam, out sqlType); else if (expr is BooleanParenthesisExpression boolParen) - return ToExpression(boolParen, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(boolParen, context, contextParam, out sqlType); else if (expr is InPredicate inPred) - return ToExpression(inPred, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(inPred, context, contextParam, out sqlType); else if (expr is BooleanIsNullExpression isNull) - return ToExpression(isNull, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(isNull, context, contextParam, out sqlType); else if (expr is LikePredicate like) - return ToExpression(like, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(like, context, contextParam, out sqlType); else if (expr is BooleanNotExpression not) - return ToExpression(not, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(not, context, contextParam, out sqlType); else if (expr is FullTextPredicate fullText) - return ToExpression(fullText, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(fullText, context, contextParam, out sqlType); else if (expr is Microsoft.SqlServer.TransactSql.ScriptDom.BinaryExpression bin) - return ToExpression(bin, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(bin, context, contextParam, out sqlType); else if (expr is FunctionCall func) - return ToExpression(func, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(func, context, contextParam, out sqlType); else if (expr is ParenthesisExpression paren) - return ToExpression(paren, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(paren, context, contextParam, out sqlType); else if (expr is Microsoft.SqlServer.TransactSql.ScriptDom.UnaryExpression unary) - return ToExpression(unary, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(unary, context, contextParam, out sqlType); else if (expr is VariableReference var) - return ToExpression(var, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(var, context, contextParam, out sqlType); else if (expr is SimpleCaseExpression simpleCase) - return ToExpression(simpleCase, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(simpleCase, context, contextParam, out sqlType); else if (expr is SearchedCaseExpression searchedCase) - return ToExpression(searchedCase, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(searchedCase, context, contextParam, out sqlType); else if (expr is ConvertCall convert) - return ToExpression(convert, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(convert, context, contextParam, out sqlType); else if (expr is CastCall cast) - return ToExpression(cast, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(cast, context, contextParam, out sqlType); else if (expr is ParameterlessCall parameterless) - return ToExpression(parameterless, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(parameterless, context, contextParam, out sqlType); else if (expr is GlobalVariableExpression global) - return ToExpression(global, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(global, context, contextParam, out sqlType); else throw new NotSupportedQueryFragmentException("Unhandled expression type", expr); } - private static Expression ToExpression(ColumnReferenceExpression col, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(ColumnReferenceExpression col, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { var name = col.GetColumnName(); - if (schema == null || !schema.ContainsColumn(name, out var normalizedName)) + if (context.Schema == null || !context.Schema.ContainsColumn(name, out var normalizedName)) { - if (schema == null || !schema.Aliases.TryGetValue(name, out var normalized)) + if (context.Schema == null || !context.Schema.Aliases.TryGetValue(name, out var normalized)) { - if (nonAggregateSchema != null && nonAggregateSchema.ContainsColumn(name, out _)) + if (context.NonAggregateSchema != null && context.NonAggregateSchema.ContainsColumn(name, out _)) throw new NotSupportedQueryFragmentException("Column is invalid in the select list because it is not contained in either an aggregate function or the GROUP BY clause", col); var ex = new NotSupportedQueryFragmentException("Unknown column", col); @@ -168,52 +155,53 @@ private static Expression ToExpression(ColumnReferenceExpression col, DataSource }; } - sqlType = schema.Schema[normalizedName]; - var expr = Expression.Property(entityParam, typeof(Entity).GetCustomAttribute().MemberName, Expression.Constant(normalizedName)); + sqlType = context.Schema.Schema[normalizedName]; + var entity = Expression.Property(contextParam, nameof(ExpressionExecutionContext.Entity)); + var expr = Expression.Property(entity, typeof(Entity).GetCustomAttribute().MemberName, Expression.Constant(normalizedName)); return Expression.Convert(expr, sqlType.ToNetType(out _)); } - private static Expression ToExpression(IdentifierLiteral guid, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(IdentifierLiteral guid, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { sqlType = DataTypeHelpers.UniqueIdentifier; return Expression.Constant(new SqlGuid(guid.Value)); } - private static Expression ToExpression(IntegerLiteral i, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(IntegerLiteral i, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { sqlType = DataTypeHelpers.Int; return Expression.Constant(new SqlInt32(Int32.Parse(i.Value, CultureInfo.InvariantCulture))); } - private static Expression ToExpression(MoneyLiteral money, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(MoneyLiteral money, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { sqlType = DataTypeHelpers.Money; return Expression.Constant(new SqlMoney(Decimal.Parse(money.Value, CultureInfo.InvariantCulture))); } - private static Expression ToExpression(NullLiteral n, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(NullLiteral n, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { sqlType = DataTypeHelpers.ImplicitIntForNullLiteral; return Expression.Constant(SqlInt32.Null); } - private static Expression ToExpression(NumericLiteral num, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(NumericLiteral num, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { var value = new SqlDecimal(Decimal.Parse(num.Value, CultureInfo.InvariantCulture)); sqlType = DataTypeHelpers.Decimal(value.Precision, value.Scale); return Expression.Constant(value); } - private static Expression ToExpression(RealLiteral real, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(RealLiteral real, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { sqlType = DataTypeHelpers.Real; return Expression.Constant(new SqlDouble(Double.Parse(real.Value, CultureInfo.InvariantCulture))); } - private static Expression ToExpression(StringLiteral str, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(StringLiteral str, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { var collationLabel = CollationLabel.CoercibleDefault; - var collation = GetCollation(primaryDataSource, str.Collation, ref collationLabel); + var collation = GetCollation(context.PrimaryDataSource, str.Collation, ref collationLabel); sqlType = str.IsNational ? DataTypeHelpers.NVarChar(str.Value.Length, collation, collationLabel) @@ -222,7 +210,7 @@ private static Expression ToExpression(StringLiteral str, DataSource primaryData return Expression.Constant(collation.ToSqlString(str.Value)); } - private static Expression ToExpression(OdbcLiteral odbc, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(OdbcLiteral odbc, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { switch (odbc.OdbcLiteralType) { @@ -243,7 +231,7 @@ private static Expression ToExpression(OdbcLiteral odbc, DataSource primaryDataS } } - private static Expression ToExpression(BooleanComparisonExpression cmp, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(BooleanComparisonExpression cmp, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { // Special case for field = func() where func is defined in FetchXmlConditionMethods if (cmp.FirstExpression is ColumnReferenceExpression && @@ -253,15 +241,15 @@ cmp.SecondExpression is FunctionCall func { var parameters = func.Parameters.Select(p => { - var paramExpr = p.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var paramType); + var paramExpr = p.ToExpression(context, contextParam, out var paramType); return new KeyValuePair(paramExpr, paramType); }).ToList(); - var colExpr = cmp.FirstExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var colType); + var colExpr = cmp.FirstExpression.ToExpression(context, contextParam, out var colType); parameters.Insert(0, new KeyValuePair(colExpr, colType)); var paramTypes = parameters.Select(p => p.Value).ToArray(); var paramExpressions = parameters.Select(p => p.Key).ToArray(); - var fetchXmlComparison = GetMethod(typeof(FetchXmlConditionMethods), primaryDataSource, func, paramTypes, false, optionsParam, ref paramExpressions, out sqlType); + var fetchXmlComparison = GetMethod(typeof(FetchXmlConditionMethods), context.PrimaryDataSource, func, paramTypes, false, contextParam, ref paramExpressions, out sqlType); if (fetchXmlComparison != null) return Expr.Call(fetchXmlComparison, paramExpressions); @@ -269,10 +257,10 @@ cmp.SecondExpression is FunctionCall func sqlType = DataTypeHelpers.Bit; - var lhs = cmp.FirstExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var lhsType); - var rhs = cmp.SecondExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var rhsType); + var lhs = cmp.FirstExpression.ToExpression(context, contextParam, out var lhsType); + var rhs = cmp.SecondExpression.ToExpression(context, contextParam, out var rhsType); - if (!SqlTypeConverter.CanMakeConsistentTypes(lhsType, rhsType, primaryDataSource, out var type)) + if (!SqlTypeConverter.CanMakeConsistentTypes(lhsType, rhsType, context.PrimaryDataSource, out var type)) { // Special case - we can filter on entity reference types by string if (lhs.Type == typeof(SqlEntityReference) && rhs.Type == typeof(SqlString) || @@ -299,7 +287,7 @@ cmp.SecondExpression is StringLiteral str && || type.IsType(SqlDataTypeOption.UniqueIdentifier) && !Guid.TryParse(str.Value, out _) ) && - schema.ContainsColumn(col.GetColumnName() + "name", out var nameCol)) + context.Schema.ContainsColumn(col.GetColumnName() + "name", out var nameCol)) { throw new NotSupportedQueryFragmentException($"Cannot convert text value to {type.ToSql()}", str) { @@ -338,12 +326,12 @@ cmp.SecondExpression is StringLiteral str && } } - private static Expression ToExpression(BooleanBinaryExpression bin, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(BooleanBinaryExpression bin, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { sqlType = DataTypeHelpers.Bit; - var lhs = bin.FirstExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out _); - var rhs = bin.SecondExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out _); + var lhs = bin.FirstExpression.ToExpression(context, contextParam, out _); + var rhs = bin.SecondExpression.ToExpression(context, contextParam, out _); if (bin.BinaryExpressionType == BooleanBinaryExpressionType.And) return Expression.AndAlso(lhs, rhs); @@ -351,17 +339,17 @@ private static Expression ToExpression(BooleanBinaryExpression bin, DataSource p return Expression.OrElse(lhs, rhs); } - private static Expression ToExpression(BooleanParenthesisExpression paren, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(BooleanParenthesisExpression paren, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { - return paren.Expression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return paren.Expression.ToExpression(context, contextParam, out sqlType); } - private static Expression ToExpression(Microsoft.SqlServer.TransactSql.ScriptDom.BinaryExpression bin, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(Microsoft.SqlServer.TransactSql.ScriptDom.BinaryExpression bin, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { - var lhs = bin.FirstExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var lhsSqlType); - var rhs = bin.SecondExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var rhsSqlType); + var lhs = bin.FirstExpression.ToExpression(context, contextParam, out var lhsSqlType); + var rhs = bin.SecondExpression.ToExpression(context, contextParam, out var rhsSqlType); - if (!SqlTypeConverter.CanMakeConsistentTypes(lhsSqlType, rhsSqlType, primaryDataSource, out var type)) + if (!SqlTypeConverter.CanMakeConsistentTypes(lhsSqlType, rhsSqlType, context.PrimaryDataSource, out var type)) throw new NotSupportedQueryFragmentException($"No implicit conversion exists for types {lhsSqlType.ToSql()} and {rhsSqlType.ToSql()}", bin); // For decimal types, need to work out the precision and scale of the result depending on the type of operation @@ -451,8 +439,7 @@ rhsSqlType is SqlDataTypeReferenceWithCollation rhsSql && lhs.Type == typeof(SqlString) && rhs.Type == typeof(SqlString) && lhsSql.Parameters.Count == 1 && - rhsSql.Parameters.Count == 1 && - sqlType is SqlDataTypeReferenceWithCollation sqlTypeWithColl) + rhsSql.Parameters.Count == 1) { int lhsLength; int rhsLength; @@ -467,12 +454,15 @@ rhsSqlType is SqlDataTypeReferenceWithCollation rhsSql && var length = lhsLength + rhsLength; + if (!SqlDataTypeReferenceWithCollation.TryConvertCollation(lhsSql, rhsSql, out var collation, out var collationLabel)) + throw new NotSupportedQueryFragmentException($"Cannot resolve collation conflict between '{lhsSql.Collation.Name}' and {rhsSql.Collation.Name}' in add operation", bin); + sqlType = new SqlDataTypeReferenceWithCollation { SqlDataTypeOption = ((SqlDataTypeReference)type).SqlDataTypeOption, Parameters = { length <= 8000 ? (Literal)new IntegerLiteral { Value = length.ToString(CultureInfo.InvariantCulture) } : new MaxLiteral() }, - Collation = sqlTypeWithColl.Collation, - CollationLabel = sqlTypeWithColl.CollationLabel + Collation = collation, + CollationLabel = collationLabel }; } break; @@ -517,7 +507,7 @@ rhsSqlType is SqlDataTypeReferenceWithCollation rhsSql && sqlType = type; if (sqlType == null) - sqlType = expr.Type.ToSqlType(primaryDataSource); + sqlType = expr.Type.ToSqlType(context.PrimaryDataSource); return expr; } @@ -542,7 +532,7 @@ private static SqlDateTime SubtractSqlDateTime(SqlDateTime lhs, SqlDateTime rhs) return lhs - ts; } - private static MethodInfo GetMethod(FunctionCall func, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out Expression[] paramExpressions, out DataTypeReference sqlType) + private static MethodInfo GetMethod(FunctionCall func, ExpressionCompilationContext context, ParameterExpression contextParam, out Expression[] paramExpressions, out DataTypeReference sqlType) { KeyValuePair[] paramExpressionsWithType; @@ -569,10 +559,10 @@ private static MethodInfo GetMethod(FunctionCall func, DataSource primaryDataSou throw new NotSupportedQueryFragmentException("Expected a datepart name", param); } - return new KeyValuePair(Expression.Constant(col.MultiPartIdentifier.Identifiers.Single().Value), DataTypeHelpers.NVarChar(col.MultiPartIdentifier.Identifiers.Single().Value.Length, primaryDataSource.DefaultCollation, CollationLabel.CoercibleDefault)); + return new KeyValuePair(Expression.Constant(col.MultiPartIdentifier.Identifiers.Single().Value), DataTypeHelpers.NVarChar(col.MultiPartIdentifier.Identifiers.Single().Value.Length, context.PrimaryDataSource.DefaultCollation, CollationLabel.CoercibleDefault)); } - var paramExpr = param.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var paramType); + var paramExpr = param.ToExpression(context, contextParam, out var paramType); return new KeyValuePair(paramExpr, paramType); }) .ToArray(); @@ -582,7 +572,7 @@ private static MethodInfo GetMethod(FunctionCall func, DataSource primaryDataSou paramExpressionsWithType = func.Parameters .Select(param => { - var paramExpr = param.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var paramType); + var paramExpr = param.ToExpression(context, contextParam, out var paramType); return new KeyValuePair(paramExpr, paramType); }) .ToArray(); @@ -592,10 +582,10 @@ private static MethodInfo GetMethod(FunctionCall func, DataSource primaryDataSou .Select(kvp => kvp.Key) .ToArray(); - return GetMethod(typeof(ExpressionFunctions), primaryDataSource, func, paramExpressionsWithType.Select(kvp => kvp.Value).ToArray(), true, optionsParam, ref paramExpressions, out sqlType); + return GetMethod(typeof(ExpressionFunctions), context.PrimaryDataSource, func, paramExpressionsWithType.Select(kvp => kvp.Value).ToArray(), true, contextParam, ref paramExpressions, out sqlType); } - private static MethodInfo GetMethod(Type targetType, DataSource primaryDataSource, FunctionCall func, DataTypeReference[] paramTypes, bool throwOnMissing, Expression optionsParam, ref Expression[] paramExpressions, out DataTypeReference sqlType) + private static MethodInfo GetMethod(Type targetType, DataSource primaryDataSource, FunctionCall func, DataTypeReference[] paramTypes, bool throwOnMissing, ParameterExpression contextParam, ref Expression[] paramExpressions, out DataTypeReference sqlType) { // Find a method that implements this function var methods = targetType @@ -617,7 +607,7 @@ private static MethodInfo GetMethod(Type targetType, DataSource primaryDataSourc .Select(m => new { Method = m, Parameters = m.GetParameters() }) .Where(m => { - var allowedParameters = m.Parameters.Where(p => p.GetCustomAttribute() == null && p.ParameterType != typeof(IQueryExecutionOptions)); + var allowedParameters = m.Parameters.Where(p => p.GetCustomAttribute() == null && p.ParameterType != typeof(ExpressionExecutionContext)); var requiredParameters = allowedParameters.Where(p => p.GetCustomAttribute() == null); var isArrayParameter = requiredParameters.Any() && requiredParameters.Last().ParameterType.IsArray; @@ -697,11 +687,11 @@ private static MethodInfo GetMethod(Type targetType, DataSource primaryDataSourc if (i == parameters.Length - 1 && paramTypes.Length >= parameters.Length && paramType.IsArray) paramType = paramType.GetElementType(); - if (i == parameters.Length - 1 && paramTypes.Length < parameters.Length && paramType == typeof(IQueryExecutionOptions)) + if (i == parameters.Length - 1 && paramTypes.Length < parameters.Length && paramType == typeof(ExpressionExecutionContext)) { var paramsWithOptions = new Expression[paramExpressions.Length + 1]; paramExpressions.CopyTo(paramsWithOptions, 0); - paramsWithOptions[paramExpressions.Length] = optionsParam; + paramsWithOptions[paramExpressions.Length] = contextParam; paramExpressions = paramsWithOptions; break; } @@ -760,13 +750,13 @@ private static MethodInfo GetMethod(Type targetType, DataSource primaryDataSourc return method; } - private static Expression ToExpression(this FunctionCall func, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this FunctionCall func, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { if (func.OverClause != null) throw new NotSupportedQueryFragmentException("Window functions are not supported", func); // Find the method to call and get the expressions for the parameter values - var method = GetMethod(func, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var paramValues, out sqlType); + var method = GetMethod(func, context, contextParam, out var paramValues, out sqlType); // Convert the parameters to the expected types var parameters = method.GetParameters(); @@ -785,14 +775,14 @@ private static Expression ToExpression(this FunctionCall func, DataSource primar return expr; } - private static Expression ToExpression(this ParenthesisExpression paren, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this ParenthesisExpression paren, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { - return paren.Expression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return paren.Expression.ToExpression(context, contextParam, out sqlType); } - private static Expression ToExpression(this Microsoft.SqlServer.TransactSql.ScriptDom.UnaryExpression unary, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this Microsoft.SqlServer.TransactSql.ScriptDom.UnaryExpression unary, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { - var value = unary.Expression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + var value = unary.Expression.ToExpression(context, contextParam, out sqlType); switch (unary.UnaryExpressionType) { @@ -810,20 +800,20 @@ private static Expression ToExpression(this Microsoft.SqlServer.TransactSql.Scri } } - private static Expression ToExpression(this InPredicate inPred, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this InPredicate inPred, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { if (inPred.Subquery != null) throw new NotSupportedQueryFragmentException("Subquery should have been eliminated by query plan", inPred); - var exprValue = inPred.Expression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var exprType); + var exprValue = inPred.Expression.ToExpression(context, contextParam, out var exprType); Expression result = null; foreach (var value in inPred.Values) { - var comparisonValue = value.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var comparisonType); + var comparisonValue = value.ToExpression(context, contextParam, out var comparisonType); - if (!SqlTypeConverter.CanMakeConsistentTypes(exprType, comparisonType, primaryDataSource, out var type)) + if (!SqlTypeConverter.CanMakeConsistentTypes(exprType, comparisonType, context.PrimaryDataSource, out var type)) throw new NotSupportedQueryFragmentException($"No implicit conversion exists for types {exprType.ToSql()} and {comparisonType.ToSql()}", inPred); var convertedExprValue = exprValue; @@ -846,27 +836,29 @@ private static Expression ToExpression(this InPredicate inPred, DataSource prima return result; } - private static Expression ToExpression(this VariableReference var, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this VariableReference var, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { - if (parameterTypes == null || !parameterTypes.TryGetValue(var.Name, out sqlType)) + if (context.ParameterTypes == null || !context.ParameterTypes.TryGetValue(var.Name, out sqlType)) throw new NotSupportedQueryFragmentException("Undefined variable", var); - var expr = Expression.Property(parameterParam, typeof(IDictionary).GetCustomAttribute().MemberName, Expression.Constant(var.Name)); + var parameters = Expression.Property(contextParam, nameof(ExpressionExecutionContext.ParameterValues)); + var expr = Expression.Property(parameters, typeof(IDictionary).GetCustomAttribute().MemberName, Expression.Constant(var.Name)); return Expression.Convert(expr, sqlType.ToNetType(out _)); } - private static Expression ToExpression(this GlobalVariableExpression var, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this GlobalVariableExpression var, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { - if (parameterTypes == null || !parameterTypes.TryGetValue(var.Name, out sqlType)) + if (context.ParameterTypes == null || !context.ParameterTypes.TryGetValue(var.Name, out sqlType)) throw new NotSupportedQueryFragmentException("Undefined variable", var); - var expr = Expression.Property(parameterParam, typeof(IDictionary).GetCustomAttribute().MemberName, Expression.Constant(var.Name)); + var parameters = Expression.Property(contextParam, nameof(ExpressionExecutionContext.ParameterValues)); + var expr = Expression.Property(parameters, typeof(IDictionary).GetCustomAttribute().MemberName, Expression.Constant(var.Name)); return Expression.Convert(expr, sqlType.ToNetType(out _)); } - private static Expression ToExpression(this BooleanIsNullExpression isNull, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this BooleanIsNullExpression isNull, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { - var value = isNull.Expression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out _); + var value = isNull.Expression.ToExpression(context, contextParam, out _); value = SqlTypeConverter.NullCheck(value); if (isNull.IsNot) @@ -877,17 +869,17 @@ private static Expression ToExpression(this BooleanIsNullExpression isNull, Data return value; } - private static Expression ToExpression(this LikePredicate like, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this LikePredicate like, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { DataTypeReference escapeType = null; - var value = like.FirstExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var valueType); - var pattern = like.SecondExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var patternType); - var escape = like.EscapeExpression?.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out escapeType); + var value = like.FirstExpression.ToExpression(context, contextParam, out var valueType); + var pattern = like.SecondExpression.ToExpression(context, contextParam, out var patternType); + var escape = like.EscapeExpression?.ToExpression(context, contextParam, out escapeType); // TODO: Use the collations of the value/pattern and ensure they are consistent sqlType = DataTypeHelpers.Bit; - var stringType = DataTypeHelpers.NVarChar(Int32.MaxValue, primaryDataSource.DefaultCollation, CollationLabel.CoercibleDefault); + var stringType = DataTypeHelpers.NVarChar(Int32.MaxValue, context.PrimaryDataSource.DefaultCollation, CollationLabel.CoercibleDefault); if (value.Type != typeof(SqlString)) { @@ -1034,23 +1026,23 @@ private static SqlBoolean Like(SqlString value, Regex pattern, bool not) return result; } - private static Expression ToExpression(this SimpleCaseExpression simpleCase, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this SimpleCaseExpression simpleCase, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { // Convert all the different elements to expressions - var value = simpleCase.InputExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var valueType); + var value = simpleCase.InputExpression.ToExpression(context, contextParam, out var valueType); var whenClauses = simpleCase.WhenClauses.Select(when => { - var whenExpr = when.WhenExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var whenType); + var whenExpr = when.WhenExpression.ToExpression(context, contextParam, out var whenType); return new { Expression = whenExpr, Type = whenType }; }).ToList(); var caseTypes = new DataTypeReference[whenClauses.Count]; var thenClauses = simpleCase.WhenClauses.Select(when => { - var thenExpr = when.ThenExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var thenType); + var thenExpr = when.ThenExpression.ToExpression(context, contextParam, out var thenType); return new { Expression = thenExpr, Type = thenType }; }).ToList(); DataTypeReference elseType = null; - var elseValue = simpleCase.ElseExpression?.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out elseType); + var elseValue = simpleCase.ElseExpression?.ToExpression(context, contextParam, out elseType); // First pass to determine final return type DataTypeReference type = null; @@ -1059,7 +1051,7 @@ private static Expression ToExpression(this SimpleCaseExpression simpleCase, Dat { var whenType = whenClauses[i].Type; - if (!SqlTypeConverter.CanMakeConsistentTypes(valueType, whenType, primaryDataSource, out var caseType)) + if (!SqlTypeConverter.CanMakeConsistentTypes(valueType, whenType, context.PrimaryDataSource, out var caseType)) throw new NotSupportedQueryFragmentException($"Cannot compare values of type {value.Type} and {whenType}", simpleCase.WhenClauses[i].WhenExpression); caseTypes[i] = caseType; @@ -1068,7 +1060,7 @@ private static Expression ToExpression(this SimpleCaseExpression simpleCase, Dat if (type == null) type = thenType; - else if (!SqlTypeConverter.CanMakeConsistentTypes(type, thenType, primaryDataSource, out type)) + else if (!SqlTypeConverter.CanMakeConsistentTypes(type, thenType, context.PrimaryDataSource, out type)) throw new NotSupportedQueryFragmentException($"Cannot determine return type", simpleCase); } @@ -1076,7 +1068,7 @@ private static Expression ToExpression(this SimpleCaseExpression simpleCase, Dat { if (type == null) type = elseType; - else if (!SqlTypeConverter.CanMakeConsistentTypes(type, elseType, primaryDataSource, out type)) + else if (!SqlTypeConverter.CanMakeConsistentTypes(type, elseType, context.PrimaryDataSource, out type)) throw new NotSupportedQueryFragmentException($"Cannot determine return type", simpleCase); } @@ -1123,21 +1115,21 @@ private static Expression ToExpression(this SimpleCaseExpression simpleCase, Dat return result; } - private static Expression ToExpression(this SearchedCaseExpression searchedCase, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this SearchedCaseExpression searchedCase, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { // Convert all the different elements to expressions var whenClauses = searchedCase.WhenClauses.Select(when => { - var whenExpr = when.WhenExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var whenType); + var whenExpr = when.WhenExpression.ToExpression(context, contextParam, out var whenType); return new { Expression = whenExpr, Type = whenType }; }).ToList(); var thenClauses = searchedCase.WhenClauses.Select(when => { - var thenExpr = when.ThenExpression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var thenType); + var thenExpr = when.ThenExpression.ToExpression(context, contextParam, out var thenType); return new { Expression = thenExpr, Type = thenType }; }).ToList(); DataTypeReference elseType = null; - var elseValue = searchedCase.ElseExpression?.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out elseType); + var elseValue = searchedCase.ElseExpression?.ToExpression(context, contextParam, out elseType); // First pass to determine final return type DataTypeReference type = null; @@ -1148,7 +1140,7 @@ private static Expression ToExpression(this SearchedCaseExpression searchedCase, if (type == null) type = thenType; - else if (!SqlTypeConverter.CanMakeConsistentTypes(type, thenType, primaryDataSource, out type)) + else if (!SqlTypeConverter.CanMakeConsistentTypes(type, thenType, context.PrimaryDataSource, out type)) throw new NotSupportedQueryFragmentException($"Cannot determine return type", searchedCase); } @@ -1156,7 +1148,7 @@ private static Expression ToExpression(this SearchedCaseExpression searchedCase, { if (type == null) type = elseType; - else if (!SqlTypeConverter.CanMakeConsistentTypes(type, elseType, primaryDataSource, out type)) + else if (!SqlTypeConverter.CanMakeConsistentTypes(type, elseType, context.PrimaryDataSource, out type)) throw new NotSupportedQueryFragmentException($"Cannot determine return type", searchedCase); } @@ -1197,9 +1189,9 @@ private static Expression ToExpression(this SearchedCaseExpression searchedCase, return result; } - private static Expression ToExpression(this BooleanNotExpression not, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this BooleanNotExpression not, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { - var value = not.Expression.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + var value = not.Expression.ToExpression(context, contextParam, out sqlType); return Expression.Not(value); } @@ -1305,16 +1297,16 @@ public static bool IsType(this DataTypeReference type, SqlDataTypeOption sqlType public static DataTypeReference ToSqlType(this Type type, DataSource dataSource) { if (type == typeof(SqlString)) - return DataTypeHelpers.NVarChar(Int32.MaxValue, dataSource?.DefaultCollation ?? Collation.USEnglish, CollationLabel.CoercibleDefault); + return DataTypeHelpers.NVarChar(Int32.MaxValue, dataSource?.DefaultCollation ?? Collation.USEnglish, CollationLabel.Implicit); return _netTypeMapping[type]; } - private static Expression ToExpression(this ConvertCall convert, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this ConvertCall convert, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { - var value = convert.Parameter.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var valueType); + var value = convert.Parameter.ToExpression(context, contextParam, out var valueType); DataTypeReference styleType = null; - var style = convert.Style?.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out styleType); + var style = convert.Style?.ToExpression(context, contextParam, out styleType); sqlType = convert.DataType; @@ -1333,14 +1325,14 @@ private static Expression ToExpression(this ConvertCall convert, DataSource prim return SqlTypeConverter.Convert(value, valueType, sqlType, style, styleType, convert); } - private static Expression ToExpression(this CastCall cast, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this CastCall cast, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { - return ToExpression(new ConvertCall { Parameter = cast.Parameter, DataType = cast.DataType, Collation = cast.Collation }, primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out sqlType); + return ToExpression(new ConvertCall { Parameter = cast.Parameter, DataType = cast.DataType, Collation = cast.Collation }, context, contextParam, out sqlType); } private static readonly Regex _containsParser = new Regex("^\\S+( OR \\S+)*$", RegexOptions.IgnoreCase | RegexOptions.Compiled); - private static Expression ToExpression(this FullTextPredicate fullText, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this FullTextPredicate fullText, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { // Only support simple CONTAINS calls to handle multi-select optionsets for now if (fullText.FullTextFunctionType != FullTextFunctionType.Contains) @@ -1358,8 +1350,8 @@ private static Expression ToExpression(this FullTextPredicate fullText, DataSour if (fullText.LanguageTerm != null) throw new NotSupportedQueryFragmentException("LANGUAGE is not currently supported", fullText.LanguageTerm); - var col = fullText.Columns[0].ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var colType); - var stringType = DataTypeHelpers.NVarChar(Int32.MaxValue, primaryDataSource.DefaultCollation, CollationLabel.CoercibleDefault); + var col = fullText.Columns[0].ToExpression(context, contextParam, out var colType); + var stringType = DataTypeHelpers.NVarChar(Int32.MaxValue, context.PrimaryDataSource.DefaultCollation, CollationLabel.CoercibleDefault); if (!SqlTypeConverter.CanChangeTypeImplicit(colType, stringType)) throw new NotSupportedQueryFragmentException("Only string columns are supported", fullText.Columns[0]); @@ -1376,7 +1368,7 @@ private static Expression ToExpression(this FullTextPredicate fullText, DataSour return Expr.Call(() => Contains(Expr.Arg(), Expr.Arg()), col, Expression.Constant(words)); } - var value = fullText.Value.ToExpression(primaryDataSource, schema, nonAggregateSchema, parameterTypes, entityParam, parameterParam, optionsParam, out var valueType); + var value = fullText.Value.ToExpression(context, contextParam, out var valueType); if (!SqlTypeConverter.CanChangeTypeImplicit(valueType, stringType)) throw new NotSupportedQueryFragmentException($"Expected string value to match, got {value.Type}", fullText.Value); @@ -1419,31 +1411,31 @@ private static Regex[] GetContainsWords(string pattern, bool compile) .ToArray(); } - private static Expression ToExpression(this ParameterlessCall parameterless, DataSource primaryDataSource, INodeSchema schema, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ParameterExpression entityParam, ParameterExpression parameterParam, ParameterExpression optionsParam, out DataTypeReference sqlType) + private static Expression ToExpression(this ParameterlessCall parameterless, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { switch (parameterless.ParameterlessCallType) { case ParameterlessCallType.CurrentTimestamp: sqlType = DataTypeHelpers.DateTime; - return Expr.Call(() => GetCurrentTimestamp(Expr.Arg()), optionsParam); + return Expr.Call(() => GetCurrentTimestamp(Expr.Arg()), contextParam); default: sqlType = DataTypeHelpers.EntityReference; - return Expr.Call(() => GetCurrentUser(Expr.Arg()), optionsParam); + return Expr.Call(() => GetCurrentUser(Expr.Arg()), contextParam); } } - private static SqlDateTime GetCurrentTimestamp(IQueryExecutionOptions options) + private static SqlDateTime GetCurrentTimestamp(ExpressionExecutionContext context) { - if (options.UseLocalTimeZone) + if (context.Options.UseLocalTimeZone) return new SqlDateTime(DateTime.Now); else return new SqlDateTime(DateTime.UtcNow); } - private static SqlEntityReference GetCurrentUser(IQueryExecutionOptions options) + private static SqlEntityReference GetCurrentUser(ExpressionExecutionContext context) { - return new SqlEntityReference(options.PrimaryDataSource, "systemuser", options.UserId); + return new SqlEntityReference(context.Options.PrimaryDataSource, "systemuser", context.Options.UserId); } /// @@ -1546,10 +1538,10 @@ public static ColumnReferenceExpression ToColumnReference(this string colName) /// Checks if an expression has a constant value /// /// The expression to check - /// The schema that the expression is evaluated in + /// The context the expression is being evaluated in /// The equivalent literal value /// true if the expression has a constant value, or false if it can change depending on the current data record - public static bool IsConstantValueExpression(this ScalarExpression expr, DataSource primaryDataSource, INodeSchema schema, IQueryExecutionOptions options, out Literal literal) + public static bool IsConstantValueExpression(this ScalarExpression expr, ExpressionCompilationContext context, out Literal literal) { literal = expr as Literal; @@ -1574,7 +1566,8 @@ public static bool IsConstantValueExpression(this ScalarExpression expr, DataSou if (parameterlessVisitor.ParameterlessCalls.Any(p => p.ParameterlessCallType != ParameterlessCallType.CurrentTimestamp)) return false; - var value = expr.Compile(primaryDataSource, schema, null)(null, null, options); + var evaluationContext = new ExpressionExecutionContext(context); + var value = expr.Compile(context)(evaluationContext); if (value == null || value is INullable n && n.IsNull) literal = new NullLiteral(); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs index 0092c0c3..2f86555e 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs @@ -196,17 +196,17 @@ public bool RequiresCustomPaging(IDictionary dataSources) return false; } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { PagesRetrieved = 0; - if (!dataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Missing datasource " + DataSource); ReturnFullSchema = false; - var schema = GetSchema(dataSources, parameterTypes); + var schema = GetSchema(context); - ApplyParameterValues(options, parameterValues); + ApplyParameterValues(context); FindEntityNameGroupings(dataSource.Metadata); @@ -215,10 +215,10 @@ protected override IEnumerable ExecuteInternal(IDictionary ExecuteInternal(IDictionary ExecuteInternal(IDictionary /// Updates the with current parameter values /// - /// The options to control how the query is executed - /// The parameter values to apply - public void ApplyParameterValues(IQueryExecutionOptions options, IDictionary parameterValues) + /// The context the node is being executed in + public void ApplyParameterValues(NodeExecutionContext context) { - if (parameterValues == null) + if (context.ParameterValues == null) return; if (_parameterizedConditions == null) _parameterizedConditions = FindParameterizedConditions(); - foreach (var param in parameterValues) + foreach (var param in context.ParameterValues) { if (_parameterizedConditions.TryGetValue(param.Key, out var conditions)) { foreach (var condition in conditions) - condition.SetValue(param.Value, options); + condition.SetValue(param.Value, context.Options); } } } @@ -391,7 +390,7 @@ public void RemoveSorts() } } - private void OnRetrievedEntity(Entity entity, INodeSchema schema, IQueryExecutionOptions options, IAttributeMetadataCache metadata) + private void OnRetrievedEntity(Entity entity, INodeSchema schema, IQueryExecutionOptions options, DataSource dataSource) { // Expose any formatted values for OptionSetValue and EntityReference values foreach (var formatted in entity.FormattedValues) @@ -418,7 +417,7 @@ private void OnRetrievedEntity(Entity entity, INodeSchema schema, IQueryExecutio if (value is DateTime dt) { - var meta = metadata[entityName]; + var meta = dataSource.Metadata[entityName]; var attrMeta = (DateTimeAttributeMetadata) meta.Attributes.Single(a => a.LogicalName == attributeName); if (attrMeta.DateTimeBehavior == DateTimeBehavior.UserLocal) @@ -454,7 +453,7 @@ private void OnRetrievedEntity(Entity entity, INodeSchema schema, IQueryExecutio else throw new QueryExecutionException($"Expected ObjectTypeCode value, got {aliasedValue.Value} ({aliasedValue.Value?.GetType()})"); - var meta = metadata[otc]; + var meta = dataSource.Metadata[otc]; entity[attribute.Key] = meta.LogicalName; } else @@ -476,7 +475,7 @@ private void OnRetrievedEntity(Entity entity, INodeSchema schema, IQueryExecutio object sqlValue; if (entity.Attributes.TryGetValue(col.Key, out var value) && value != null) - sqlValue = SqlTypeConverter.NetToSqlType(DataSource, value, col.Value); + sqlValue = SqlTypeConverter.NetToSqlType(dataSource, value, col.Value); else sqlValue = SqlTypeConverter.GetNullValue(col.Value.ToNetType(out _)); @@ -641,9 +640,9 @@ public override IEnumerable GetSources() return Array.Empty(); } - public override INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { - if (!dataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Missing datasource " + DataSource); var fetchXmlString = FetchXmlString; @@ -990,7 +989,7 @@ private void AddSchemaAttribute(Dictionary schema, Di ((List)simpleColumnNameAliases).Add(fullName); } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { NormalizeFilters(); @@ -1277,12 +1276,12 @@ private void MergeNestedFilters(filter filter) filter.Items = items.ToArray(); } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { - if (!dataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Missing datasource " + DataSource); - var schema = GetSchema(dataSources, parameterTypes); + var schema = GetSchema(context); // Add columns to FetchXml foreach (var col in requiredColumns) @@ -1358,7 +1357,7 @@ public override void AddRequiredColumns(IDictionary dataSour Entity.AddItem(new FetchAttributeType { name = metadata.PrimaryIdAttribute }); } - if (RequiresCustomPaging(dataSources)) + if (RequiresCustomPaging(context.DataSources)) { RemoveSorts(); @@ -1372,8 +1371,8 @@ public override void AddRequiredColumns(IDictionary dataSour AddPrimaryIdAttribute(linkEntity, dataSource); } - NormalizeAttributes(dataSources); - SetDefaultPageSize(dataSources, parameterTypes); + NormalizeAttributes(context.DataSources); + SetDefaultPageSize(context); } private void AddPrimaryIdAttribute(FetchEntityType entity, DataSource dataSource) @@ -1454,7 +1453,7 @@ private object[] NormalizeAttributes(IDictionary dataSources .ToArray(); } - private void SetDefaultPageSize(IDictionary dataSources, IDictionary parameterTypes) + private void SetDefaultPageSize(NodeCompilationContext context) { if (!String.IsNullOrEmpty(FetchXml.count) || !String.IsNullOrEmpty(FetchXml.top)) return; @@ -1464,7 +1463,7 @@ private void SetDefaultPageSize(IDictionary dataSources, IDi var fullSchema = ReturnFullSchema; ReturnFullSchema = false; - var schema = GetSchema(dataSources, parameterTypes); + var schema = GetSchema(context); if (schema.Schema.Count > 100) FetchXml.count = "1000"; @@ -1475,7 +1474,7 @@ private void SetDefaultPageSize(IDictionary dataSources, IDi ReturnFullSchema = fullSchema; } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { if (FetchXml.aggregateSpecified && FetchXml.aggregate) { @@ -1484,10 +1483,10 @@ protected override RowCountEstimate EstimateRowsOutInternal(IDictionary ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - var schema = Source.GetSchema(dataSources, parameterTypes); - var filter = Filter.Compile(dataSources[options.PrimaryDataSource], schema, parameterTypes); + var schema = Source.GetSchema(context); + var expressionCompilationContext = new ExpressionCompilationContext(context, schema, null); + var filter = Filter.Compile(expressionCompilationContext); + var expressionContext = new ExpressionExecutionContext(context); - foreach (var entity in Source.Execute(dataSources, options, parameterTypes, parameterValues)) + foreach (var entity in Source.Execute(context)) { - if (filter(entity, parameterValues, options)) + expressionContext.Entity = entity; + + if (filter(expressionContext)) yield return entity; } } @@ -49,9 +53,9 @@ public override IEnumerable GetSources() yield return Source; } - public override INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { - var schema = Source.GetSchema(dataSources, parameterTypes); + var schema = Source.GetSchema(context); var notNullColumns = new HashSet(schema.NotNullColumns, StringComparer.OrdinalIgnoreCase); AddNotNullColumns(schema, notNullColumns, Filter, false); @@ -125,21 +129,21 @@ private void AddNotNullColumn(INodeSchema schema, HashSet notNullColumns notNullColumns.Add(colName); } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - Source = Source.FoldQuery(dataSources, options, parameterTypes, hints); + Source = Source.FoldQuery(context, hints); Source.Parent = this; var foldedFilters = false; foldedFilters |= FoldConsecutiveFilters(); - foldedFilters |= FoldNestedLoopFiltersToJoins(dataSources, options, parameterTypes, hints); - foldedFilters |= FoldInExistsToFetchXml(dataSources, options, parameterTypes, hints, out var addedLinks); - foldedFilters |= FoldTableSpoolToIndexSpool(dataSources, options, parameterTypes, hints); - foldedFilters |= FoldFiltersToDataSources(dataSources, options, parameterTypes); + foldedFilters |= FoldNestedLoopFiltersToJoins(context, hints); + foldedFilters |= FoldInExistsToFetchXml(context, hints, out var addedLinks); + foldedFilters |= FoldTableSpoolToIndexSpool(context, hints); + foldedFilters |= FoldFiltersToDataSources(context); - if (FoldColumnComparisonsWithKnownValues(dataSources, parameterTypes)) - foldedFilters |= FoldFiltersToDataSources(dataSources, options, parameterTypes); + if (FoldColumnComparisonsWithKnownValues(context)) + foldedFilters |= FoldFiltersToDataSources(context); foreach (var addedLink in addedLinks) { @@ -150,7 +154,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + private bool FoldNestedLoopFiltersToJoins(NodeCompilationContext context, IList hints) { // Queries like "FROM table1, table2 WHERE table1.col = table2.col" are created as: // Filter: table1.col = table2.col @@ -184,7 +188,7 @@ private bool FoldNestedLoopFiltersToJoins(IDictionary dataSo // -> FetchXml // -> Table Spool // -> FetchXml - if (FoldNestedLoopFiltersToJoins(Source as BaseJoinNode, dataSources, options, parameterTypes, hints, out var foldedJoin)) + if (FoldNestedLoopFiltersToJoins(Source as BaseJoinNode, context, hints, out var foldedJoin)) { Source = foldedJoin; return true; @@ -193,7 +197,7 @@ private bool FoldNestedLoopFiltersToJoins(IDictionary dataSo return false; } - private bool FoldNestedLoopFiltersToJoins(BaseJoinNode join, IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints, out FoldableJoinNode foldedJoin) + private bool FoldNestedLoopFiltersToJoins(BaseJoinNode join, NodeCompilationContext context, IList hints, out FoldableJoinNode foldedJoin) { foldedJoin = null; @@ -208,8 +212,8 @@ private bool FoldNestedLoopFiltersToJoins(BaseJoinNode join, IDictionary(ref T first, ref T second) second = temp; } - private bool FoldInExistsToFetchXml(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints, out Dictionary addedLinks) + private bool FoldInExistsToFetchXml(NodeCompilationContext context, IList hints, out Dictionary addedLinks) { var foldedFilters = false; @@ -387,7 +391,7 @@ private bool FoldInExistsToFetchXml(IDictionary dataSources, if (!leftFetch.DataSource.Equals(rightFetch.DataSource, StringComparison.OrdinalIgnoreCase)) break; - var rightSchema = rightFetch.GetSchema(dataSources, parameterTypes); + var rightSchema = rightFetch.GetSchema(context); if (!rightSchema.ContainsColumn(merge.RightAttribute.GetColumnName(), out var attribute)) break; @@ -410,8 +414,8 @@ private bool FoldInExistsToFetchXml(IDictionary dataSources, // Can't do this if there is any conflict in join aliases if (notNullFilterRemovable && leftFetch.Entity.name == rightFetch.Entity.name && - merge.LeftAttribute.GetColumnName() == leftFetch.Alias + "." + dataSources[leftFetch.DataSource].Metadata[leftFetch.Entity.name].PrimaryIdAttribute && - merge.RightAttribute.GetColumnName() == rightFetch.Alias + "." + dataSources[rightFetch.DataSource].Metadata[rightFetch.Entity.name].PrimaryIdAttribute && + merge.LeftAttribute.GetColumnName() == leftFetch.Alias + "." + context.DataSources[leftFetch.DataSource].Metadata[leftFetch.Entity.name].PrimaryIdAttribute && + merge.RightAttribute.GetColumnName() == rightFetch.Alias + "." + context.DataSources[rightFetch.DataSource].Metadata[rightFetch.Entity.name].PrimaryIdAttribute && !leftFetch.Entity.GetLinkEntities().Select(l => l.alias).Intersect(rightFetch.Entity.GetLinkEntities().Select(l => l.alias), StringComparer.OrdinalIgnoreCase).Any() && (leftFetch.FetchXml.top == null || rightFetch.FetchXml.top == null)) { @@ -485,7 +489,7 @@ private bool FoldInExistsToFetchXml(IDictionary dataSources, } // We need to use an "in" join type - check that's supported else if (notNullFilterRemovable && - options.JoinOperatorsAvailable.Contains(JoinOperator.Any)) + context.Options.JoinOperatorsAvailable.Contains(JoinOperator.Any)) { // Remove the filter and replace with an "in" link-entity Filter = Filter.RemoveCondition(notNullFilter); @@ -521,7 +525,7 @@ private bool FoldInExistsToFetchXml(IDictionary dataSources, else if (join is NestedLoopNode loop) { // Check we meet all the criteria for a foldable correlated EXISTS query - if (!options.JoinOperatorsAvailable.Contains(JoinOperator.Exists)) + if (!context.Options.JoinOperatorsAvailable.Contains(JoinOperator.Exists)) break; if (loop.JoinCondition != null || @@ -628,13 +632,13 @@ private bool FoldInExistsToFetchXml(IDictionary dataSources, return foldedFilters; } - private bool FoldTableSpoolToIndexSpool(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + private bool FoldTableSpoolToIndexSpool(NodeCompilationContext context, IList hints) { // If we've got a filter matching a column and a variable (key lookup in a nested loop) from a table spool, replace it with a index spool if (!(Source is TableSpoolNode tableSpool)) return false; - var schema = Source.GetSchema(dataSources, parameterTypes); + var schema = Source.GetSchema(context); if (!ExtractKeyLookupFilter(Filter, out var filter, out var indexColumn, out var seekVariable) || !schema.ContainsColumn(indexColumn, out indexColumn)) return false; @@ -660,28 +664,28 @@ private bool FoldTableSpoolToIndexSpool(IDictionary dataSour Source = spoolSource, KeyColumn = indexColumn, SeekValue = seekVariable - }.FoldQuery(dataSources, options, parameterTypes, hints); + }.FoldQuery(context, hints); Filter = filter; return true; } - private bool FoldFiltersToDataSources(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + private bool FoldFiltersToDataSources(NodeCompilationContext context) { var foldedFilters = false; // Find all the data source nodes we could fold this into. Include direct data sources, those from either side of an inner join, or the main side of an outer join foreach (var source in GetFoldableSources(Source)) { - var schema = source.GetSchema(dataSources, parameterTypes); + var schema = source.GetSchema(context); if (source is FetchXmlScan fetchXml && !fetchXml.FetchXml.aggregate) { - if (!dataSources.TryGetValue(fetchXml.DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(fetchXml.DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Missing datasource " + fetchXml.DataSource); // If the criteria are ANDed, see if any of the individual conditions can be translated to FetchXML - Filter = ExtractFetchXMLFilters(dataSources[options.PrimaryDataSource], dataSource.Metadata, options, Filter, schema, null, fetchXml.Entity.name, fetchXml.Alias, fetchXml.Entity.Items, parameterTypes, out var fetchFilter); + Filter = ExtractFetchXMLFilters(context, dataSource.Metadata, Filter, schema, null, fetchXml.Entity.name, fetchXml.Alias, fetchXml.Entity.Items, out var fetchFilter); if (fetchFilter != null) { @@ -706,7 +710,7 @@ private bool FoldFiltersToDataSources(IDictionary dataSource if (source is MetadataQueryNode meta) { // If the criteria are ANDed, see if any of the individual conditions can be translated to the metadata query - Filter = ExtractMetadataFilters(dataSources[options.PrimaryDataSource], Filter, meta, options, out var entityFilter, out var attributeFilter, out var relationshipFilter); + Filter = ExtractMetadataFilters(context, Filter, meta, out var entityFilter, out var attributeFilter, out var relationshipFilter); meta.Query.AddFilter(entityFilter); @@ -728,7 +732,7 @@ private bool FoldFiltersToDataSources(IDictionary dataSource return foldedFilters; } - private bool FoldColumnComparisonsWithKnownValues(IDictionary dataSources, IDictionary parameterTypes) + private bool FoldColumnComparisonsWithKnownValues(NodeCompilationContext context) { var foldedFilters = false; @@ -737,12 +741,12 @@ private bool FoldColumnComparisonsWithKnownValues(IDictionary dataSources, IDictionary parameterTypes, FetchXmlScan fetchXml, INodeSchema schema, BooleanExpression filter) + private BooleanExpression FoldColumnComparisonsWithKnownValues(NodeCompilationContext context, FetchXmlScan fetchXml, INodeSchema schema, BooleanExpression filter) { if (filter is BooleanComparisonExpression cmp && cmp.FirstExpression is ColumnReferenceExpression col1 && cmp.SecondExpression is ColumnReferenceExpression col2) { - if (HasKnownValue(dataSources, parameterTypes, fetchXml, col1, schema, out var value)) + if (HasKnownValue(context, fetchXml, col1, schema, out var value)) { return new BooleanComparisonExpression { @@ -773,7 +777,7 @@ cmp.FirstExpression is ColumnReferenceExpression col1 && SecondExpression = col2 }; } - else if (HasKnownValue(dataSources, parameterTypes, fetchXml, col2, schema, out value)) + else if (HasKnownValue(context, fetchXml, col2, schema, out value)) { return new BooleanComparisonExpression { @@ -786,8 +790,8 @@ cmp.FirstExpression is ColumnReferenceExpression col1 && else if (filter is BooleanBinaryExpression bin && bin.BinaryExpressionType == BooleanBinaryExpressionType.And) { - var bin1 = FoldColumnComparisonsWithKnownValues(dataSources, parameterTypes, fetchXml, schema, bin.FirstExpression); - var bin2 = FoldColumnComparisonsWithKnownValues(dataSources, parameterTypes, fetchXml, schema, bin.SecondExpression); + var bin1 = FoldColumnComparisonsWithKnownValues(context, fetchXml, schema, bin.FirstExpression); + var bin2 = FoldColumnComparisonsWithKnownValues(context, fetchXml, schema, bin.SecondExpression); if (bin1 != bin.FirstExpression || bin2 != bin.SecondExpression) { @@ -803,7 +807,7 @@ cmp.FirstExpression is ColumnReferenceExpression col1 && return filter; } - private bool HasKnownValue(IDictionary dataSources, IDictionary parameterTypes, FetchXmlScan fetchXml, ColumnReferenceExpression col, INodeSchema schema, out Literal value) + private bool HasKnownValue(NodeCompilationContext context, FetchXmlScan fetchXml, ColumnReferenceExpression col, INodeSchema schema, out Literal value) { value = null; @@ -996,9 +1000,9 @@ private bool ExtractKeyLookupFilter(BooleanExpression filter, out BooleanExpress return false; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { - var schema = Source.GetSchema(dataSources, parameterTypes); + var schema = Source.GetSchema(context); foreach (var col in Filter.GetColumns()) { @@ -1009,12 +1013,12 @@ public override void AddRequiredColumns(IDictionary dataSour requiredColumns.Add(normalized); } - Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + Source.AddRequiredColumns(context, requiredColumns); } - private BooleanExpression ExtractFetchXMLFilters(DataSource primaryDataSource, IAttributeMetadataCache metadata, IQueryExecutionOptions options, BooleanExpression criteria, INodeSchema schema, string allowedPrefix, string targetEntityName, string targetEntityAlias, object[] items, IDictionary parameterTypes, out filter filter) + private BooleanExpression ExtractFetchXMLFilters(NodeCompilationContext context, IAttributeMetadataCache metadata, BooleanExpression criteria, INodeSchema schema, string allowedPrefix, string targetEntityName, string targetEntityAlias, object[] items, out filter filter) { - if (TranslateFetchXMLCriteria(primaryDataSource, metadata, options, criteria, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out filter)) + if (TranslateFetchXMLCriteria(context, metadata, criteria, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, out filter)) return null; if (!(criteria is BooleanBinaryExpression bin)) @@ -1023,8 +1027,8 @@ private BooleanExpression ExtractFetchXMLFilters(DataSource primaryDataSource, I if (bin.BinaryExpressionType != BooleanBinaryExpressionType.And) return criteria; - bin.FirstExpression = ExtractFetchXMLFilters(primaryDataSource, metadata, options, bin.FirstExpression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out var lhsFilter); - bin.SecondExpression = ExtractFetchXMLFilters(primaryDataSource, metadata, options, bin.SecondExpression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, parameterTypes, out var rhsFilter); + bin.FirstExpression = ExtractFetchXMLFilters(context, metadata, bin.FirstExpression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, out var lhsFilter); + bin.SecondExpression = ExtractFetchXMLFilters(context, metadata, bin.SecondExpression, schema, allowedPrefix, targetEntityName, targetEntityAlias, items, out var rhsFilter); filter = (lhsFilter != null && rhsFilter != null) ? new filter { Items = new object[] { lhsFilter, rhsFilter } } : lhsFilter ?? rhsFilter; @@ -1034,9 +1038,9 @@ private BooleanExpression ExtractFetchXMLFilters(DataSource primaryDataSource, I return bin.FirstExpression ?? bin.SecondExpression; } - protected BooleanExpression ExtractMetadataFilters(DataSource primaryDataSource, BooleanExpression criteria, MetadataQueryNode meta, IQueryExecutionOptions options, out MetadataFilterExpression entityFilter, out MetadataFilterExpression attributeFilter, out MetadataFilterExpression relationshipFilter) + protected BooleanExpression ExtractMetadataFilters(NodeCompilationContext context, BooleanExpression criteria, MetadataQueryNode meta, out MetadataFilterExpression entityFilter, out MetadataFilterExpression attributeFilter, out MetadataFilterExpression relationshipFilter) { - if (TranslateMetadataCriteria(primaryDataSource, criteria, meta, options, out entityFilter, out attributeFilter, out relationshipFilter)) + if (TranslateMetadataCriteria(context, criteria, meta, out entityFilter, out attributeFilter, out relationshipFilter)) return null; if (!(criteria is BooleanBinaryExpression bin)) @@ -1045,8 +1049,8 @@ protected BooleanExpression ExtractMetadataFilters(DataSource primaryDataSource, if (bin.BinaryExpressionType != BooleanBinaryExpressionType.And) return criteria; - bin.FirstExpression = ExtractMetadataFilters(primaryDataSource, bin.FirstExpression, meta, options, out var lhsEntityFilter, out var lhsAttributeFilter, out var lhsRelationshipFilter); - bin.SecondExpression = ExtractMetadataFilters(primaryDataSource, bin.SecondExpression, meta, options, out var rhsEntityFilter, out var rhsAttributeFilter, out var rhsRelationshipFilter); + bin.FirstExpression = ExtractMetadataFilters(context, bin.FirstExpression, meta, out var lhsEntityFilter, out var lhsAttributeFilter, out var lhsRelationshipFilter); + bin.SecondExpression = ExtractMetadataFilters(context, bin.SecondExpression, meta, out var rhsEntityFilter, out var rhsAttributeFilter, out var rhsRelationshipFilter); entityFilter = (lhsEntityFilter != null && rhsEntityFilter != null) ? new MetadataFilterExpression { Filters = { lhsEntityFilter, rhsEntityFilter } } : lhsEntityFilter ?? rhsEntityFilter; attributeFilter = (lhsAttributeFilter != null && rhsAttributeFilter != null) ? new MetadataFilterExpression { Filters = { lhsAttributeFilter, rhsAttributeFilter } } : lhsAttributeFilter ?? rhsAttributeFilter; @@ -1057,17 +1061,21 @@ protected BooleanExpression ExtractMetadataFilters(DataSource primaryDataSource, return bin.FirstExpression ?? bin.SecondExpression; } - protected bool TranslateMetadataCriteria(DataSource primaryDataSource, BooleanExpression criteria, MetadataQueryNode meta, IQueryExecutionOptions options, out MetadataFilterExpression entityFilter, out MetadataFilterExpression attributeFilter, out MetadataFilterExpression relationshipFilter) + + protected bool TranslateMetadataCriteria(NodeCompilationContext context, BooleanExpression criteria, MetadataQueryNode meta, out MetadataFilterExpression entityFilter, out MetadataFilterExpression attributeFilter, out MetadataFilterExpression relationshipFilter) { entityFilter = null; attributeFilter = null; relationshipFilter = null; + var expressionCompilationContext = new ExpressionCompilationContext(context, null, null); + var expressionExecutionContext = new ExpressionExecutionContext(new NodeExecutionContext(context.DataSources, context.Options, context.ParameterTypes, null)); + if (criteria is BooleanBinaryExpression binary) { - if (!TranslateMetadataCriteria(primaryDataSource, binary.FirstExpression, meta, options, out var lhsEntityFilter, out var lhsAttributeFilter, out var lhsRelationshipFilter)) + if (!TranslateMetadataCriteria(context, binary.FirstExpression, meta, out var lhsEntityFilter, out var lhsAttributeFilter, out var lhsRelationshipFilter)) return false; - if (!TranslateMetadataCriteria(primaryDataSource, binary.SecondExpression, meta, options, out var rhsEntityFilter, out var rhsAttributeFilter, out var rhsRelationshipFilter)) + if (!TranslateMetadataCriteria(context, binary.SecondExpression, meta, out var rhsEntityFilter, out var rhsAttributeFilter, out var rhsRelationshipFilter)) return false; if (binary.BinaryExpressionType == BooleanBinaryExpressionType.Or) @@ -1140,7 +1148,7 @@ protected bool TranslateMetadataCriteria(DataSource primaryDataSource, BooleanEx if (col == null || literal == null) return false; - var schema = meta.GetSchema(null, null); + var schema = meta.GetSchema(context); if (!schema.ContainsColumn(col.GetColumnName(), out var colName)) return false; @@ -1174,7 +1182,7 @@ protected bool TranslateMetadataCriteria(DataSource primaryDataSource, BooleanEx throw new InvalidOperationException(); } - var condition = new MetadataConditionExpression(parts[1], op, literal.Compile(primaryDataSource, null, null)(null, null, options)); + var condition = new MetadataConditionExpression(parts[1], op, literal.Compile(expressionCompilationContext)(expressionExecutionContext)); return TranslateMetadataCondition(condition, parts[0], meta, out entityFilter, out attributeFilter, out relationshipFilter); } @@ -1186,7 +1194,7 @@ protected bool TranslateMetadataCriteria(DataSource primaryDataSource, BooleanEx if (col == null) return false; - var schema = meta.GetSchema(null, null); + var schema = meta.GetSchema(context); if (!schema.ContainsColumn(col.GetColumnName(), out var colName)) return false; @@ -1198,7 +1206,7 @@ protected bool TranslateMetadataCriteria(DataSource primaryDataSource, BooleanEx if (inPred.Values.Any(val => !(val is Literal))) return false; - var condition = new MetadataConditionExpression(parts[1], inPred.NotDefined ? MetadataConditionOperator.NotIn : MetadataConditionOperator.In, inPred.Values.Select(val => val.Compile(primaryDataSource, null, null)(null, null, options)).ToArray()); + var condition = new MetadataConditionExpression(parts[1], inPred.NotDefined ? MetadataConditionOperator.NotIn : MetadataConditionOperator.In, inPred.Values.Select(val => val.Compile(expressionCompilationContext)(expressionExecutionContext)).ToArray()); return TranslateMetadataCondition(condition, parts[0], meta, out entityFilter, out attributeFilter, out relationshipFilter); } @@ -1210,7 +1218,7 @@ protected bool TranslateMetadataCriteria(DataSource primaryDataSource, BooleanEx if (col == null) return false; - var schema = meta.GetSchema(null, null); + var schema = meta.GetSchema(context); if (!schema.ContainsColumn(col.GetColumnName(), out var colName)) return false; @@ -1357,9 +1365,9 @@ private bool TranslateMetadataCondition(MetadataConditionExpression condition, s return false; } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { - return new RowCountEstimate(Source.EstimateRowsOut(dataSources, options, parameterTypes).Value * 8 / 10); + return new RowCountEstimate(Source.EstimateRowsOut(context).Value * 8 / 10); } protected override IEnumerable GetVariablesInternal() diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FoldableJoinNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FoldableJoinNode.cs index 25d1c078..e30513b4 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FoldableJoinNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FoldableJoinNode.cs @@ -49,18 +49,18 @@ abstract class FoldableJoinNode : BaseJoinNode [DisplayName("Additional Join Criteria")] public BooleanExpression AdditionalJoinCriteria { get; set; } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - LeftSource = LeftSource.FoldQuery(dataSources, options, parameterTypes, hints); + LeftSource = LeftSource.FoldQuery(context, hints); LeftSource.Parent = this; - RightSource = RightSource.FoldQuery(dataSources, options, parameterTypes, hints); + RightSource = RightSource.FoldQuery(context, hints); RightSource.Parent = this; if (SemiJoin) return this; - var leftSchema = LeftSource.GetSchema(dataSources, parameterTypes); - var rightSchema = RightSource.GetSchema(dataSources, parameterTypes); + var leftSchema = LeftSource.GetSchema(context); + var rightSchema = RightSource.GetSchema(context); var leftFilter = JoinType == QualifiedJoinType.Inner || JoinType == QualifiedJoinType.LeftOuter ? LeftSource as FilterNode : null; var rightFilter = JoinType == QualifiedJoinType.Inner || JoinType == QualifiedJoinType.RightOuter ? RightSource as FilterNode : null; var leftFetch = (leftFilter?.Source ?? LeftSource) as FetchXmlScan; @@ -70,35 +70,35 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints, params FilterNode[] filters) + private IDataExecutionPlanNodeInternal PrependFilters(IDataExecutionPlanNodeInternal folded, NodeCompilationContext context, IList hints, params FilterNode[] filters) { foreach (var filter in filters) { @@ -106,15 +106,15 @@ private IDataExecutionPlanNodeInternal PrependFilters(IDataExecutionPlanNodeInte continue; filter.Source = folded; - folded = filter.FoldQuery(dataSources, options, parameterTypes, hints); + folded = filter.FoldQuery(context, hints); } return folded; } - private IDataExecutionPlanNodeInternal AddNotNullFilter(IDataExecutionPlanNodeInternal source, ColumnReferenceExpression attribute, IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + private IDataExecutionPlanNodeInternal AddNotNullFilter(IDataExecutionPlanNodeInternal source, ColumnReferenceExpression attribute, NodeCompilationContext context, IList hints) { - var schema = source.GetSchema(dataSources, parameterTypes); + var schema = source.GetSchema(context); if (!schema.ContainsColumn(attribute.GetColumnName(), out var colName)) return source; @@ -131,7 +131,7 @@ private IDataExecutionPlanNodeInternal AddNotNullFilter(IDataExecutionPlanNodeIn } }; - var folded = filter.FoldQuery(dataSources, options, parameterTypes, hints); + var folded = filter.FoldQuery(context, hints); if (folded != filter) { @@ -142,7 +142,7 @@ private IDataExecutionPlanNodeInternal AddNotNullFilter(IDataExecutionPlanNodeIn return source; } - private bool FoldFetchXmlJoin(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints, BaseJoinNode join, FetchXmlScan fetch, INodeSchema fetchSchema, out IDataExecutionPlanNodeInternal folded) + private bool FoldFetchXmlJoin(NodeCompilationContext context, IList hints, BaseJoinNode join, FetchXmlScan fetch, INodeSchema fetchSchema, out IDataExecutionPlanNodeInternal folded) { folded = null; @@ -161,7 +161,7 @@ private bool FoldFetchXmlJoin(IDictionary dataSources, IQuer if ((join.JoinType == QualifiedJoinType.Inner || join.JoinType == QualifiedJoinType.RightOuter) && join.LeftSource is FetchXmlScan leftInnerFetch) { var leftSource = leftInnerFetch; - var leftSchema = leftInnerFetch.GetSchema(dataSources, parameterTypes); + var leftSchema = leftInnerFetch.GetSchema(context); var rightSource = fetch; var rightSchema = fetchSchema; @@ -171,11 +171,11 @@ private bool FoldFetchXmlJoin(IDictionary dataSources, IQuer Swap(ref leftSchema, ref rightSchema); } - if (FoldFetchXmlJoin(dataSources, options, parameterTypes, hints, leftSource, leftSchema, rightSource, rightSchema, out folded)) + if (FoldFetchXmlJoin(context, hints, leftSource, leftSchema, rightSource, rightSchema, out folded)) { folded.Parent = join; join.LeftSource = folded; - folded = ConvertManyToManyMergeJoinToHashJoin(join, dataSources, options, parameterTypes, hints); + folded = ConvertManyToManyMergeJoinToHashJoin(join, context, hints); return true; } } @@ -183,7 +183,7 @@ private bool FoldFetchXmlJoin(IDictionary dataSources, IQuer if ((join.JoinType == QualifiedJoinType.Inner || join.JoinType == QualifiedJoinType.LeftOuter) && join.RightSource is FetchXmlScan rightInnerFetch) { var leftSource = rightInnerFetch; - var leftSchema = rightInnerFetch.GetSchema(dataSources, parameterTypes); + var leftSchema = rightInnerFetch.GetSchema(context); var rightSource = fetch; var rightSchema = fetchSchema; @@ -193,7 +193,7 @@ private bool FoldFetchXmlJoin(IDictionary dataSources, IQuer Swap(ref leftSchema, ref rightSchema); } - if (FoldFetchXmlJoin(dataSources, options, parameterTypes, hints, leftSource, leftSchema, rightSource, rightSchema, out folded)) + if (FoldFetchXmlJoin(context, hints, leftSource, leftSchema, rightSource, rightSchema, out folded)) { folded.Parent = join; join.RightSource = folded; @@ -205,14 +205,14 @@ private bool FoldFetchXmlJoin(IDictionary dataSources, IQuer return false; } - private IDataExecutionPlanNodeInternal ConvertManyToManyMergeJoinToHashJoin(BaseJoinNode join, IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + private IDataExecutionPlanNodeInternal ConvertManyToManyMergeJoinToHashJoin(BaseJoinNode join, NodeCompilationContext context, IList hints) { // If folding the inner join has caused a one-to-many merge join to become a many-to-many merge join, // which we don't currently support, switch it to be a hash join if (!(join is MergeJoinNode merge)) return join; - var leftSchema = join.LeftSource.GetSchema(dataSources, parameterTypes); + var leftSchema = join.LeftSource.GetSchema(context); if (leftSchema.ContainsColumn(merge.LeftAttribute.GetColumnName(), out var leftKey) && leftKey == leftSchema.PrimaryKey) return join; @@ -246,10 +246,10 @@ private IDataExecutionPlanNodeInternal ConvertManyToManyMergeJoinToHashJoin(Base hash.LeftSource.Parent = hash; hash.RightSource.Parent = hash; - return hash.FoldQuery(dataSources, options, parameterTypes, hints); + return hash.FoldQuery(context, hints); } - private bool FoldFetchXmlJoin(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints, FetchXmlScan leftFetch, INodeSchema leftSchema, FetchXmlScan rightFetch, INodeSchema rightSchema, out IDataExecutionPlanNodeInternal folded) + private bool FoldFetchXmlJoin(NodeCompilationContext context, IList hints, FetchXmlScan leftFetch, INodeSchema leftSchema, FetchXmlScan rightFetch, INodeSchema rightSchema, out IDataExecutionPlanNodeInternal folded) { folded = null; @@ -283,7 +283,7 @@ private bool FoldFetchXmlJoin(IDictionary dataSources, IQuer return false; // If the entities are from different virtual entity data providers it's probably not going to work - if (!dataSources.TryGetValue(leftFetch.DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(leftFetch.DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Missing datasource " + leftFetch.DataSource); if (dataSource.Metadata[leftFetch.Entity.name].DataProviderId != dataSource.Metadata[rightFetch.Entity.name].DataProviderId) @@ -318,7 +318,7 @@ private bool FoldFetchXmlJoin(IDictionary dataSources, IQuer // in the new link entity or we must be using an inner join so we can use a post-filter node var additionalCriteria = AdditionalJoinCriteria; - if (TranslateFetchXMLCriteria(dataSources[options.PrimaryDataSource], dataSource.Metadata, options, additionalCriteria, rightSchema, rightFetch.Alias, rightEntity.name, rightFetch.Alias, rightEntity.Items, parameterTypes, out var filter)) + if (TranslateFetchXMLCriteria(context, dataSource.Metadata, additionalCriteria, rightSchema, rightFetch.Alias, rightEntity.name, rightFetch.Alias, rightEntity.Items, out var filter)) { rightEntity.AddItem(filter); additionalCriteria = null; @@ -360,7 +360,7 @@ private bool FoldFetchXmlJoin(IDictionary dataSources, IQuer if (additionalCriteria != null) { - folded = new FilterNode { Filter = additionalCriteria, Source = leftFetch }.FoldQuery(dataSources, options, parameterTypes, hints); + folded = new FilterNode { Filter = additionalCriteria, Source = leftFetch }.FoldQuery(context, hints); return true; } @@ -382,7 +382,7 @@ private Version GetVersion(IOrganizationService org) return new Version(resp.Version); } - private bool FoldMetadataJoin(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints, MetadataQueryNode leftMeta, INodeSchema leftSchema, MetadataQueryNode rightMeta, INodeSchema rightSchema, out IDataExecutionPlanNodeInternal folded) + private bool FoldMetadataJoin(NodeCompilationContext context, IList hints, MetadataQueryNode leftMeta, INodeSchema leftSchema, MetadataQueryNode rightMeta, INodeSchema rightSchema, out IDataExecutionPlanNodeInternal folded) { folded = null; @@ -469,7 +469,7 @@ private bool FoldMetadataJoin(IDictionary dataSources, IQuer return false; } - private bool FoldSingleRowJoinToNestedLoop(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints, INodeSchema leftSchema, INodeSchema rightSchema, out IDataExecutionPlanNodeInternal folded) + private bool FoldSingleRowJoinToNestedLoop(NodeCompilationContext context, IList hints, INodeSchema leftSchema, INodeSchema rightSchema, out IDataExecutionPlanNodeInternal folded) { folded = null; @@ -483,8 +483,8 @@ private bool FoldSingleRowJoinToNestedLoop(IDictionary dataS var rightSource = RightSource; var leftAttribute = LeftAttribute; var rightAttribute = RightAttribute; - leftSource.EstimateRowsOut(dataSources, options, parameterTypes); - rightSource.EstimateRowsOut(dataSources, options, parameterTypes); + leftSource.EstimateRowsOut(context); + rightSource.EstimateRowsOut(context); leftSchema.ContainsColumn(leftAttribute.GetColumnName(), out var leftAttr); rightSchema.ContainsColumn(rightAttribute.GetColumnName(), out var rightAttr); @@ -514,7 +514,7 @@ private bool FoldSingleRowJoinToNestedLoop(IDictionary dataS SecondExpression = new VariableReference { Name = outerReference } } }; - var foldedRightSource = filteredRightSource.FoldQuery(dataSources, options, parameterTypes, hints); + var foldedRightSource = filteredRightSource.FoldQuery(context, hints); // If we can't fold the filter down to the data source, there's no benefit from doing this so stick with the // original join type @@ -544,7 +544,7 @@ private bool FoldSingleRowJoinToNestedLoop(IDictionary dataS else if (nestedLoop.RightSource is FetchXmlScan rightFetch) rightFetch.RemoveSorts(); - folded = nestedLoop.FoldQuery(dataSources, options, parameterTypes, hints); + folded = nestedLoop.FoldQuery(context, hints); return true; } @@ -555,7 +555,7 @@ private static void Swap(ref T left, ref T right) right = temp; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { if (AdditionalJoinCriteria != null) { @@ -567,8 +567,8 @@ public override void AddRequiredColumns(IDictionary dataSour } // Work out which columns need to be pushed down to which source - var leftSchema = LeftSource.GetSchema(dataSources, parameterTypes); - var rightSchema = RightSource.GetSchema(dataSources, parameterTypes); + var leftSchema = LeftSource.GetSchema(context); + var rightSchema = RightSource.GetSchema(context); var leftColumns = requiredColumns .Where(col => leftSchema.ContainsColumn(col, out _)) @@ -580,15 +580,15 @@ public override void AddRequiredColumns(IDictionary dataSour leftColumns.Add(LeftAttribute.GetColumnName()); rightColumns.Add(RightAttribute.GetColumnName()); - LeftSource.AddRequiredColumns(dataSources, parameterTypes, leftColumns); - RightSource.AddRequiredColumns(dataSources, parameterTypes, rightColumns); + LeftSource.AddRequiredColumns(context, leftColumns); + RightSource.AddRequiredColumns(context, rightColumns); } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { - var leftEstimate = LeftSource.EstimateRowsOut(dataSources, options, parameterTypes); + var leftEstimate = LeftSource.EstimateRowsOut(context); ParseEstimate(leftEstimate, out var leftMin, out var leftMax, out var leftIsRange); - var rightEstimate = RightSource.EstimateRowsOut(dataSources, options, parameterTypes); + var rightEstimate = RightSource.EstimateRowsOut(context); ParseEstimate(rightEstimate, out var rightMin, out var rightMax, out var rightIsRange); if (JoinType == QualifiedJoinType.LeftOuter && SemiJoin) @@ -599,8 +599,8 @@ protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { _optionsetCols = new Dictionary(); @@ -105,17 +105,17 @@ public override void AddRequiredColumns(IDictionary dataSour } } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { return new RowCountEstimate(100); } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { return this; } - public override INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { var schema = new Dictionary(StringComparer.OrdinalIgnoreCase); var aliases = new Dictionary>(StringComparer.OrdinalIgnoreCase); @@ -147,9 +147,9 @@ public override IEnumerable GetSources() return Array.Empty(); } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - if (!dataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Missing datasource " + DataSource); var resp = (RetrieveAllOptionSetsResponse)dataSource.Connection.Execute(new RetrieveAllOptionSetsRequest()); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GoToNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GoToNode.cs index 5cc822fb..ec0395c6 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GoToNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GoToNode.cs @@ -13,7 +13,7 @@ class GoToNode : BaseNode, IGoToNode { private int _executionCount; private readonly Timer _timer = new Timer(); - private Func, IQueryExecutionOptions, bool> _condition; + private Func _condition; public override int ExecutionCount => _executionCount; @@ -41,10 +41,10 @@ class GoToNode : BaseNode, IGoToNode [Browsable(false)] public string SourceColumn { get; set; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { if (Source != null) - Source.AddRequiredColumns(dataSources, parameterTypes, new List(requiredColumns)); + Source.AddRequiredColumns(context, new List(requiredColumns)); } public object Clone() @@ -62,7 +62,7 @@ public object Clone() }; } - public string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + public string Execute(NodeExecutionContext context) { using (_timer.Run()) { @@ -74,11 +74,11 @@ public string Execute(IDictionary dataSources, IQueryExecuti if (_condition != null) { - result = _condition(null, parameterValues, options); + result = _condition(new ExpressionExecutionContext(context)); } else if (Source != null) { - var record = Source.Execute(dataSources, options, parameterTypes, parameterValues).First(); + var record = Source.Execute(context).First(); result = ((SqlInt32)record[SourceColumn]).Value == 1; } else @@ -105,12 +105,12 @@ public string Execute(IDictionary dataSources, IQueryExecuti } } - public IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext context, IList hints) { if (Source != null) - Source = Source.FoldQuery(dataSources, options, parameterTypes, hints); + Source = Source.FoldQuery(context, hints); - _condition = Condition?.Compile(dataSources[options.PrimaryDataSource], null, parameterTypes); + _condition = Condition?.Compile(new ExpressionCompilationContext(context, null, null)); return new[] { this }; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GotoLabelNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GotoLabelNode.cs index c2807bf3..78328ee7 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GotoLabelNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GotoLabelNode.cs @@ -24,7 +24,7 @@ class GotoLabelNode : BaseNode, IRootExecutionPlanNodeInternal [Category("Label")] public string Label { get; set; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { } @@ -39,7 +39,7 @@ public object Clone() }; } - public IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext context, IList hints) { return new[] { this }; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashJoinNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashJoinNode.cs index 811550df..42d4a458 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashJoinNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashJoinNode.cs @@ -23,36 +23,40 @@ class OuterRecord private IDictionary> _hashTable; - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { _hashTable = new Dictionary>(); - var mergedSchema = GetSchema(dataSources, parameterTypes, true); - var additionalJoinCriteria = AdditionalJoinCriteria?.Compile(dataSources[options.PrimaryDataSource], mergedSchema, parameterTypes); + var mergedSchema = GetSchema(context, true); + var additionalJoinCriteria = AdditionalJoinCriteria?.Compile(new ExpressionCompilationContext(context, mergedSchema, null)); // Build the hash table - var leftSchema = LeftSource.GetSchema(dataSources, parameterTypes); + var leftSchema = LeftSource.GetSchema(context); leftSchema.ContainsColumn(LeftAttribute.GetColumnName(), out var leftCol); var leftColType = leftSchema.Schema[leftCol]; - var rightSchema = RightSource.GetSchema(dataSources, parameterTypes); + var rightSchema = RightSource.GetSchema(context); rightSchema.ContainsColumn(RightAttribute.GetColumnName(), out var rightCol); var rightColType = rightSchema.Schema[rightCol]; - if (!SqlTypeConverter.CanMakeConsistentTypes(leftColType, rightColType, dataSources[options.PrimaryDataSource], out var keyType)) + if (!SqlTypeConverter.CanMakeConsistentTypes(leftColType, rightColType, context.PrimaryDataSource, out var keyType)) throw new QueryExecutionException($"Cannot match key types {leftColType.ToSql()} and {rightColType.ToSql()}"); var leftKeyAccessor = (ScalarExpression) leftCol.ToColumnReference(); if (!leftColType.IsSameAs(keyType)) leftKeyAccessor = new ConvertCall { Parameter = leftKeyAccessor, DataType = keyType }; - var leftKeyConverter = leftKeyAccessor.Compile(dataSources[options.PrimaryDataSource], leftSchema, parameterTypes); + var leftKeyConverter = leftKeyAccessor.Compile(new ExpressionCompilationContext(context, leftSchema, null)); var rightKeyAccessor = (ScalarExpression)rightCol.ToColumnReference(); if (!rightColType.IsSameAs(keyType)) rightKeyAccessor = new ConvertCall { Parameter = rightKeyAccessor, DataType = keyType }; - var rightKeyConverter = rightKeyAccessor.Compile(dataSources[options.PrimaryDataSource], rightSchema, parameterTypes); + var rightKeyConverter = rightKeyAccessor.Compile(new ExpressionCompilationContext(context, rightSchema, null)); - foreach (var entity in LeftSource.Execute(dataSources, options, parameterTypes, parameterValues)) + var expressionContext = new ExpressionExecutionContext(context); + + foreach (var entity in LeftSource.Execute(context)) { - var key = leftKeyConverter(entity, parameterValues, options); + expressionContext.Entity = entity; + + var key = leftKeyConverter(expressionContext); if (!_hashTable.TryGetValue(key, out var list)) { @@ -64,9 +68,11 @@ protected override IEnumerable ExecuteInternal(IDictionary ExecuteInternal(IDictionary GetSortOrder(INodeSchema outerSchema, I return null; } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - var folded = base.FoldQuery(dataSources, options, parameterTypes, hints); + var folded = base.FoldQuery(context, hints); if (folded != this) return folded; @@ -118,8 +125,8 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary RightSource.EstimatedRowsOut) { diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashMatchAggregateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashMatchAggregateNode.cs index 567fc6c8..61235a74 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashMatchAggregateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashMatchAggregateNode.cs @@ -23,14 +23,15 @@ class HashMatchAggregateNode : BaseAggregateNode { private bool _folded; - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - var schema = Source.GetSchema(dataSources, parameterTypes); + var schema = Source.GetSchema(context); var groupByCols = GetGroupingColumns(schema); var groups = new Dictionary>(new DistinctEqualityComparer(groupByCols)); - InitializeAggregates(dataSources[options.PrimaryDataSource], schema, parameterTypes); - var aggregates = CreateAggregateFunctions(parameterValues, options, false); + InitializeAggregates(new ExpressionCompilationContext(context, schema, null)); + var executionContext = new ExpressionExecutionContext(context); + var aggregates = CreateAggregateFunctions(executionContext, false); if (IsScalarAggregate) { @@ -39,7 +40,7 @@ protected override IEnumerable ExecuteInternal(IDictionary ExecuteInternal(IDictionary ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { if (_folded) return this; - Source = Source.FoldQuery(dataSources, options, parameterTypes, hints); + Source = Source.FoldQuery(context, hints); Source.Parent = this; // Special case for using RetrieveTotalRecordCount instead of FetchXML @@ -85,10 +88,10 @@ Source is FetchXmlScan fetch && GroupBy.Count == 0 && Aggregates.Count == 1 && Aggregates.Single().Value.AggregateType == AggregateType.CountStar && - dataSources[fetch.DataSource].Metadata[fetch.Entity.name].DataProviderId == null) // RetrieveTotalRecordCountRequest is not valid for virtual entities + context.DataSources[fetch.DataSource].Metadata[fetch.Entity.name].DataProviderId == null) // RetrieveTotalRecordCountRequest is not valid for virtual entities { var count = new RetrieveTotalRecordCountNode { DataSource = fetch.DataSource, EntityName = fetch.Entity.name }; - var countName = count.GetSchema(dataSources, parameterTypes).Schema.Single().Key; + var countName = count.GetSchema(context).Schema.Single().Key; if (countName == Aggregates.Single().Key) return count; @@ -234,7 +237,7 @@ Source is FetchXmlScan fetch && } // FetchXML dategrouping always uses local timezone. If we're using UTC we can't use it - if (!options.UseLocalTimeZone) + if (!context.Options.UseLocalTimeZone) { canUseFetchXmlAggregate = false; break; @@ -242,14 +245,14 @@ Source is FetchXmlScan fetch && } } - var metadata = dataSources[fetchXml.DataSource].Metadata; + var metadata = context.DataSources[fetchXml.DataSource].Metadata; // FetchXML is translated to QueryExpression for virtual entities, which doesn't support aggregates if (metadata[fetchXml.Entity.name].DataProviderId != null) canUseFetchXmlAggregate = false; // Check FetchXML supports grouping by each of the requested attributes - var fetchSchema = fetchXml.GetSchema(dataSources, parameterTypes); + var fetchSchema = fetchXml.GetSchema(context); foreach (var group in GroupBy) { if (!fetchSchema.ContainsColumn(group.GetColumnName(), out var groupCol)) @@ -304,7 +307,8 @@ Source is FetchXmlScan fetch && fetchXml.FetchXml.aggregateSpecified = true; fetchXml.FetchXml = fetchXml.FetchXml; - var schema = Source.GetSchema(dataSources, parameterTypes); + var schema = Source.GetSchema(context); + var expressionCompilationContext = new ExpressionCompilationContext(context, schema, null); foreach (var grouping in GroupBy) { @@ -340,7 +344,7 @@ Source is FetchXmlScan fetch && attribute.dategrouping = dateGrouping.Value; attribute.dategroupingSpecified = true; } - else if (grouping.GetType(dataSources[options.PrimaryDataSource], schema, null, parameterTypes, out _) == typeof(SqlDateTime)) + else if (grouping.GetType(expressionCompilationContext, out _) == typeof(SqlDateTime)) { // Can't group on datetime columns without a DATEPART specification canUseFetchXmlAggregate = false; @@ -435,7 +439,7 @@ Source is FetchXmlScan fetch && // Check how we should execute this aggregate if the FetchXML aggregate fails or is not available. Use stream aggregate // for scalar aggregates or where all the grouping fields can be folded into sorts. - var nonFetchXmlAggregate = FoldToStreamAggregate(dataSources, options, parameterTypes, hints); + var nonFetchXmlAggregate = FoldToStreamAggregate(context, hints); if (!canUseFetchXmlAggregate) return nonFetchXmlAggregate; @@ -583,10 +587,10 @@ Source is FetchXmlScan fetch && return tryCatch; } - return FoldToStreamAggregate(dataSources, options, parameterTypes, hints); + return FoldToStreamAggregate(context, hints); } - private IDataExecutionPlanNodeInternal FoldToStreamAggregate(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + private IDataExecutionPlanNodeInternal FoldToStreamAggregate(NodeCompilationContext context, IList hints) { // Use stream aggregate where possible - if there are no grouping fields or the groups can be folded into sorts var streamAggregate = new StreamAggregateNode { Source = Source }; @@ -606,7 +610,7 @@ private IDataExecutionPlanNodeInternal FoldToStreamAggregate(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { // Columns required by previous nodes must be derived from this node, so no need to pass them through. // Just calculate the columns that are required to calculate the groups & aggregates @@ -626,7 +630,7 @@ public override void AddRequiredColumns(IDictionary dataSour scalarRequiredColumns.AddRange(Aggregates.Where(agg => agg.Value.SqlExpression != null).SelectMany(agg => agg.Value.SqlExpression.GetColumns()).Distinct()); - Source.AddRequiredColumns(dataSources, parameterTypes, scalarRequiredColumns); + Source.AddRequiredColumns(context, scalarRequiredColumns); } public override object Clone() diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IDataExecutionPlanNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IDataExecutionPlanNode.cs index 7fc917b8..58aa56ba 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IDataExecutionPlanNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IDataExecutionPlanNode.cs @@ -29,26 +29,31 @@ internal interface IDataExecutionPlanNodeInternal : IDataExecutionPlanNode, IExe /// /// Populates with an estimate of the number of rows that will be returned by this node /// - RowCountEstimate EstimateRowsOut(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes); + /// The context in which the node is being built + /// An estimate of how many rows will be returned by the node + RowCountEstimate EstimateRowsOut(NodeCompilationContext context); /// /// Executes the execution plan /// - /// The to use to execute the plan + /// The context in which the node is being executed /// A sequence of entities matched by the query - IEnumerable Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues); + IEnumerable Execute(NodeExecutionContext context); /// /// Attempts to fold the query operator down into its source /// + /// The context in which the node is being built + /// Any optimizer hints which may affect how the query is folded /// The final execution plan node to execute - IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints); + IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints); /// /// Gets the schema of the dataset returned by the node /// - /// - INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes); + /// The context in which the node is being built + /// The schema of data that will be produced by the node + INodeSchema GetSchema(NodeCompilationContext context); /// /// Gets the variables that are in use by this node and optionally its sources diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IDataSetExecutionPlanNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IDataSetExecutionPlanNode.cs index 55ec6c75..f645d160 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IDataSetExecutionPlanNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IDataSetExecutionPlanNode.cs @@ -18,12 +18,9 @@ internal interface IDataReaderExecutionPlanNode : IRootExecutionPlanNodeInternal /// /// Executes the execution plan /// - /// The data sources that can be used in the query - /// The options that control how the query should be executed - /// The types of the parameters that are available to the query - /// The values of the parameters that are available to the query + /// The context in which the node is being executed /// Additional options to control how the command should be executed /// A that contains the results of the query - DbDataReader Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, CommandBehavior behavior); + DbDataReader Execute(NodeExecutionContext context, CommandBehavior behavior); } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IDmlQueryExecutionPlanNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IDmlQueryExecutionPlanNode.cs index bcadbf0b..d865ab29 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IDmlQueryExecutionPlanNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IDmlQueryExecutionPlanNode.cs @@ -16,8 +16,9 @@ internal interface IDmlQueryExecutionPlanNode : IRootExecutionPlanNodeInternal /// /// Executes the execution plan /// - /// The to use to execute the plan + /// The context in which the node is being executed + /// The number of records that were affected by the query /// A status message for the results of the query - string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, out int recordsAffected); + string Execute(NodeExecutionContext context, out int recordsAffected); } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IExecutionPlanNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IExecutionPlanNode.cs index 714ff7f0..d6f2850b 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IExecutionPlanNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IExecutionPlanNode.cs @@ -52,9 +52,8 @@ internal interface IExecutionPlanNodeInternal : IExecutionPlanNode, ICloneable /// /// Adds columns into the query which are required by preceding nodes /// - /// - /// - /// - void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns); + /// The context in which the node is being built + /// The columns which are required by the parent node + void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns); } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IGoToNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IGoToNode.cs index 39abc073..834ce2af 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IGoToNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IGoToNode.cs @@ -13,11 +13,8 @@ interface IGoToNode : IRootExecutionPlanNodeInternal /// /// Checks which nodes should be executed next /// - /// The data sources that can be accessed by the query - /// The options which describe how the query should be executed - /// The types of any parameters available to the query - /// The values of any parameters available to the query + /// The context in which the node is being executed /// The label which should be executed next - string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues); + string Execute(NodeExecutionContext context); } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IRootExecutionPlanNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IRootExecutionPlanNode.cs index 23947dd1..ad383417 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IRootExecutionPlanNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IRootExecutionPlanNode.cs @@ -33,11 +33,9 @@ internal interface IRootExecutionPlanNodeInternal : IRootExecutionPlanNode, IExe /// /// Attempts to fold this node into its source to simplify the query /// - /// The data sources that the query can use - /// to indicate how the query can be executed - /// A mapping of parameter names to their related types + /// The context in which the node is being built /// Any optimizer hints to apply /// The node that should be used in place of this node - IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints); + IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext context, IList hints); } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IndexSpoolNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IndexSpoolNode.cs index 11128996..15d0e341 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IndexSpoolNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/IndexSpoolNode.cs @@ -40,16 +40,16 @@ public IndexSpoolNode() { } [DisplayName("Seek Value")] public string SeekValue { get; set; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { requiredColumns.Add(KeyColumn); - Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + Source.AddRequiredColumns(context, requiredColumns); } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { - var rows = Source.EstimateRowsOut(dataSources, options, parameterTypes); + var rows = Source.EstimateRowsOut(context); if (rows is RowCountEstimateDefiniteRange range && range.Maximum == 1) return range; @@ -57,15 +57,15 @@ protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - Source = Source.FoldQuery(dataSources, options, parameterTypes, hints); + Source = Source.FoldQuery(context, hints); // Index and seek values must be the same type - var indexType = Source.GetSchema(dataSources, parameterTypes).Schema[KeyColumn]; - var seekType = parameterTypes[SeekValue]; + var indexType = Source.GetSchema(context).Schema[KeyColumn]; + var seekType = context.ParameterTypes[SeekValue]; - if (!SqlTypeConverter.CanMakeConsistentTypes(indexType, seekType, dataSources[options.PrimaryDataSource], out var consistentType)) + if (!SqlTypeConverter.CanMakeConsistentTypes(indexType, seekType, context.PrimaryDataSource, out var consistentType)) throw new QueryExecutionException($"No type conversion available for {indexType.ToSql()} and {seekType.ToSql()}"); _keySelector = SqlTypeConverter.GetConversion(indexType, consistentType); @@ -74,9 +74,9 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { - return Source.GetSchema(dataSources, parameterTypes); + return Source.GetSchema(context); } public override IEnumerable GetSources() @@ -84,17 +84,17 @@ public override IEnumerable GetSources() yield return Source; } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { // Build an internal hash table of the source indexed by the key column if (_hashTable == null) { - _hashTable = Source.Execute(dataSources, options, parameterTypes, parameterValues) + _hashTable = Source.Execute(context) .GroupBy(e => _keySelector((INullable)e[KeyColumn])) .ToDictionary(g => g.Key, g => g.ToList()); } - var keyValue = _seekSelector((INullable)parameterValues[SeekValue]); + var keyValue = _seekSelector((INullable)context.ParameterValues[SeekValue]); if (!_hashTable.TryGetValue(keyValue, out var matches)) return Array.Empty(); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/InsertNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/InsertNode.cs index 5b303db5..9252e6e4 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/InsertNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/InsertNode.cs @@ -45,7 +45,7 @@ class InsertNode : BaseDmlNode [Category("Insert")] public override bool BypassCustomPluginExecution { get; set; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { foreach (var col in ColumnMappings.Values) { @@ -53,16 +53,16 @@ public override void AddRequiredColumns(IDictionary dataSour requiredColumns.Add(col); } - Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + Source.AddRequiredColumns(context, requiredColumns); } - public override string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, out int recordsAffected) + public override string Execute(NodeExecutionContext context, out int recordsAffected) { _executionCount++; try { - if (!dataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) throw new QueryExecutionException("Missing datasource " + DataSource); List entities; @@ -73,21 +73,21 @@ public override string Execute(IDictionary dataSources, IQue using (_timer.Run()) { - entities = GetDmlSourceEntities(dataSources, options, parameterTypes, parameterValues, out var schema); + entities = GetDmlSourceEntities(context, out var schema); // Precompile mappings with type conversions meta = dataSource.Metadata[LogicalName]; attributes = meta.Attributes.ToDictionary(a => a.LogicalName, StringComparer.OrdinalIgnoreCase); - var dateTimeKind = options.UseLocalTimeZone ? DateTimeKind.Local : DateTimeKind.Utc; + var dateTimeKind = context.Options.UseLocalTimeZone ? DateTimeKind.Local : DateTimeKind.Utc; attributeAccessors = CompileColumnMappings(dataSource, LogicalName, ColumnMappings, schema, dateTimeKind, entities); attributeAccessors.TryGetValue(meta.PrimaryIdAttribute, out primaryIdAccessor); } // Check again that the update is allowed. Don't count any UI interaction in the execution time var confirmArgs = new ConfirmDmlStatementEventArgs(entities.Count, meta, BypassCustomPluginExecution); - if (options.CancellationToken.IsCancellationRequested) + if (context.Options.CancellationToken.IsCancellationRequested) confirmArgs.Cancel = true; - options.ConfirmInsert(confirmArgs); + context.Options.ConfirmInsert(confirmArgs); if (confirmArgs.Cancel) throw new OperationCanceledException("INSERT cancelled by user"); @@ -95,7 +95,7 @@ public override string Execute(IDictionary dataSources, IQue { return ExecuteDmlOperation( dataSource.Connection, - options, + context.Options, entities, meta, entity => CreateInsertRequest(meta, entity, attributeAccessors, primaryIdAccessor, attributes), @@ -106,8 +106,8 @@ public override string Execute(IDictionary dataSources, IQue CompletedLowercase = "inserted" }, out recordsAffected, - parameterValues, - LogicalName == "listmember" || meta.IsIntersect == true ? null : (Action) ((r) => SetIdentity(r, parameterValues)) + context.ParameterValues, + LogicalName == "listmember" || meta.IsIntersect == true ? null : (Action) ((r) => SetIdentity(r, context.ParameterValues)) ); } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MergeJoinNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MergeJoinNode.cs index e9fe5921..a3804361 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MergeJoinNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MergeJoinNode.cs @@ -16,11 +16,8 @@ namespace MarkMpn.Sql4Cds.Engine.ExecutionPlan /// class MergeJoinNode : FoldableJoinNode { - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - if (!dataSources.TryGetValue(options.PrimaryDataSource, out var dataSource)) - throw new QueryExecutionException("Invalid data source"); - // https://sqlserverfast.com/epr/merge-join/ // Implemented inner, left outer, right outer and full outer variants // Not implemented semi joins @@ -28,12 +25,14 @@ protected override IEnumerable ExecuteInternal(IDictionary ExecuteInternal(IDictionary ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - var folded = base.FoldQuery(dataSources, options, parameterTypes, hints); + var folded = base.FoldQuery(context, hints); if (folded != this) return folded; @@ -132,7 +132,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { _entityCols = new Dictionary(); _attributeCols = new Dictionary(); @@ -373,7 +373,7 @@ private void NormalizeProperties(MetadataQueryExpression query, IEnumerable dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { var entityCount = 100; var attributesPerEntity = 1; @@ -434,12 +434,12 @@ private bool HasEqualityFilter(MetadataFilterExpression filter, string propName) return false; } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { return this; } - public override INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { var schema = new Dictionary(StringComparer.OrdinalIgnoreCase); var aliases = new Dictionary>(); @@ -689,7 +689,13 @@ internal static Func GetPropertyAccessor(PropertyInfo prop, Type if (directConversionType == typeof(SqlString) && value.Type != typeof(string)) value = Expression.Call(value, nameof(Object.ToString), Array.Empty()); - var converted = SqlTypeConverter.Convert(value, directConversionType); + Expression converted; + + if (value.Type == typeof(string) && directConversionType == typeof(SqlString) && targetType == typeof(SqlString)) + converted = Expr.Call(() => ApplyCollation(Expr.Arg()), value); + else + converted = SqlTypeConverter.Convert(value, directConversionType); + if (targetType != directConversionType) converted = SqlTypeConverter.Convert(converted, targetType); @@ -709,12 +715,21 @@ internal static Func GetPropertyAccessor(PropertyInfo prop, Type return func; } + private static SqlString ApplyCollation(string value) + { + if (value == null) + return SqlString.Null; + + // Assume all metadata values should use standard collation rather than datasource specific? + return Collation.USEnglish.ToSqlString(value); + } + public override IEnumerable GetSources() { return Array.Empty(); } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { if (MetadataSource.HasFlag(MetadataSource.Attribute)) { @@ -756,7 +771,7 @@ protected override IEnumerable ExecuteInternal(IDictionary OuterReferences { get; set; } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - var leftSchema = LeftSource.GetSchema(dataSources, parameterTypes); - var innerParameterTypes = GetInnerParameterTypes(leftSchema, parameterTypes); + var leftSchema = LeftSource.GetSchema(context); + var innerParameterTypes = GetInnerParameterTypes(leftSchema, context.ParameterTypes); if (OuterReferences != null) { - if (parameterTypes == null) + if (context.ParameterTypes == null) innerParameterTypes = new Dictionary(StringComparer.OrdinalIgnoreCase); else - innerParameterTypes = new Dictionary(parameterTypes, StringComparer.OrdinalIgnoreCase); + innerParameterTypes = new Dictionary(context.ParameterTypes, StringComparer.OrdinalIgnoreCase); foreach (var kvp in OuterReferences) innerParameterTypes[kvp.Value] = leftSchema.Schema[kvp.Key]; } - var rightSchema = RightSource.GetSchema(dataSources, innerParameterTypes); - var mergedSchema = GetSchema(dataSources, parameterTypes, true); - var joinCondition = JoinCondition?.Compile(dataSources[options.PrimaryDataSource], mergedSchema, parameterTypes); + var rightCompilationContext = new NodeCompilationContext(context.DataSources, context.Options, innerParameterTypes); + var rightSchema = RightSource.GetSchema(rightCompilationContext); + var mergedSchema = GetSchema(context, true); + var joinCondition = JoinCondition?.Compile(new ExpressionCompilationContext(context, mergedSchema, null)); + var joinConditionContext = joinCondition == null ? null : new ExpressionExecutionContext(context); - foreach (var left in LeftSource.Execute(dataSources, options, parameterTypes, parameterValues)) + foreach (var left in LeftSource.Execute(context)) { - var innerParameters = parameterValues; + var innerParameters = context.ParameterValues; if (OuterReferences != null) { - if (parameterValues == null) + if (innerParameters == null) innerParameters = new Dictionary(); else - innerParameters = new Dictionary(parameterValues); + innerParameters = new Dictionary(innerParameters); foreach (var kvp in OuterReferences) { @@ -69,12 +71,21 @@ protected override IEnumerable ExecuteInternal(IDictionary GetInnerParameterTypes(INodeSchem return innerParameterTypes; } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - var leftSchema = LeftSource.GetSchema(dataSources, parameterTypes); - LeftSource = LeftSource.FoldQuery(dataSources, options, parameterTypes, hints); + var leftSchema = LeftSource.GetSchema(context); + LeftSource = LeftSource.FoldQuery(context, hints); LeftSource.Parent = this; - var innerParameterTypes = GetInnerParameterTypes(leftSchema, parameterTypes); - RightSource = RightSource.FoldQuery(dataSources, options, innerParameterTypes, hints); + var innerParameterTypes = GetInnerParameterTypes(leftSchema, context.ParameterTypes); + var innerContext = new NodeCompilationContext(context.DataSources, context.Options, innerParameterTypes); + RightSource = RightSource.FoldQuery(innerContext, hints); RightSource.Parent = this; if (LeftSource is ConstantScanNode constant && @@ -121,7 +133,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary 5000) + LeftSource.EstimateRowsOut(context).Value < 100 && + fetch.EstimateRowsOut(innerContext).Value > 5000) { // Scalar subquery was folded to use an index spool due to an expected large number of outer records, // but the estimate has now changed (e.g. due to a TopNode being folded). Remove the index spool and replace @@ -174,13 +186,13 @@ indexSpool.Source is FetchXmlScan fetch && }, Parent = aggregate }; - aggregate.Source = filter.FoldQuery(dataSources, options, innerParameterTypes, hints); + aggregate.Source = filter.FoldQuery(innerContext, hints); } return this; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { if (JoinCondition != null) { @@ -191,39 +203,42 @@ public override void AddRequiredColumns(IDictionary dataSour } } - var leftSchema = LeftSource.GetSchema(dataSources, parameterTypes); + var leftSchema = LeftSource.GetSchema(context); var leftColumns = requiredColumns .Where(col => leftSchema.ContainsColumn(col, out _)) .Concat((IEnumerable) OuterReferences?.Keys ?? Array.Empty()) .Distinct() .ToList(); - var innerParameterTypes = GetInnerParameterTypes(leftSchema, parameterTypes); - var rightSchema = RightSource.GetSchema(dataSources, innerParameterTypes); + var innerParameterTypes = GetInnerParameterTypes(leftSchema, context.ParameterTypes); + var innerContext = new NodeCompilationContext(context.DataSources, context.Options, innerParameterTypes); + var rightSchema = RightSource.GetSchema(innerContext); var rightColumns = requiredColumns .Where(col => rightSchema.ContainsColumn(col, out _)) .Concat(DefinedValues.Values) .Distinct() .ToList(); - LeftSource.AddRequiredColumns(dataSources, parameterTypes, leftColumns); - RightSource.AddRequiredColumns(dataSources, parameterTypes, rightColumns); + LeftSource.AddRequiredColumns(context, leftColumns); + RightSource.AddRequiredColumns(context, rightColumns); } - protected override INodeSchema GetRightSchema(IDictionary dataSources, IDictionary parameterTypes) + protected override INodeSchema GetRightSchema(NodeCompilationContext context) { - var leftSchema = LeftSource.GetSchema(dataSources, parameterTypes); - var innerParameterTypes = GetInnerParameterTypes(leftSchema, parameterTypes); - return RightSource.GetSchema(dataSources, innerParameterTypes); + var leftSchema = LeftSource.GetSchema(context); + var innerParameterTypes = GetInnerParameterTypes(leftSchema, context.ParameterTypes); + var innerContext = new NodeCompilationContext(context.DataSources, context.Options, innerParameterTypes); + return RightSource.GetSchema(innerContext); } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { - var leftEstimate = LeftSource.EstimateRowsOut(dataSources, options, parameterTypes); + var leftEstimate = LeftSource.EstimateRowsOut(context); ParseEstimate(leftEstimate, out var leftMin, out var leftMax, out var leftIsRange); - var leftSchema = LeftSource.GetSchema(dataSources, parameterTypes); - var innerParameterTypes = GetInnerParameterTypes(leftSchema, parameterTypes); + var leftSchema = LeftSource.GetSchema(context); + var innerParameterTypes = GetInnerParameterTypes(leftSchema, context.ParameterTypes); + var innerContext = new NodeCompilationContext(context.DataSources, context.Options, innerParameterTypes); - var rightEstimate = RightSource.EstimateRowsOut(dataSources, options, innerParameterTypes); + var rightEstimate = RightSource.EstimateRowsOut(innerContext); ParseEstimate(rightEstimate, out var rightMin, out var rightMax, out var rightIsRange); if (JoinType == QualifiedJoinType.LeftOuter && SemiJoin) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/OffsetFetchNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/OffsetFetchNode.cs index fb12acd8..00a0da7f 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/OffsetFetchNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/OffsetFetchNode.cs @@ -32,10 +32,12 @@ class OffsetFetchNode : BaseDataNode, ISingleSourceExecutionPlanNode [Browsable(false)] public IDataExecutionPlanNodeInternal Source { get; set; } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - var offset = SqlTypeConverter.ChangeType(Offset.Compile(dataSources[options.PrimaryDataSource], null, parameterTypes)(null, parameterValues, options)); - var fetch = SqlTypeConverter.ChangeType(Fetch.Compile(dataSources[options.PrimaryDataSource], null, parameterTypes)(null, parameterValues, options)); + var expressionCompilationContext = new ExpressionCompilationContext(context, null, null); + var expressionExecutionContext = new ExpressionExecutionContext(context); + var offset = SqlTypeConverter.ChangeType(Offset.Compile(expressionCompilationContext)(expressionExecutionContext)); + var fetch = SqlTypeConverter.ChangeType(Fetch.Compile(expressionCompilationContext)(expressionExecutionContext)); if (offset < 0) throw new QueryExecutionException("The offset specified in a OFFSET clause may not be negative."); @@ -43,14 +45,14 @@ protected override IEnumerable ExecuteInternal(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { - return Source.GetSchema(dataSources, parameterTypes); + return Source.GetSchema(context); } public override IEnumerable GetSources() @@ -58,19 +60,22 @@ public override IEnumerable GetSources() yield return Source; } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - Source = Source.FoldQuery(dataSources, options, parameterTypes, hints); + Source = Source.FoldQuery(context, hints); Source.Parent = this; - if (!Offset.IsConstantValueExpression(dataSources[options.PrimaryDataSource], null, options, out var offsetLiteral) || - !Fetch.IsConstantValueExpression(dataSources[options.PrimaryDataSource], null, options, out var fetchLiteral)) + var expressionCompilationContext = new ExpressionCompilationContext(context.DataSources, context.Options, null, null, null); + + if (!Offset.IsConstantValueExpression(expressionCompilationContext, out var offsetLiteral) || + !Fetch.IsConstantValueExpression(expressionCompilationContext, out var fetchLiteral)) return this; if (Source is FetchXmlScan fetchXml) { - var offset = SqlTypeConverter.ChangeType(offsetLiteral.Compile(dataSources[options.PrimaryDataSource], null, null)(null, null, options)); - var count = SqlTypeConverter.ChangeType(fetchLiteral.Compile(dataSources[options.PrimaryDataSource], null, null)(null, null, options)); + var expressionExecutionContext = new ExpressionExecutionContext(expressionCompilationContext); + var offset = SqlTypeConverter.ChangeType(offsetLiteral.Compile(expressionCompilationContext)(expressionExecutionContext)); + var count = SqlTypeConverter.ChangeType(fetchLiteral.Compile(expressionCompilationContext)(expressionExecutionContext)); var page = offset / count; if (page * count == offset && count <= 5000) @@ -85,17 +90,18 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { - Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + Source.AddRequiredColumns(context, requiredColumns); } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { - var sourceCount = Source.EstimateRowsOut(dataSources, options, parameterTypes); + var sourceCount = Source.EstimateRowsOut(context); + var expressionCompilationContext = new ExpressionCompilationContext(context, null, null); - if (!Offset.IsConstantValueExpression(dataSources[options.PrimaryDataSource], null, options, out var offsetLiteral) || - !Fetch.IsConstantValueExpression(dataSources[options.PrimaryDataSource], null, options, out var fetchLiteral)) + if (!Offset.IsConstantValueExpression(expressionCompilationContext, out var offsetLiteral) || + !Fetch.IsConstantValueExpression(expressionCompilationContext, out var fetchLiteral)) return sourceCount; var offset = Int32.Parse(offsetLiteral.Value, CultureInfo.InvariantCulture); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PartitionedAggregateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PartitionedAggregateNode.cs index 6ba84929..9bac5b4d 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PartitionedAggregateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PartitionedAggregateNode.cs @@ -46,36 +46,38 @@ class Partition private int _pendingPartitions; private object _lock; - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - Source = Source.FoldQuery(dataSources, options, parameterTypes, hints); + Source = Source.FoldQuery(context, hints); Source.Parent = this; return this; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { // All required columns must already have been added during the original folding of the HashMatchAggregateNode } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - var schema = Source.GetSchema(dataSources, parameterTypes); + var schema = Source.GetSchema(context); var groupByCols = GetGroupingColumns(schema); var groups = new ConcurrentDictionary>(new DistinctEqualityComparer(groupByCols)); + var expressionCompilationContext = new ExpressionCompilationContext(context, schema, null); + var expressionExecutionContext = new ExpressionExecutionContext(context); - InitializePartitionedAggregates(dataSources[options.PrimaryDataSource], schema, parameterTypes); - var aggregates = CreateAggregateFunctions(parameterValues, options, true); + InitializePartitionedAggregates(expressionCompilationContext); + var aggregates = CreateAggregateFunctions(expressionExecutionContext, true); var fetchXmlNode = (FetchXmlScan)Source; var name = fetchXmlNode.Entity.name; - var meta = dataSources[fetchXmlNode.DataSource].Metadata[name]; - options.Progress(0, $"Partitioning {GetDisplayName(0, meta)}..."); + var meta = context.DataSources[fetchXmlNode.DataSource].Metadata[name]; + context.Options.Progress(0, $"Partitioning {GetDisplayName(0, meta)}..."); // Get the minimum and maximum primary keys from the source - var minKey = GetMinMaxKey(fetchXmlNode, dataSources, options, parameterTypes, parameterValues, false); - var maxKey = GetMinMaxKey(fetchXmlNode, dataSources, options, parameterTypes, parameterValues, true); + var minKey = GetMinMaxKey(fetchXmlNode, context, false); + var maxKey = GetMinMaxKey(fetchXmlNode, context, true); if (minKey.IsNull || maxKey.IsNull || minKey == maxKey) throw new QueryExecutionException("Cannot partition query"); @@ -96,9 +98,9 @@ protected override IEnumerable ExecuteInternal(IDictionary ExecuteInternal(IDictionary ExecuteInternal(IDictionary ExecuteInternal(IDictionary ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, Dictionary aggregates, ConcurrentDictionary> groups, FetchXmlScan fetchXmlNode, SqlDateTime minValue, SqlDateTime maxValue) + private void ExecuteAggregate(NodeExecutionContext context, ExpressionExecutionContext expressionContext, Dictionary aggregates, ConcurrentDictionary> groups, FetchXmlScan fetchXmlNode, SqlDateTime minValue, SqlDateTime maxValue) { - parameterValues["@PartitionStart"] = minValue; - parameterValues["@PartitionEnd"] = maxValue; + context.ParameterValues["@PartitionStart"] = minValue; + context.ParameterValues["@PartitionEnd"] = maxValue; - var results = fetchXmlNode.Execute(dataSources, options, parameterTypes, parameterValues); + var results = fetchXmlNode.Execute(context); foreach (var entity in results) { // Update aggregates var values = groups.GetOrAdd(entity, _ => ResetAggregates(aggregates)); - lock (values) + lock (expressionContext) { + expressionContext.Entity = entity; + foreach (var func in values.Values) - func.AggregateFunction.NextPartition(entity, func.State); + func.AggregateFunction.NextPartition(func.State); } } } - private SqlDateTime GetMinMaxKey(FetchXmlScan fetchXmlNode, IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, bool max) + private SqlDateTime GetMinMaxKey(FetchXmlScan fetchXmlNode, NodeExecutionContext context, bool max) { // Create a new FetchXmlScan node with a copy of the original query var minMaxNode = new FetchXmlScan @@ -333,7 +337,7 @@ private SqlDateTime GetMinMaxKey(FetchXmlScan fetchXmlNode, IDictionary, IQueryExecutionOptions, object> _expression; + private Func _expression; public override int ExecutionCount => _executionCount; @@ -31,7 +31,7 @@ class PrintNode : BaseNode, IDmlQueryExecutionPlanNode [Description("The value to print")] public ScalarExpression Expression { get; set; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { } @@ -47,14 +47,14 @@ public object Clone() }; } - public string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, out int recordsAffected) + public string Execute(NodeExecutionContext context, out int recordsAffected) { _executionCount++; recordsAffected = -1; using (_timer.Run()) { - var value = (SqlString)_expression(null, parameterValues, options); + var value = (SqlString)_expression(new ExpressionExecutionContext(context)); if (value.IsNull) return null; @@ -63,9 +63,9 @@ public string Execute(IDictionary dataSources, IQueryExecuti } } - public IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext context, IList hints) { - _expression = Expression.Compile(dataSources[options.PrimaryDataSource], null, parameterTypes); + _expression = Expression.Compile(new ExpressionCompilationContext(context, null, null)); return new[] { this }; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RetrieveTotalRecordCountNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RetrieveTotalRecordCountNode.cs index b6f7e0d6..acc1b475 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RetrieveTotalRecordCountNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RetrieveTotalRecordCountNode.cs @@ -27,10 +27,9 @@ class RetrieveTotalRecordCountNode : BaseDataNode [Description("The logical name of the entity to get the record count for")] public string EntityName { get; set; } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - - if (!dataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) throw new QueryExecutionException("Missing datasource " + DataSource); var count = ((RetrieveTotalRecordCountResponse)dataSource.Connection.Execute(new RetrieveTotalRecordCountRequest { EntityNames = new[] { EntityName } })).EntityRecordCountCollection[EntityName]; @@ -48,7 +47,7 @@ public override IEnumerable GetSources() return Array.Empty(); } - public override INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { return new NodeSchema( primaryKey: null, @@ -64,16 +63,16 @@ public override INodeSchema GetSchema(IDictionary dataSource sortOrder: null); } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { return this; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { return RowCountEstimateDefiniteRange.ExactlyOne; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RevertNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RevertNode.cs index 81b28e98..5029ccb1 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RevertNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RevertNode.cs @@ -48,11 +48,11 @@ class RevertNode : BaseNode, IDmlQueryExecutionPlanNode, IImpersonateRevertExecu public override TimeSpan Duration => _timer.Duration; - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { } - public string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, out int recordsAffected) + public string Execute(NodeExecutionContext context, out int recordsAffected) { _executionCount++; @@ -60,7 +60,7 @@ public string Execute(IDictionary dataSources, IQueryExecuti { using (_timer.Run()) { - if (!dataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) throw new QueryExecutionException("Missing datasource " + DataSource); #if NETCOREAPP @@ -94,7 +94,7 @@ public string Execute(IDictionary dataSources, IQueryExecuti } } - public IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext context, IList hints) { return new[] { this }; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SelectNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SelectNode.cs index 08c6bf75..812202d5 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SelectNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SelectNode.cs @@ -49,18 +49,18 @@ class SelectNode : BaseNode, ISingleSourceExecutionPlanNode, IDataReaderExecutio public override int ExecutionCount => _executionCount; - public DbDataReader Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, CommandBehavior behavior) + public DbDataReader Execute(NodeExecutionContext context, CommandBehavior behavior) { _executionCount++; var timer = _timer.Run(); - var schema = Source.GetSchema(dataSources, parameterTypes); - var source = behavior.HasFlag(CommandBehavior.SchemaOnly) ? Array.Empty() : Source.Execute(dataSources, options, parameterTypes, parameterValues); + var schema = Source.GetSchema(context); + var source = behavior.HasFlag(CommandBehavior.SchemaOnly) ? Array.Empty() : Source.Execute(context); if (behavior.HasFlag(CommandBehavior.SingleRow)) source = source.Take(1); - return new SelectDataReader(ColumnSet, timer, schema, source, parameterValues); + return new SelectDataReader(ColumnSet, timer, schema, source, context.ParameterValues); } public override IEnumerable GetSources() @@ -68,15 +68,15 @@ public override IEnumerable GetSources() yield return Source; } - public IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext context, IList hints) { - Source = Source.FoldQuery(dataSources, options, parameterTypes, hints); + Source = Source.FoldQuery(context, hints); Source.Parent = this; - FoldFetchXmlColumns(Source, ColumnSet, dataSources, parameterTypes); + FoldFetchXmlColumns(Source, ColumnSet, context); FoldMetadataColumns(Source, ColumnSet); - ExpandWildcardColumns(dataSources, parameterTypes); + ExpandWildcardColumns(context); if (Source is AliasNode alias) { @@ -92,15 +92,15 @@ public IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary columnSet, IDictionary dataSources, IDictionary parameterTypes) + internal static void FoldFetchXmlColumns(IDataExecutionPlanNode source, List columnSet, NodeCompilationContext context) { if (source is FetchXmlScan fetchXml) { - if (!dataSources.TryGetValue(fetchXml.DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(fetchXml.DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Missing datasource " + fetchXml.DataSource); // Check if there are any aliases we can apply to the source FetchXml - var schema = fetchXml.GetSchema(dataSources, parameterTypes); + var schema = fetchXml.GetSchema(context); var hasStar = columnSet.Any(col => col.AllColumns && col.SourceColumn == null); var aliasStars = new HashSet(columnSet.Where(col => col.AllColumns && col.SourceColumn != null).Select(col => col.SourceColumn.Replace(".*", "")).Distinct(StringComparer.OrdinalIgnoreCase), StringComparer.OrdinalIgnoreCase); @@ -255,17 +255,17 @@ private void FoldMetadataColumns(IDataExecutionPlanNode source, List dataSources, IDictionary parameterTypes) + public void ExpandWildcardColumns(NodeCompilationContext context) { - ExpandWildcardColumns(Source, ColumnSet, dataSources, parameterTypes); + ExpandWildcardColumns(Source, ColumnSet, context); } - internal static void ExpandWildcardColumns(IDataExecutionPlanNodeInternal source, List columnSet, IDictionary dataSources, IDictionary parameterTypes) + internal static void ExpandWildcardColumns(IDataExecutionPlanNodeInternal source, List columnSet, NodeCompilationContext context) { // Expand any AllColumns if (columnSet.Any(col => col.AllColumns)) { - var schema = source.GetSchema(dataSources, parameterTypes); + var schema = source.GetSchema(context); var expanded = new List(); foreach (var col in columnSet) @@ -291,7 +291,7 @@ internal static void ExpandWildcardColumns(IDataExecutionPlanNodeInternal source } } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { foreach (var col in ColumnSet.Select(c => c.SourceColumn + (c.AllColumns ? ".*" : ""))) { @@ -299,7 +299,7 @@ public override void AddRequiredColumns(IDictionary dataSour requiredColumns.Add(col); } - Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + Source.AddRequiredColumns(context, requiredColumns); } public override string ToString() diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SortNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SortNode.cs index 01202570..cf939e60 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SortNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SortNode.cs @@ -39,11 +39,13 @@ class SortNode : BaseDataNode, ISingleSourceExecutionPlanNode [Browsable(false)] public IDataExecutionPlanNodeInternal Source { get; set; } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - var source = Source.Execute(dataSources, options, parameterTypes, parameterValues); - var schema = GetSchema(dataSources, parameterTypes); - var expressions = Sorts.Select(sort => sort.Expression.Compile(dataSources[options.PrimaryDataSource], schema, parameterTypes)).ToList(); + var source = Source.Execute(context); + var schema = GetSchema(context); + var expressionCompilationContext = new ExpressionCompilationContext(context, schema, null); + var expressionExecutionContext = new ExpressionExecutionContext(context); + var expressions = Sorts.Select(sort => sort.Expression.Compile(expressionCompilationContext)).ToList(); if (PresortedCount == 0) { @@ -52,18 +54,18 @@ protected override IEnumerable ExecuteInternal(IDictionary sortedSource; if (Sorts[0].SortOrder == SortOrder.Descending) - sortedSource = source.OrderByDescending(e => expressions[0](e, parameterValues, options)); + sortedSource = source.OrderByDescending(e => { expressionExecutionContext.Entity = e; return expressions[0](expressionExecutionContext); }); else - sortedSource = source.OrderBy(e => expressions[0](e, parameterValues, options)); + sortedSource = source.OrderBy(e => { expressionExecutionContext.Entity = e; return expressions[0](expressionExecutionContext); }); for (var i = 1; i < Sorts.Count; i++) { var expr = expressions[i]; if (Sorts[i].SortOrder == SortOrder.Descending) - sortedSource = sortedSource.ThenByDescending(e => expr(e, parameterValues, options)); + sortedSource = sortedSource.ThenByDescending(e => { expressionExecutionContext.Entity = e; return expr(expressionExecutionContext); }); else - sortedSource = sortedSource.ThenBy(e => expr(e, parameterValues, options)); + sortedSource = sortedSource.ThenBy(e => { expressionExecutionContext.Entity = e; return expr(expressionExecutionContext); }); } foreach (var entity in sortedSource) @@ -84,10 +86,12 @@ protected override IEnumerable ExecuteInternal(IDictionary expr(next, parameterValues, options)) + .Select(expr => expr(expressionExecutionContext)) .ToList(); // If we've already got a subset to work on, check if this fits in the same subset @@ -95,7 +99,7 @@ protected override IEnumerable ExecuteInternal(IDictionary ExecuteInternal(IDictionary subset, INodeSchema schema, IDictionary parameterTypes, IDictionary parameterValues, List, IQueryExecutionOptions, object>> expressions, IQueryExecutionOptions options) + private void SortSubset(List subset, ExpressionExecutionContext context, List> expressions) { // Simple case if there's no need to do any further sorting if (subset.Count <= 1) @@ -123,7 +127,13 @@ private void SortSubset(List subset, INodeSchema schema, IDictionary entity, entity => expressions.Skip(PresortedCount).Select(expr => expr(entity, parameterValues, options)).ToList()); + .ToDictionary( + entity => entity, + entity => expressions + .Skip(PresortedCount) + .Select(expr => { context.Entity = entity; return expr(context); }) + .ToList() + ); // Sort the list according to these sort keys subset.Sort((x, y) => @@ -153,9 +163,9 @@ public override IEnumerable GetSources() yield return Source; } - public override INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { - var schema = new NodeSchema(Source.GetSchema(dataSources, parameterTypes)); + var schema = new NodeSchema(Source.GetSchema(context)); var sortOrder = new List(); foreach (var sort in Sorts) @@ -177,9 +187,9 @@ public override INodeSchema GetSchema(IDictionary dataSource sortOrder: sortOrder); } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - Source = Source.FoldQuery(dataSources, options, parameterTypes, hints); + Source = Source.FoldQuery(context, hints); // These sorts will override any previous sort if (Source is SortNode prevSort) @@ -187,10 +197,10 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + private IDataExecutionPlanNodeInternal FoldSorts(NodeCompilationContext context) { PresortedCount = 0; @@ -206,7 +216,7 @@ private IDataExecutionPlanNodeInternal FoldSorts(IDictionary var fetchAggregateSort = new SortNode { Source = tryFetch }; fetchAggregateSort.Sorts.AddRange(Sorts); - var sortedFetchResult = fetchAggregateSort.FoldSorts(dataSources, options, parameterTypes); + var sortedFetchResult = fetchAggregateSort.FoldSorts(context); // If we managed to fold any of the sorts in to the FetchXML, do the same for the non-FetchXML version and remove this node if (sortedFetchResult == tryFetch || (sortedFetchResult == fetchAggregateSort && fetchAggregateSort.PresortedCount > 0)) @@ -219,7 +229,7 @@ private IDataExecutionPlanNodeInternal FoldSorts(IDictionary var nonFetchAggregateSort = new SortNode { Source = tryCatch.CatchSource }; nonFetchAggregateSort.Sorts.AddRange(Sorts); - var sortedNonFetchResult = nonFetchAggregateSort.FoldSorts(dataSources, options, parameterTypes); + var sortedNonFetchResult = nonFetchAggregateSort.FoldSorts(context); tryCatch.CatchSource = sortedNonFetchResult; sortedNonFetchResult.Parent = tryCatch; @@ -246,14 +256,14 @@ private IDataExecutionPlanNodeInternal FoldSorts(IDictionary fetchXml = source as FetchXmlScan; } - if (fetchXml != null && !fetchXml.RequiresCustomPaging(dataSources)) + if (fetchXml != null && !fetchXml.RequiresCustomPaging(context.DataSources)) { - if (!dataSources.TryGetValue(fetchXml.DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(fetchXml.DataSource, out var dataSource)) throw new QueryExecutionException("Missing datasource " + fetchXml.DataSource); fetchXml.RemoveSorts(); - var fetchSchema = fetchXml.GetSchema(dataSources, parameterTypes); + var fetchSchema = fetchXml.GetSchema(context); var entity = fetchXml.Entity; var items = entity.Items; @@ -365,9 +375,11 @@ private IDataExecutionPlanNodeInternal FoldSorts(IDictionary if (top == null && offset == null) return this; + var expressionCompilationContext = new ExpressionCompilationContext(context, null, null); + if (top != null) { - if (!top.Top.IsConstantValueExpression(dataSources[options.PrimaryDataSource], null, options, out var topLiteral)) + if (!top.Top.IsConstantValueExpression(expressionCompilationContext, out var topLiteral)) return this; if (Int32.Parse(topLiteral.Value, CultureInfo.InvariantCulture) > 50000) @@ -375,8 +387,8 @@ private IDataExecutionPlanNodeInternal FoldSorts(IDictionary } else if (offset != null) { - if (!offset.Offset.IsConstantValueExpression(dataSources[options.PrimaryDataSource], null, options, out var offsetLiteral) || - !offset.Fetch.IsConstantValueExpression(dataSources[options.PrimaryDataSource], null, options, out var fetchLiteral)) + if (!offset.Offset.IsConstantValueExpression(expressionCompilationContext, out var offsetLiteral) || + !offset.Fetch.IsConstantValueExpression(expressionCompilationContext, out var fetchLiteral)) return this; if (Int32.Parse(offsetLiteral.Value, CultureInfo.InvariantCulture) + Int32.Parse(fetchLiteral.Value, CultureInfo.InvariantCulture) > 50000) @@ -395,7 +407,7 @@ private IDataExecutionPlanNodeInternal FoldSorts(IDictionary } // Check if the data is already sorted by any prefix of our sorts - var schema = Source.GetSchema(dataSources, parameterTypes); + var schema = Source.GetSchema(context); for (var i = 0; i < Sorts.Count && i < schema.SortOrder.Count; i++) { @@ -441,7 +453,7 @@ private string FindEntityWithAttributeAlias(string alias, object[] items, string return null; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { var sortColumns = Sorts.SelectMany(s => s.Expression.GetColumns()).Distinct(); @@ -451,12 +463,12 @@ public override void AddRequiredColumns(IDictionary dataSour requiredColumns.Add(col); } - Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + Source.AddRequiredColumns(context, requiredColumns); } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { - return Source.EstimateRowsOut(dataSources, options, parameterTypes); + return Source.EstimateRowsOut(context); } protected override IEnumerable GetVariablesInternal() diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlNode.cs index 142fbfbf..44c57902 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlNode.cs @@ -50,11 +50,11 @@ public SqlNode() { } [Browsable(false)] public HashSet Parameters { get; private set; } = new HashSet(StringComparer.OrdinalIgnoreCase); - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { } - public DbDataReader Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, CommandBehavior behavior) + public DbDataReader Execute(NodeExecutionContext context, CommandBehavior behavior) { _executionCount++; @@ -62,10 +62,10 @@ public DbDataReader Execute(IDictionary dataSources, IQueryE { try { - if (!dataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) throw new QueryExecutionException("Missing datasource " + DataSource); - if (options.UseLocalTimeZone) + if (context.Options.UseLocalTimeZone) throw new QueryExecutionException("Cannot use automatic local time zone conversion with the TDS Endpoint"); #if NETCOREAPP @@ -86,9 +86,9 @@ public DbDataReader Execute(IDictionary dataSources, IQueryE var cmd = con.CreateCommand(); cmd.CommandTimeout = (int)TimeSpan.FromMinutes(2).TotalSeconds; - cmd.CommandText = ApplyCommandBehavior(Sql, behavior, options); + cmd.CommandText = ApplyCommandBehavior(Sql, behavior, context.Options); - foreach (var paramValue in parameterValues) + foreach (var paramValue in context.ParameterValues) { if (paramValue.Key.StartsWith("@@")) continue; @@ -107,12 +107,12 @@ public DbDataReader Execute(IDictionary dataSources, IQueryE cmd.Parameters.Add(param); } - options.CancellationToken.Register(() => cmd.Cancel()); + context.Options.CancellationToken.Register(() => cmd.Cancel()); if (Parent == null) { cmd.StatementCompleted += (s, e) => { - parameterValues["@@ROWCOUNT"] = (SqlInt32)e.RecordCount; + context.ParameterValues["@@ROWCOUNT"] = (SqlInt32)e.RecordCount; }; } @@ -207,7 +207,7 @@ internal static string ApplyCommandBehavior(string sql, CommandBehavior behavior return script.ToSql(); } - public IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext context, IList hints) { return new[] { this }; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlTypeConverter.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlTypeConverter.cs index 6fae711b..4c4907ba 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlTypeConverter.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlTypeConverter.cs @@ -53,7 +53,7 @@ class SqlTypeConverter private static ConcurrentDictionary> _conversions; private static ConcurrentDictionary> _sqlConversions; private static Dictionary _netToSqlTypeConversions; - private static Dictionary> _netToSqlTypeConversionFuncs; + private static Dictionary> _netToSqlTypeConversionFuncs; private static Dictionary _sqlToNetTypeConversions; private static Dictionary> _sqlToNetTypeConversionFuncs; @@ -89,7 +89,7 @@ static SqlTypeConverter() _sqlConversions = new ConcurrentDictionary>(); _netToSqlTypeConversions = new Dictionary(); - _netToSqlTypeConversionFuncs = new Dictionary>(); + _netToSqlTypeConversionFuncs = new Dictionary>(); _sqlToNetTypeConversions = new Dictionary(); _sqlToNetTypeConversionFuncs = new Dictionary>(); @@ -105,7 +105,7 @@ static SqlTypeConverter() AddTypeConversion((ds, v, dt) => v, v => v.Value); AddTypeConversion((ds, v, dt) => v, v => v.Value); AddTypeConversion((ds, v, dt) => v, v => v.Value); - AddNullableTypeConversion((ds, v, dt) => UseDefaultCollation(v), v => v.Value); + AddNullableTypeConversion((ds, v, dt) => ((SqlDataTypeReferenceWithCollation)dt).Collation.ToSqlString(v), v => v.Value); AddTypeConversion((ds, v, dt) => v, v => v.Value); AddTypeConversion((ds, v, dt) => (SqlDateTime)v, v => v.Value); AddTypeConversion((ds, v, dt) => (SqlDateTime)v, v => v.Value); @@ -114,12 +114,12 @@ static SqlTypeConverter() AddNullableTypeConversion((ds, v, dt) => v.Value, null); AddNullableTypeConversion((ds, v, dt) => v.Value, null); - AddNullableTypeConversion((ds, v, dt) => UseDefaultCollation(String.Join(",", v.Select(osv => osv.Value))), null); - AddNullableTypeConversion((ds, v, dt) => UseDefaultCollation(String.Join(",", v.Entities.Select(e => FormatEntityCollectionEntry(e)))), null); - AddNullableTypeConversion((ds, v, dt) => new SqlEntityReference(ds, v), v => v); + AddNullableTypeConversion((ds, v, dt) => ds.DefaultCollation.ToSqlString(String.Join(",", v.Select(osv => osv.Value))), null); + AddNullableTypeConversion((ds, v, dt) => ds.DefaultCollation.ToSqlString(String.Join(",", v.Entities.Select(e => FormatEntityCollectionEntry(e)))), null); + AddNullableTypeConversion((ds, v, dt) => new SqlEntityReference(ds.Name, v), v => v); } - private static void AddTypeConversion(Func netToSql, Func sqlToNet) + private static void AddTypeConversion(Func netToSql, Func sqlToNet) where TSql: INullable where TNet: struct { @@ -153,7 +153,7 @@ private static void AddTypeConversion(Func(Func netToSql, Func sqlToNet) + private static void AddNullableTypeConversion(Func netToSql, Func sqlToNet) where TSql : INullable { if (netToSql != null) @@ -435,6 +435,7 @@ public static bool CanChangeTypeExplicit(DataTypeReference from, DataTypeReferen /// /// The expression that generates the values to convert /// The type to convert to + /// The expression which contains the the expression will be evaluated in /// An expression to generate values of the required type public static Expression Convert(Expression expr, Type to) { @@ -493,13 +494,26 @@ public static Expression Convert(Expression expr, Type to) expr = Expression.Convert(expr, to); - if (to == typeof(SqlString)) - expr = Expr.Call(() => UseDefaultCollation(Expr.Arg()), expr); + //if (to == typeof(SqlString)) + // expr = Expr.Call(() => ApplyCollation(Expr.Arg(), Expr.Arg()), Expression.Constant(collation), expr); } return expr; } + private static SqlString ApplyCollation(ExpressionExecutionContext context, SqlString sqlString) + { + return ApplyCollation(context.PrimaryDataSource.DefaultCollation, sqlString); + } + + private static SqlString ApplyCollation(Collation collation, SqlString sqlString) + { + if (sqlString.IsNull) + return sqlString; + + return collation.ToSqlString(sqlString.Value); + } + /// /// Produces the required expression to convert values to a specific type /// @@ -537,12 +551,14 @@ public static Expression Convert(Expression expr, DataTypeReference from, DataTy if (!CanChangeTypeImplicit(styleType, DataTypeHelpers.Int)) throw new NotSupportedQueryFragmentException($"No type conversion available from {styleType.ToSql()} to {DataTypeHelpers.Int.ToSql()}", convert.Style); + var targetCollation = (to as SqlDataTypeReferenceWithCollation)?.Collation; + if (fromSqlType != null && (fromSqlType.SqlDataTypeOption.IsDateTimeType() || fromSqlType.SqlDataTypeOption == SqlDataTypeOption.Date || fromSqlType.SqlDataTypeOption == SqlDataTypeOption.Time) && targetType == typeof(SqlString)) - expr = Expr.Call(() => Convert(Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg()), Convert(expr, typeof(SqlDateTime)), Expression.Constant(fromSqlType.SqlDataTypeOption != SqlDataTypeOption.Time), Expression.Constant(fromSqlType.SqlDataTypeOption != SqlDataTypeOption.Date), Expression.Constant(from.GetScale()), Expression.Constant(from), Expression.Constant(to), style); + expr = Expr.Call(() => Convert(Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg()), Convert(expr, typeof(SqlDateTime)), Expression.Constant(fromSqlType.SqlDataTypeOption != SqlDataTypeOption.Time), Expression.Constant(fromSqlType.SqlDataTypeOption != SqlDataTypeOption.Date), Expression.Constant(from.GetScale()), Expression.Constant(from), Expression.Constant(to), style, Expression.Constant(targetCollation)); else if ((expr.Type == typeof(SqlDouble) || expr.Type == typeof(SqlSingle)) && targetType == typeof(SqlString)) - expr = Expr.Call(() => Convert(Expr.Arg(), Expr.Arg()), expr, style); + expr = Expr.Call(() => Convert(Expr.Arg(), Expr.Arg(), Expr.Arg()), expr, style, Expression.Constant(targetCollation)); else if (expr.Type == typeof(SqlMoney) && targetType == typeof(SqlString)) - expr = Expr.Call(() => Convert(Expr.Arg(), Expr.Arg()), expr, style); + expr = Expr.Call(() => Convert(Expr.Arg(), Expr.Arg(), Expr.Arg()), expr, style, Expression.Constant(targetCollation)); if (expr.Type != targetType) expr = Convert(expr, targetType); @@ -639,12 +655,12 @@ private static SqlString Truncate(SqlString value, int maxLength, string valueOn return value; if (valueOnTruncate != null) - return SqlTypeConverter.UseDefaultCollation(new SqlString(valueOnTruncate)); + return new SqlString(valueOnTruncate, value.LCID, value.SqlCompareOptions); if (exceptionOnTruncate != null) throw exceptionOnTruncate; - return SqlTypeConverter.UseDefaultCollation(new SqlString(value.Value.Substring(0, maxLength))); + return new SqlString(value.Value.Substring(0, maxLength), value.LCID, value.SqlCompareOptions); } private static EntityCollection ParseEntityCollection(SqlString value) @@ -702,9 +718,12 @@ private static OptionSetValueCollection ParseOptionSetValueCollection(SqlString /// Indicates if the date part should be included /// Indicates if the time part should be included /// The scale of the fractional seconds part + /// The original SQL type that is being converted from + /// The original SQL type that is being converted to /// The style to apply + /// The collation to use for the returned result /// The converted string - private static SqlString Convert(SqlDateTime value, bool date, bool time, int timeScale, DataTypeReference fromType, DataTypeReference toType, SqlInt32 style) + private static SqlString Convert(SqlDateTime value, bool date, bool time, int timeScale, DataTypeReference fromType, DataTypeReference toType, SqlInt32 style, Collation collation) { if (value.IsNull || style.IsNull) return SqlString.Null; @@ -892,7 +911,7 @@ private static SqlString Convert(SqlDateTime value, bool date, bool time, int ti formatString += timeFormatString; var formatted = value.Value.ToString(formatString, cultureInfo); - return UseDefaultCollation(formatted); + return collation.ToSqlString(formatted); } /// @@ -900,8 +919,9 @@ private static SqlString Convert(SqlDateTime value, bool date, bool time, int ti /// /// The value to convert /// The style to apply + /// The collation to use for the returned result /// The converted string - public static SqlString Convert(SqlDouble value, SqlInt32 style) + public static SqlString Convert(SqlDouble value, SqlInt32 style, Collation collation) { if (value.IsNull || style.IsNull) return SqlString.Null; @@ -928,7 +948,7 @@ public static SqlString Convert(SqlDouble value, SqlInt32 style) } var formatted = value.Value.ToString(formatString); - return UseDefaultCollation(formatted); + return collation.ToSqlString(formatted); } /// @@ -936,8 +956,9 @@ public static SqlString Convert(SqlDouble value, SqlInt32 style) /// /// The value to convert /// The style to apply + /// The collation to use for the returned result /// The converted string - public static SqlString Convert(SqlMoney value, SqlInt32 style) + public static SqlString Convert(SqlMoney value, SqlInt32 style, Collation collation) { if (value.IsNull || style.IsNull) return SqlString.Null; @@ -961,23 +982,7 @@ public static SqlString Convert(SqlMoney value, SqlInt32 style) } var formatted = value.Value.ToString(formatString); - return UseDefaultCollation(formatted); - } - - /// - /// Converts a value to the default collation - /// - /// The value to convert - /// A value using the default collation - public static SqlString UseDefaultCollation(SqlString value) - { - if (value.IsNull) - return value; - - if (value.LCID == CultureInfo.CurrentCulture.LCID && value.SqlCompareOptions == (SqlCompareOptions.IgnoreCase | SqlCompareOptions.IgnoreWidth)) - return value; - - return new SqlString((string)value, CultureInfo.CurrentCulture.LCID, SqlCompareOptions.IgnoreCase | SqlCompareOptions.IgnoreNonSpace); + return collation.ToSqlString(formatted); } /// @@ -1013,11 +1018,11 @@ public static Type NetToSqlType(Type type) /// /// Converts a value from a CLR type to the equivalent SQL type. /// - /// The name of the data source the was obtained from + /// The data source the was obtained from /// The value in a standard CLR type /// The expected data type /// The value converted to a SQL type - public static INullable NetToSqlType(string dataSource, object value, DataTypeReference dataType) + public static INullable NetToSqlType(DataSource dataSource, object value, DataTypeReference dataType) { var type = value.GetType(); @@ -1039,7 +1044,7 @@ public static INullable NetToSqlType(string dataSource, object value, DataTypeRe } // Convert any other complex types (e.g. from metadata queries) to strings - func = (_, v, __) => UseDefaultCollation(v.ToString()); + func = (ds, v, __) => ds.DefaultCollation.ToSqlString(v.ToString()); _netToSqlTypeConversionFuncs[originalType] = func; return func(dataSource, value, dataType); } @@ -1110,6 +1115,7 @@ public static T ChangeType(object value) /// /// Converts a value from one type to another /// + /// The context in which the conversion is being performed /// The value to convert /// The type to convert the value to /// The value converted to the requested type @@ -1164,8 +1170,8 @@ private static Func CompileConversion(Type sourceType, Type dest expression = Expression.Convert(expression, destType); } - if (destType == typeof(SqlString)) - expression = Expr.Call(() => UseDefaultCollation(Expr.Arg()), expression); + //if (destType == typeof(SqlString)) + // expression = Expr.Call(() => ApplyCollation(Expr.Arg(), Expr.Arg()), contextParam, expression); expression = Expression.Convert(expression, typeof(object)); return Expression.Lambda>(expression, param).Compile(); @@ -1180,6 +1186,15 @@ private static Func CompileConversion(Type sourceType, Type dest public static Func GetConversion(DataTypeReference sourceType, DataTypeReference destType) { var key = sourceType.ToSql() + " -> " + destType.ToSql(); + + if (destType is SqlDataTypeReferenceWithCollation collation) + { + if (!String.IsNullOrEmpty(collation.Collation.Name)) + key += " COLLATE " + collation.Collation.Name; + else + key += " COLLATE " + collation.Collation.LCID + ":" + collation.Collation.CompareOptions; + } + return _sqlConversions.GetOrAdd(key, _ => CompileConversion(sourceType, destType)); } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/StreamAggregateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/StreamAggregateNode.cs index aa2d4ea9..dfbaa286 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/StreamAggregateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/StreamAggregateNode.cs @@ -11,16 +11,16 @@ namespace MarkMpn.Sql4Cds.Engine.ExecutionPlan /// class StreamAggregateNode : BaseAggregateNode { - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - Source = Source.FoldQuery(dataSources, options, parameterTypes, hints); + Source = Source.FoldQuery(context, hints); Source.Parent = this; return this; } - public override INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { - var schema = base.GetSchema(dataSources, parameterTypes); + var schema = base.GetSchema(context); var groupByCols = GetGroupingColumns(schema); return new NodeSchema( @@ -31,20 +31,22 @@ public override INodeSchema GetSchema(IDictionary dataSource sortOrder: groupByCols); } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { - var schema = Source.GetSchema(dataSources, parameterTypes); + var schema = Source.GetSchema(context); var groupByCols = GetGroupingColumns(schema); + var expressionCompilationContext = new ExpressionCompilationContext(context, schema, null); + var expressionExecutionContext = new ExpressionExecutionContext(context); var isScalarAggregate = IsScalarAggregate; - InitializeAggregates(schema, parameterTypes); + InitializeAggregates(expressionCompilationContext); Entity currentGroup = null; var comparer = new DistinctEqualityComparer(groupByCols); - var aggregates = CreateAggregateFunctions(parameterValues, options, false); + var aggregates = CreateAggregateFunctions(expressionExecutionContext, false); var states = isScalarAggregate ? ResetAggregates(aggregates) : null; - foreach (var entity in Source.Execute(dataSources, options, parameterTypes, parameterValues)) + foreach (var entity in Source.Execute(context)) { if (!isScalarAggregate || currentGroup != null) { @@ -73,8 +75,10 @@ protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + internal int GetCount(NodeExecutionContext context) { if (_eagerSpool == null) - _eagerSpool = Source.Execute(dataSources, options, parameterTypes, parameterValues).ToArray(); + _eagerSpool = Source.Execute(context).ToArray(); return _eagerSpool.Length; } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { if (SpoolType == SpoolType.Eager) { if (_eagerSpool == null) - _eagerSpool = Source.Execute(dataSources, options, parameterTypes, parameterValues).ToArray(); + _eagerSpool = Source.Execute(context).ToArray(); return _eagerSpool; } else { if (_lazyCache == null) - _lazyCache = new CachedList(Source.Execute(dataSources, options, parameterTypes, parameterValues)); + _lazyCache = new CachedList(Source.Execute(context)); return _lazyCache; } @@ -127,14 +127,14 @@ public override IEnumerable GetSources() yield return Source; } - public override INodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { - return Source.GetSchema(dataSources, parameterTypes); + return Source.GetSchema(context); } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - Source = Source.FoldQuery(dataSources, options, parameterTypes, hints); + Source = Source.FoldQuery(context, hints); if (hints != null && hints.Any(hint => hint.HintKind == OptimizerHintKind.NoPerformanceSpool)) return Source; @@ -143,14 +143,14 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { - Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + Source.AddRequiredColumns(context, requiredColumns); } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { - return Source.EstimateRowsOut(dataSources, options, parameterTypes); + return Source.EstimateRowsOut(context); } public override string ToString() diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/TopNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/TopNode.cs index 995581ba..bee2d57e 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/TopNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/TopNode.cs @@ -47,15 +47,18 @@ class TopNode : BaseDataNode, ISingleSourceExecutionPlanNode [Browsable(false)] public IDataExecutionPlanNodeInternal Source { get; set; } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { int topCount; - Top.GetType(dataSources[options.PrimaryDataSource], null, null, parameterTypes, out var topType); + + var expressionCompilationContext = new ExpressionCompilationContext(context, null, null); + var expressionExecutionContext = new ExpressionExecutionContext(context); + Top.GetType(expressionCompilationContext, out var topType); if (Percent) { var top = new ConvertCall { Parameter = Top, DataType = DataTypeHelpers.Float }; - var topPercent = (SqlDouble)top.Compile(dataSources[options.PrimaryDataSource], null, parameterTypes)(null, parameterValues, options); + var topPercent = (SqlDouble)top.Compile(expressionCompilationContext)(expressionExecutionContext); if (topPercent.IsNull) { @@ -66,9 +69,9 @@ protected override IEnumerable ExecuteInternal(IDictionary ExecuteInternal(IDictionary ExecuteInternal(IDictionary { if (index == topCount - 1) @@ -110,9 +113,9 @@ protected override IEnumerable ExecuteInternal(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { - return Source.GetSchema(dataSources, parameterTypes); + return Source.GetSchema(context); } public override IEnumerable GetSources() @@ -120,12 +123,14 @@ public override IEnumerable GetSources() yield return Source; } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - Source = Source.FoldQuery(dataSources, options, parameterTypes, hints); + Source = Source.FoldQuery(context, hints); Source.Parent = this; - if (!Top.IsConstantValueExpression(dataSources[options.PrimaryDataSource], null, options, out var literal)) + var expressionCompilationContext = new ExpressionCompilationContext(context, null, null); + + if (!Top.IsConstantValueExpression(expressionCompilationContext, out var literal)) return this; // FetchXML can support TOP directly provided it's for no more than 5,000 records @@ -155,23 +160,25 @@ public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { - Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + Source.AddRequiredColumns(context, requiredColumns); } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { - var sourceCount = Source.EstimateRowsOut(dataSources, options, parameterTypes); + var sourceCount = Source.EstimateRowsOut(context); + + var expressionCompilationContext = new ExpressionCompilationContext(context, null, null); - if (!Top.IsConstantValueExpression(dataSources[options.PrimaryDataSource], null, options, out var topLiteral)) + if (!Top.IsConstantValueExpression(expressionCompilationContext, out var topLiteral)) return sourceCount; var top = Decimal.Parse(topLiteral.Value, CultureInfo.InvariantCulture); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/TryCatchNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/TryCatchNode.cs index 1606b499..74f01da6 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/TryCatchNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/TryCatchNode.cs @@ -31,14 +31,14 @@ class TryCatchNode : BaseDataNode [Description("The error generated by the Try branch that caused execution to move to the Catch branch")] public string CaughtException { get; set; } - protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { var useCatchSource = false; IEnumerator enumerator; try { - enumerator = TrySource.Execute(dataSources, options, parameterTypes, parameterValues).GetEnumerator(); + enumerator = TrySource.Execute(context).GetEnumerator(); } catch (Exception ex) { @@ -52,7 +52,7 @@ protected override IEnumerable ExecuteInternal(IDictionary ExecuteInternal(IDictionary dataSources, IDictionary parameterTypes) + public override INodeSchema GetSchema(NodeCompilationContext context) { - var trySchema = TrySource.GetSchema(dataSources, parameterTypes); - var catchSchema = CatchSource.GetSchema(dataSources, parameterTypes); + var trySchema = TrySource.GetSchema(context); + var catchSchema = CatchSource.GetSchema(context); // Columns should be the same but sort order may be different if (trySchema.SortOrder.SequenceEqual(catchSchema.SortOrder, StringComparer.OrdinalIgnoreCase)) @@ -112,24 +112,24 @@ public override IEnumerable GetSources() yield return CatchSource; } - public override IDataExecutionPlanNodeInternal FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { - TrySource = TrySource.FoldQuery(dataSources, options, parameterTypes, hints); + TrySource = TrySource.FoldQuery(context, hints); TrySource.Parent = this; - CatchSource = CatchSource.FoldQuery(dataSources, options, parameterTypes, hints); + CatchSource = CatchSource.FoldQuery(context, hints); CatchSource.Parent = this; return this; } - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { - TrySource.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); - CatchSource.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + TrySource.AddRequiredColumns(context, requiredColumns); + CatchSource.AddRequiredColumns(context, requiredColumns); } - protected override RowCountEstimate EstimateRowsOutInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) { - return TrySource.EstimateRowsOut(dataSources, options, parameterTypes); + return TrySource.EstimateRowsOut(context); } public override object Clone() diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/UpdateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/UpdateNode.cs index cfd58197..fb77ceca 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/UpdateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/UpdateNode.cs @@ -60,7 +60,7 @@ class UpdateNode : BaseDmlNode [DisplayName("State Transitions")] public IDictionary StateTransitionsDisplay => StateTransitions == null ? null : StateTransitions.Values.ToDictionary(s => $"{s.Name} ({s.StatusCode})", s => new Transitions(s.Transitions.Keys.Select(t => $"{t.Name} ({t.StatusCode})").OrderBy(n => n))); - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { if (!requiredColumns.Contains(PrimaryIdSource)) requiredColumns.Add(PrimaryIdSource); @@ -74,16 +74,16 @@ public override void AddRequiredColumns(IDictionary dataSour requiredColumns.Add(col.NewValueColumn); } - Source.AddRequiredColumns(dataSources, parameterTypes, requiredColumns); + Source.AddRequiredColumns(context, requiredColumns); } - public override string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, out int recordsAffected) + public override string Execute(NodeExecutionContext context, out int recordsAffected) { _executionCount++; try { - if (!dataSources.TryGetValue(DataSource, out var dataSource)) + if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) throw new QueryExecutionException("Missing datasource " + DataSource); List entities; @@ -95,12 +95,12 @@ public override string Execute(IDictionary dataSources, IQue using (_timer.Run()) { - entities = GetDmlSourceEntities(dataSources, options, parameterTypes, parameterValues, out var schema); + entities = GetDmlSourceEntities(context, out var schema); // Precompile mappings with type conversions meta = dataSource.Metadata[LogicalName]; attributes = meta.Attributes.ToDictionary(a => a.LogicalName, StringComparer.OrdinalIgnoreCase); - var dateTimeKind = options.UseLocalTimeZone ? DateTimeKind.Local : DateTimeKind.Utc; + var dateTimeKind = context.Options.UseLocalTimeZone ? DateTimeKind.Local : DateTimeKind.Utc; var fullMappings = new Dictionary(ColumnMappings); fullMappings[meta.PrimaryIdAttribute] = new UpdateMapping { OldValueColumn = PrimaryIdSource, NewValueColumn = PrimaryIdSource }; newAttributeAccessors = CompileColumnMappings(dataSource, LogicalName, fullMappings.Where(kvp => kvp.Value.NewValueColumn != null).ToDictionary(kvp => kvp.Key, kvp => kvp.Value.NewValueColumn), schema, dateTimeKind, entities); @@ -110,9 +110,9 @@ public override string Execute(IDictionary dataSources, IQue // Check again that the update is allowed. Don't count any UI interaction in the execution time var confirmArgs = new ConfirmDmlStatementEventArgs(entities.Count, meta, BypassCustomPluginExecution); - if (options.CancellationToken.IsCancellationRequested) + if (context.Options.CancellationToken.IsCancellationRequested) confirmArgs.Cancel = true; - options.ConfirmUpdate(confirmArgs); + context.Options.ConfirmUpdate(confirmArgs); if (confirmArgs.Cancel) throw new OperationCanceledException("UPDATE cancelled by user"); @@ -158,7 +158,7 @@ public override string Execute(IDictionary dataSources, IQue { return ExecuteDmlOperation( dataSource.Connection, - options, + context.Options, entities, meta, entity => @@ -237,7 +237,7 @@ public override string Execute(IDictionary dataSources, IQue CompletedLowercase = "updated" }, out recordsAffected, - parameterValues); + context.ParameterValues); } } catch (QueryExecutionException ex) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/WaitForNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/WaitForNode.cs index 3b4061cb..f6bcb962 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/WaitForNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/WaitForNode.cs @@ -17,7 +17,7 @@ namespace MarkMpn.Sql4Cds.Engine.ExecutionPlan class WaitForNode : BaseNode, IDmlQueryExecutionPlanNode { private int _executionCount; - private Func, IQueryExecutionOptions, object> _timeExpr; + private Func _timeExpr; private readonly Timer _timer = new Timer(); [Category("Wait")] @@ -51,11 +51,11 @@ class WaitForNode : BaseNode, IDmlQueryExecutionPlanNode public override TimeSpan Duration => _timer.Duration; - public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { } - public string Execute(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, out int recordsAffected) + public string Execute(NodeExecutionContext context, out int recordsAffected) { _executionCount++; @@ -64,9 +64,9 @@ public string Execute(IDictionary dataSources, IQueryExecuti using (_timer.Run()) { if (_timeExpr == null) - _timeExpr = Time.Compile(dataSources[options.PrimaryDataSource], null, parameterTypes); + _timeExpr = Time.Compile(new ExpressionCompilationContext(context, null, null)); - var time = (SqlTime) _timeExpr(null, parameterValues, options); + var time = (SqlTime) _timeExpr(new ExpressionExecutionContext(context)); if (time.IsNull) { @@ -84,8 +84,8 @@ public string Execute(IDictionary dataSources, IQueryExecuti delay = delay + TimeSpan.FromDays(1) - DateTime.Now.TimeOfDay; } - options.Progress(null, $"Waiting for {delay}..."); - options.CancellationToken.WaitHandle.WaitOne(delay); + context.Options.Progress(null, $"Waiting for {delay}..."); + context.Options.CancellationToken.WaitHandle.WaitOne(delay); } recordsAffected = -1; @@ -105,7 +105,7 @@ public string Execute(IDictionary dataSources, IQueryExecuti } } - public IRootExecutionPlanNodeInternal[] FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IList hints) + public IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext context, IList hints) { return new[] { this }; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs index 1d2a1162..c89ff4e6 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs @@ -20,11 +20,8 @@ class ExecutionPlanBuilder { private int _colNameCounter; private IDictionary _parameterTypes; - - public ExecutionPlanBuilder(IAttributeMetadataCache metadata, ITableSizeCache tableSize, IMessageCache messageCache, IQueryExecutionOptions options) - : this(new[] { new DataSource { Name = "local", Metadata = metadata, TableSizeCache = tableSize, MessageCache = messageCache } }, options) - { - } + private ExpressionCompilationContext _staticContext; + private NodeCompilationContext _nodeContext; public ExecutionPlanBuilder(IEnumerable dataSources, IQueryExecutionOptions options) { @@ -45,7 +42,7 @@ public ExecutionPlanBuilder(IEnumerable dataSources, IQueryExecution /// /// Indicates how the query will be executed /// - public IQueryExecutionOptions Options { get; set; } + public IQueryExecutionOptions Options { get; } /// /// Indicates if only a simplified plan for display purposes is required @@ -66,6 +63,8 @@ public IRootExecutionPlanNode[] Build(string sql, IDictionary(StringComparer.OrdinalIgnoreCase); + _staticContext = new ExpressionCompilationContext(DataSources, Options, _parameterTypes, null, null); + _nodeContext = new NodeCompilationContext(DataSources, Options, _parameterTypes); if (parameters != null) { @@ -245,8 +244,8 @@ private IRootExecutionPlanNodeInternal[] ConvertExecuteStatement(ExecuteStatemen var dataSource = SelectDataSource(sproc.ProcedureReference.ProcedureReference.Name); - var node = ExecuteMessageNode.FromMessage(sproc, dataSource, PrimaryDataSource, _parameterTypes); - var schema = node.GetSchema(DataSources, _parameterTypes); + var node = ExecuteMessageNode.FromMessage(sproc, dataSource, _staticContext); + var schema = node.GetSchema(new NodeCompilationContext(DataSources, Options, _parameterTypes)); dataSource.MessageCache.TryGetValue(node.MessageName, out var message); @@ -355,7 +354,7 @@ private IRootExecutionPlanNodeInternal ConvertWaitForStatement(WaitForStatement if (waitFor.WaitForOption == WaitForOption.Statement) throw new NotSupportedQueryFragmentException("WAITFOR is not supported", waitFor); - waitFor.Parameter.GetType(PrimaryDataSource, null, null, _parameterTypes, out var paramSqlType); + waitFor.Parameter.GetType(_staticContext, out var paramSqlType); var timeType = DataTypeHelpers.Time(3); if (!SqlTypeConverter.CanChangeTypeImplicit(paramSqlType, timeType)) @@ -410,7 +409,7 @@ private IRootExecutionPlanNodeInternal ConvertPrintStatement(PrintStatement prin // Check the expression for errors. Ensure it can be converted to a string var expr = print.Expression.Clone(); - if (expr.GetType(PrimaryDataSource, null, null, _parameterTypes, out _) != typeof(SqlString)) + if (expr.GetType(_staticContext, out _) != typeof(SqlString)) { expr = new ConvertCall { @@ -418,7 +417,7 @@ private IRootExecutionPlanNodeInternal ConvertPrintStatement(PrintStatement prin Parameter = expr }; - expr.GetType(PrimaryDataSource, null, null, _parameterTypes, out _); + expr.GetType(_staticContext, out _); } return new PrintNode @@ -438,7 +437,7 @@ private IRootExecutionPlanNodeInternal ConvertIfWhileStatement(ConditionalNodeTy if (subqueryVisitor.Subqueries.Count == 0) { // Check the predicate for errors - predicate.GetType(PrimaryDataSource, null, null, _parameterTypes, out _); + predicate.GetType(_staticContext, out _); } else { @@ -863,7 +862,7 @@ private InsertNode ConvertInsertSpecification(NamedTableReference target, IList< var attributes = metadata.Attributes.ToDictionary(attr => attr.LogicalName, StringComparer.OrdinalIgnoreCase); var attributeNames = new HashSet(StringComparer.OrdinalIgnoreCase); var virtualTypeAttributes = new HashSet(StringComparer.OrdinalIgnoreCase); - var schema = sourceColumns == null ? null : ((IDataExecutionPlanNodeInternal)source).GetSchema(DataSources, _parameterTypes); + var schema = sourceColumns == null ? null : ((IDataExecutionPlanNodeInternal)source).GetSchema(_nodeContext); // Check all target columns are valid for create foreach (var col in targetColumns) @@ -987,7 +986,7 @@ attr is LookupAttributeMetadata lookupAttr && if (targetLookupAttribute.Targets.Length > 1 && !virtualTypeAttributes.Contains(targetAttrName + "type") && targetLookupAttribute.AttributeType != AttributeTypeCode.PartyList && - (schema == null || (node.ColumnMappings[targetAttrName].ToColumnReference().GetType(PrimaryDataSource, schema, null, null, out var lookupType) != typeof(SqlEntityReference) && lookupType != DataTypeHelpers.ImplicitIntForNullLiteral))) + (schema == null || (node.ColumnMappings[targetAttrName].ToColumnReference().GetType(GetExpressionContext(schema), out var lookupType) != typeof(SqlEntityReference) && lookupType != DataTypeHelpers.ImplicitIntForNullLiteral))) { // Special case: not required for listmember.entityid if (metadata.LogicalName == "listmember" && targetLookupAttribute.LogicalName == "entityid") @@ -1429,7 +1428,8 @@ private UpdateNode ConvertSetClause(IList setClauses, HashSet update.Source = select.Source; update.PrimaryIdSource = $"{targetAlias}.{targetMetadata.PrimaryIdAttribute}"; - var schema = select.Source.GetSchema(DataSources, _parameterTypes); + var schema = select.Source.GetSchema(_nodeContext); + var expressionContext = new ExpressionCompilationContext(DataSources, Options, _parameterTypes, schema, null); foreach (var assignment in setClauses.Cast()) { @@ -1461,7 +1461,7 @@ private UpdateNode ConvertSetClause(IList setClauses, HashSet var sourceColName = select.ColumnSet.Single(col => col.OutputColumn == "new_" + targetAttrName.ToLower()).SourceColumn; var sourceCol = sourceColName.ToColumnReference(); - sourceCol.GetType(PrimaryDataSource, schema, null, null, out var sourceType); + sourceCol.GetType(expressionContext, out var sourceType); if (!SqlTypeConverter.CanChangeTypeImplicit(sourceType, targetType)) throw new NotSupportedQueryFragmentException($"Cannot convert value of type {sourceType.ToSql()} to {targetType.ToSql()}", assignment); @@ -1797,7 +1797,7 @@ private SelectNode ConvertSelectQuerySpec(QuerySpecification querySpec, IList(); - var preOrderSchema = node.GetSchema(DataSources, parameterTypes); + var preOrderSchema = node.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); foreach (var el in querySpec.SelectElements) { if (el is SelectScalarExpression expr) @@ -1859,12 +1859,12 @@ private IDataExecutionPlanNodeInternal ConvertInSubqueries(IDataExecutionPlanNod var computeScalar = source as ComputeScalarNode; var rewrites = new Dictionary(); - var schema = source.GetSchema(DataSources, parameterTypes); + var schema = source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); foreach (var inSubquery in visitor.InSubqueries) { // Validate the LHS expression - inSubquery.Expression.GetType(PrimaryDataSource, schema, null, parameterTypes, out _); + inSubquery.Expression.GetType(GetExpressionContext(schema, parameterTypes), out _); // Each query of the format "col1 IN (SELECT col2 FROM source)" becomes a left outer join: // LEFT JOIN source ON col1 = col2 @@ -1912,7 +1912,7 @@ private IDataExecutionPlanNodeInternal ConvertInSubqueries(IDataExecutionPlanNod else { // We need the inner list to be distinct to avoid creating duplicates during the join - var innerSchema = innerQuery.Source.GetSchema(DataSources, parameters); + var innerSchema = innerQuery.Source.GetSchema(new NodeCompilationContext(DataSources, Options, parameters)); if (innerQuery.ColumnSet[0].SourceColumn != innerSchema.PrimaryKey && !(innerQuery.Source is DistinctNode)) { innerQuery.Source = new DistinctNode @@ -2000,7 +2000,7 @@ private IDataExecutionPlanNodeInternal ConvertExistsSubqueries(IDataExecutionPla return source; var rewrites = new Dictionary(); - var schema = source.GetSchema(DataSources, parameterTypes); + var schema = source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); foreach (var existsSubquery in visitor.ExistsSubqueries) { @@ -2008,7 +2008,7 @@ private IDataExecutionPlanNodeInternal ConvertExistsSubqueries(IDataExecutionPla var parameters = parameterTypes == null ? new Dictionary(StringComparer.OrdinalIgnoreCase) : new Dictionary(parameterTypes, StringComparer.OrdinalIgnoreCase); var references = new Dictionary(); var innerQuery = ConvertSelectStatement(existsSubquery.Subquery.QueryExpression, hints, schema, references, parameters); - var innerSchema = innerQuery.Source.GetSchema(DataSources, parameters); + var innerSchema = innerQuery.Source.GetSchema(new NodeCompilationContext(DataSources, Options, parameters)); var innerSchemaPrimaryKey = innerSchema.PrimaryKey; // Create the join @@ -2141,7 +2141,7 @@ private IDataExecutionPlanNodeInternal ConvertHavingClause(IDataExecutionPlanNod ConvertScalarSubqueries(havingClause.SearchCondition, hints, ref source, computeScalar, parameterTypes, query); // Validate the final expression - havingClause.SearchCondition.GetType(PrimaryDataSource, source.GetSchema(DataSources, parameterTypes), nonAggregateSchema, parameterTypes, out _); + havingClause.SearchCondition.GetType(GetExpressionContext(source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)), parameterTypes, nonAggregateSchema), out _); return new FilterNode { @@ -2169,7 +2169,7 @@ private IDataExecutionPlanNodeInternal ConvertGroupByAggregates(IDataExecutionPl throw new NotSupportedQueryFragmentException("Unhandled GROUP BY option", querySpec.GroupByClause); } - var schema = source.GetSchema(DataSources, parameterTypes); + var schema = source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); // Create the grouping expressions. Grouping is done on single columns only - if a grouping is a more complex expression, // create a new calculated column using a Compute Scalar node first. @@ -2185,7 +2185,7 @@ private IDataExecutionPlanNodeInternal ConvertGroupByAggregates(IDataExecutionPl throw new NotSupportedQueryFragmentException("Unhandled GROUP BY expression", grouping); // Validate the GROUP BY expression - exprGroup.Expression.GetType(PrimaryDataSource, schema, null, parameterTypes, out _); + exprGroup.Expression.GetType(GetExpressionContext(schema, parameterTypes), out _); if (exprGroup.Expression is ColumnReferenceExpression col) { @@ -2341,7 +2341,7 @@ private IDataExecutionPlanNodeInternal ConvertGroupByAggregates(IDataExecutionPl if (converted.AggregateType == AggregateType.CountStar) converted.SqlExpression = null; else - converted.SqlExpression.GetType(PrimaryDataSource, schema, null, parameterTypes, out _); + converted.SqlExpression.GetType(GetExpressionContext(schema, parameterTypes), out _); // Create a name for the column that holds the aggregate value in the result set. string aggregateName; @@ -2420,8 +2420,8 @@ private IDataExecutionPlanNodeInternal ConvertOffsetClause(IDataExecutionPlanNod if (offsetClause == null) return source; - offsetClause.OffsetExpression.GetType(PrimaryDataSource, null, null, parameterTypes, out var offsetType); - offsetClause.FetchExpression.GetType(PrimaryDataSource, null, null, parameterTypes, out var fetchType); + offsetClause.OffsetExpression.GetType(_staticContext, out var offsetType); + offsetClause.FetchExpression.GetType(_staticContext, out var fetchType); var intType = DataTypeHelpers.Int; if (!SqlTypeConverter.CanChangeTypeImplicit(offsetType, intType)) @@ -2443,7 +2443,7 @@ private IDataExecutionPlanNodeInternal ConvertTopClause(IDataExecutionPlanNodeIn if (topRowFilter == null) return source; - topRowFilter.Expression.GetType(PrimaryDataSource, null, null, parameterTypes, out var topType); + topRowFilter.Expression.GetType(_staticContext, out var topType); var targetType = topRowFilter.Percent ? DataTypeHelpers.Float : DataTypeHelpers.BigInt; if (!SqlTypeConverter.CanChangeTypeImplicit(topType, targetType)) @@ -2456,7 +2456,7 @@ private IDataExecutionPlanNodeInternal ConvertTopClause(IDataExecutionPlanNodeIn if (orderByClause == null) throw new NotSupportedQueryFragmentException("The TOP N WITH TIES clause is not allowed without a corresponding ORDER BY clause", topRowFilter); - var schema = source.GetSchema(DataSources, parameterTypes); + var schema = source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); foreach (var sort in orderByClause.OrderByElements) { @@ -2506,7 +2506,7 @@ private IDataExecutionPlanNodeInternal ConvertOrderByClause(IDataExecutionPlanNo var computeScalar = new ComputeScalarNode { Source = source }; ConvertScalarSubqueries(orderByClause, hints, ref source, computeScalar, parameterTypes, query); - var schema = source.GetSchema(DataSources, parameterTypes); + var schema = source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); var sort = new SortNode { Source = source }; // Sorts can use aliases from the SELECT clause @@ -2555,13 +2555,13 @@ private IDataExecutionPlanNodeInternal ConvertOrderByClause(IDataExecutionPlanNo { var calculated = ComputeScalarExpression(orderBy.Expression, hints, query, computeScalar, nonAggregateSchema, parameterTypes, ref source); sort.Source = source; - schema = source.GetSchema(DataSources, parameterTypes); + schema = source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); calculationRewrites[orderBy.Expression] = calculated.ToColumnReference(); } // Validate the expression - orderBy.Expression.GetType(PrimaryDataSource, schema, nonAggregateSchema, parameterTypes, out _); + orderBy.Expression.GetType(GetExpressionContext(schema, parameterTypes, nonAggregateSchema), out _); sort.Sorts.Add(orderBy.Clone()); } @@ -2590,7 +2590,7 @@ private IDataExecutionPlanNodeInternal ConvertWhereClause(IDataExecutionPlanNode ConvertScalarSubqueries(whereClause.SearchCondition, hints, ref source, computeScalar, parameterTypes, query); // Validate the final expression - whereClause.SearchCondition.GetType(PrimaryDataSource, source.GetSchema(DataSources, parameterTypes), null, parameterTypes, out _); + whereClause.SearchCondition.GetType(GetExpressionContext(source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)), parameterTypes), out _); return new FilterNode { @@ -2607,7 +2607,7 @@ private TSqlFragment CaptureOuterReferences(INodeSchema outerSchema, IDataExecut // We're in a subquery. Check if any columns in the WHERE clause are from the outer query // so we know which columns to pass through and rewrite the filter to use parameters var rewrites = new Dictionary(); - var innerSchema = source?.GetSchema(DataSources, parameterTypes); + var innerSchema = source?.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); var columns = query.GetColumns(); foreach (var column in columns) @@ -2645,7 +2645,7 @@ private TSqlFragment CaptureOuterReferences(INodeSchema outerSchema, IDataExecut private SelectNode ConvertSelectClause(IList selectElements, IList hints, IDataExecutionPlanNodeInternal node, DistinctNode distinct, TSqlFragment query, IDictionary parameterTypes, INodeSchema outerSchema, IDictionary outerReferences, INodeSchema nonAggregateSchema) { - var schema = node.GetSchema(DataSources, parameterTypes); + var schema = node.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); var select = new SelectNode { @@ -2670,7 +2670,7 @@ private SelectNode ConvertSelectClause(IList selectElements, ILis if (!schema.ContainsColumn(colName, out colName)) { // Column name isn't valid. Use the expression extensions to throw a consistent error message - col.GetType(PrimaryDataSource, schema, nonAggregateSchema, parameterTypes, out _); + col.GetType(GetExpressionContext(schema, parameterTypes, nonAggregateSchema), out _); } var alias = scalar.ColumnName?.Value ?? col.MultiPartIdentifier.Identifiers.Last().Value; @@ -2737,7 +2737,7 @@ private SelectNode ConvertSelectClause(IList selectElements, ILis { if (col.AllColumns) { - var distinctSchema = distinct.GetSchema(DataSources, parameterTypes); + var distinctSchema = distinct.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); distinct.Columns.AddRange(distinctSchema.Schema.Keys.Where(k => col.SourceColumn == null || (k.Split('.')[0] + ".*") == col.SourceColumn)); } else @@ -2758,8 +2758,8 @@ private string ComputeScalarExpression(ScalarExpression expression, IList(); var innerParameterTypes = parameterTypes == null ? new Dictionary(StringComparer.OrdinalIgnoreCase) : new Dictionary(parameterTypes, StringComparer.OrdinalIgnoreCase); var subqueryPlan = ConvertSelectStatement(subquery.QueryExpression, hints, outerSchema, outerReferences, innerParameterTypes); @@ -2807,7 +2807,7 @@ private ColumnReferenceExpression ConvertScalarSubqueries(TSqlFragment expressio // Unless the subquery has got an explicit TOP 1 clause, insert an aggregate and assertion nodes // to check for one row - if (!(subqueryPlan.Source.EstimateRowsOut(DataSources, Options, parameterTypes) is RowCountEstimateDefiniteRange range) || range.Maximum > 1) + if (!(subqueryPlan.Source.EstimateRowsOut(new NodeCompilationContext(DataSources, Options, parameterTypes)) is RowCountEstimateDefiniteRange range) || range.Maximum > 1) { subqueryCol = $"Expr{++_colNameCounter}"; var rowCountCol = $"Expr{++_colNameCounter}"; @@ -2955,8 +2955,8 @@ private bool UseMergeJoin(IDataExecutionPlanNodeInternal node, IDataExecutionPla if (outerKey == null) return false; - var outerSchema = node.GetSchema(DataSources, null); - var innerSchema = fetch.GetSchema(DataSources, null); + var outerSchema = node.GetSchema(new NodeCompilationContext(DataSources, Options, null)); + var innerSchema = fetch.GetSchema(new NodeCompilationContext(DataSources, Options, null)); if (!outerSchema.ContainsColumn(outerKey, out outerKey) || !innerSchema.ContainsColumn(innerKey, out innerKey)) @@ -2995,7 +2995,7 @@ private bool UseMergeJoin(IDataExecutionPlanNodeInternal node, IDataExecutionPla if (semiJoin) { // Regenerate the schema after changing the alias - innerSchema = fetch.GetSchema(DataSources, null); + innerSchema = fetch.GetSchema(new NodeCompilationContext(DataSources, Options, null)); if (innerSchema.PrimaryKey != rightAttribute.GetColumnName() && !(merge.RightSource is DistinctNode)) { @@ -3098,7 +3098,7 @@ private int EstimateRowsOut(IExecutionPlanNode source, IDictionary(); var innerParameterTypes = parameterTypes == null ? new Dictionary(StringComparer.OrdinalIgnoreCase) : new Dictionary(parameterTypes, StringComparer.OrdinalIgnoreCase); var subqueryPlan = ConvertTableReference(unqualifiedJoin.SecondTableReference, hints, query, lhsSchema, lhsReferences, innerParameterTypes); @@ -3515,12 +3515,12 @@ private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference refe else if (computeScalar.Columns.Count > 0) source = computeScalar; - var scalarSubquerySchema = source?.GetSchema(DataSources, parameterTypes); + var scalarSubquerySchema = source?.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); var scalarSubqueryReferences = new Dictionary(); CaptureOuterReferences(scalarSubquerySchema, null, tvf, parameterTypes, scalarSubqueryReferences); var dataSource = SelectDataSource(tvf.SchemaObject); - var execute = ExecuteMessageNode.FromMessage(tvf, dataSource, PrimaryDataSource, parameterTypes); + var execute = ExecuteMessageNode.FromMessage(tvf, dataSource, GetExpressionContext(null, parameterTypes)); if (source == null) return execute; @@ -3661,5 +3661,10 @@ private QuerySpecification CreateSelectRow(RowValue row, IList colum return querySpec; } + + private ExpressionCompilationContext GetExpressionContext(INodeSchema schema, IDictionary parameterTypes = null, INodeSchema nonAggregateSchema = null) + { + return new ExpressionCompilationContext(DataSources, Options, parameterTypes ?? _parameterTypes, schema, nonAggregateSchema); + } } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlanOptimizer.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlanOptimizer.cs index b34a9c7f..617c3c2c 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlanOptimizer.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlanOptimizer.cs @@ -49,16 +49,18 @@ public IRootExecutionPlanNodeInternal[] Optimize(IRootExecutionPlanNodeInternal hints.Add(new ConditionalNode.DoNotCompileConditionsHint()); } + var context = new NodeCompilationContext(DataSources, Options, ParameterTypes); + // Move any additional operators down to the FetchXml var nodes = hints != null && hints.OfType().Any(list => list.Hints.Any(h => h.Value.Equals("DEBUG_BYPASS_OPTIMIZATION", StringComparison.OrdinalIgnoreCase))) ? new[] { node } - : node.FoldQuery(DataSources, Options, ParameterTypes, hints); + : node.FoldQuery(context, hints); foreach (var n in nodes) { // Ensure all required columns are added to the FetchXML - n.AddRequiredColumns(DataSources, ParameterTypes, new List()); + n.AddRequiredColumns(context, new List()); // Sort the items in the FetchXml nodes to match examples in documentation SortFetchXmlElements(n); diff --git a/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs b/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs index 33f4c389..be33eeb3 100644 --- a/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs +++ b/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs @@ -341,7 +341,7 @@ public static SqlString Left(SqlString s, [MaxLength] SqlInt32 length) if (s.Value.Length <= length) return s; - return SqlTypeConverter.UseDefaultCollation(s.Value.Substring(0, length.Value)); + return new SqlString(s.Value.Substring(0, length.Value), s.LCID, s.SqlCompareOptions); } /// @@ -358,7 +358,7 @@ public static SqlString Right(SqlString s, [MaxLength] SqlInt32 length) if (s.Value.Length <= length) return s; - return SqlTypeConverter.UseDefaultCollation(s.Value.Substring(s.Value.Length - length.Value, length.Value)); + return new SqlString(s.Value.Substring(s.Value.Length - length.Value, length.Value), s.LCID, s.SqlCompareOptions); } /// @@ -373,7 +373,7 @@ public static SqlString Replace(SqlString input, SqlString find, SqlString repla if (input.IsNull || find.IsNull || replace.IsNull) return SqlString.Null; - return SqlTypeConverter.UseDefaultCollation(Regex.Replace(input.Value, Regex.Escape(find.Value), replace.Value.Replace("$", "$$"), RegexOptions.IgnoreCase)); + return new SqlString(Regex.Replace(input.Value, Regex.Escape(find.Value), replace.Value.Replace("$", "$$"), RegexOptions.IgnoreCase), input.LCID, input.SqlCompareOptions); } /// @@ -440,14 +440,14 @@ public static SqlString Substring(SqlString expression, SqlInt32 start, [MaxLeng start = 1; if (start > expression.Value.Length) - return SqlTypeConverter.UseDefaultCollation(String.Empty); + return new SqlString(String.Empty, expression.LCID, expression.SqlCompareOptions); start -= 1; if (start + length > expression.Value.Length) length = expression.Value.Length - start; - return SqlTypeConverter.UseDefaultCollation(expression.Value.Substring(start.Value, length.Value)); + return new SqlString(expression.Value.Substring(start.Value, length.Value), expression.LCID, expression.SqlCompareOptions); } /// @@ -460,7 +460,7 @@ public static SqlString Trim([MaxLength] SqlString expression) if (expression.IsNull) return expression; - return SqlTypeConverter.UseDefaultCollation(expression.Value.Trim(' ')); + return new SqlString(expression.Value.Trim(' '), expression.LCID, expression.SqlCompareOptions); } /// @@ -473,7 +473,7 @@ public static SqlString LTrim([MaxLength] SqlString expression) if (expression.IsNull) return expression; - return SqlTypeConverter.UseDefaultCollation(expression.Value.TrimStart(' ')); + return new SqlString(expression.Value.TrimStart(' '), expression.LCID, expression.SqlCompareOptions); } /// @@ -486,7 +486,7 @@ public static SqlString RTrim([MaxLength] SqlString expression) if (expression.IsNull) return expression; - return SqlTypeConverter.UseDefaultCollation(expression.Value.TrimEnd(' ')); + return new SqlString(expression.Value.TrimEnd(' '), expression.LCID, expression.SqlCompareOptions); } /// @@ -526,12 +526,12 @@ public static SqlInt32 CharIndex(SqlString find, SqlString search, SqlInt32 star /// /// An integer from 0 through 255 /// - public static SqlString Char(SqlInt32 value) + public static SqlString Char(SqlInt32 value, ExpressionExecutionContext context) { if (value.IsNull || value.Value < 0 || value.Value > 255) return SqlString.Null; - return SqlTypeConverter.UseDefaultCollation(new string((char)value.Value, 1)); + return context.PrimaryDataSource.DefaultCollation.ToSqlString(new string((char)value.Value, 1)); } /// @@ -539,12 +539,12 @@ public static SqlString Char(SqlInt32 value) /// /// An integer from 0 through 255 /// - public static SqlString NChar(SqlInt32 value) + public static SqlString NChar(SqlInt32 value, ExpressionExecutionContext context) { if (value.IsNull || value.Value < 0 || value.Value > 0x10FFFF) return SqlString.Null; - return SqlTypeConverter.UseDefaultCollation(new string((char)value.Value, 1)); + return context.PrimaryDataSource.DefaultCollation.ToSqlString(new string((char)value.Value, 1)); } /// @@ -577,11 +577,11 @@ public static SqlInt32 Unicode(SqlString value) /// /// Returns the identifier of the user /// - /// The options that provide access to the user details + /// The context in which the expression is being executed /// - public static SqlEntityReference User_Name(IQueryExecutionOptions options) + public static SqlEntityReference User_Name(ExpressionExecutionContext context) { - return new SqlEntityReference(options.PrimaryDataSource, "systemuser", options.UserId); + return new SqlEntityReference(context.Options.PrimaryDataSource, "systemuser", context.Options.UserId); } /// @@ -608,7 +608,7 @@ public static T IsNull(T check, T replacement) /// Format pattern /// Optional argument specifying a culture /// - public static SqlString Format(T value, SqlString format, [Optional] SqlString culture) + public static SqlString Format(T value, SqlString format, [Optional] SqlString culture, ExpressionExecutionContext context) where T : INullable { if (value.IsNull) @@ -620,10 +620,10 @@ public static SqlString Format(T value, SqlString format, [Optional] SqlStrin throw new QueryExecutionException("Invalid type for FORMAT function"); var innerValue = (IFormattable)valueProp.GetValue(value); - return Format(innerValue, format, culture); + return Format(innerValue, format, culture, context); } - private static SqlString Format(IFormattable value, SqlString format, SqlString culture) + private static SqlString Format(IFormattable value, SqlString format, SqlString culture, ExpressionExecutionContext context) { if (value == null) return SqlString.Null; @@ -639,7 +639,7 @@ private static SqlString Format(IFormattable value, SqlString format, SqlString cultureInfo = CultureInfo.GetCultureInfo(culture.Value); var formatted = value.ToString(format.Value, cultureInfo); - return SqlTypeConverter.UseDefaultCollation(formatted); + return context.PrimaryDataSource.DefaultCollation.ToSqlString(formatted); } catch { diff --git a/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems b/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems index 8fdd4258..17601662 100644 --- a/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems +++ b/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems @@ -26,6 +26,7 @@ + diff --git a/MarkMpn.Sql4Cds.Engine/MetaMetadataCache.cs b/MarkMpn.Sql4Cds.Engine/MetaMetadataCache.cs index b0c1aff1..8ffa4367 100644 --- a/MarkMpn.Sql4Cds.Engine/MetaMetadataCache.cs +++ b/MarkMpn.Sql4Cds.Engine/MetaMetadataCache.cs @@ -32,7 +32,7 @@ static MetaMetadataCache() metadataNode.ManyToOneRelationshipAlias = "relationship_n_1"; metadataNode.ManyToManyRelationshipAlias = "relationship_n_n"; - var metadataSchema = metadataNode.GetSchema(null, null); + var metadataSchema = metadataNode.GetSchema(new NodeCompilationContext(null, null, null)); _customMetadata["metadata." + metadataNode.EntityAlias] = SchemaToMetadata(metadataSchema, metadataNode.EntityAlias); _customMetadata["metadata." + metadataNode.AttributeAlias] = SchemaToMetadata(metadataSchema, metadataNode.AttributeAlias); @@ -43,7 +43,7 @@ static MetaMetadataCache() var optionsetNode = new GlobalOptionSetQueryNode(); optionsetNode.Alias = "globaloptionset"; - var optionsetSchema = optionsetNode.GetSchema(null, null); + var optionsetSchema = optionsetNode.GetSchema(new NodeCompilationContext(null, null, null)); _customMetadata["metadata." + optionsetNode.Alias] = SchemaToMetadata(optionsetSchema, optionsetNode.Alias); } diff --git a/MarkMpn.Sql4Cds.Engine/MetadataExtensions.cs b/MarkMpn.Sql4Cds.Engine/MetadataExtensions.cs index 8a845ac3..4a9b4115 100644 --- a/MarkMpn.Sql4Cds.Engine/MetadataExtensions.cs +++ b/MarkMpn.Sql4Cds.Engine/MetadataExtensions.cs @@ -85,7 +85,7 @@ public static Type GetAttributeType(this AttributeMetadata attrMetadata) public static DataTypeReference GetAttributeSqlType(this AttributeMetadata attrMetadata, DataSource dataSource, bool write) { if (attrMetadata is MultiSelectPicklistAttributeMetadata) - return DataTypeHelpers.NVarChar(Int32.MaxValue, dataSource.DefaultCollation, CollationLabel.CoercibleDefault); + return DataTypeHelpers.NVarChar(Int32.MaxValue, dataSource.DefaultCollation, CollationLabel.Implicit); var typeCode = attrMetadata.AttributeType; @@ -114,7 +114,7 @@ public static DataTypeReference GetAttributeSqlType(this AttributeMetadata attrM return DataTypeHelpers.Float; if (attrMetadata is EntityNameAttributeMetadata || typeCode == AttributeTypeCode.EntityName) - return DataTypeHelpers.NVarChar(EntityLogicalNameMaxLength, dataSource.DefaultCollation, CollationLabel.CoercibleDefault); + return DataTypeHelpers.NVarChar(EntityLogicalNameMaxLength, dataSource.DefaultCollation, CollationLabel.Implicit); if (attrMetadata is ImageAttributeMetadata) return DataTypeHelpers.VarBinary(Int32.MaxValue); @@ -126,7 +126,7 @@ public static DataTypeReference GetAttributeSqlType(this AttributeMetadata attrM return DataTypeHelpers.BigInt; if (typeCode == AttributeTypeCode.PartyList) - return DataTypeHelpers.NVarChar(Int32.MaxValue, dataSource.DefaultCollation, CollationLabel.CoercibleDefault); + return DataTypeHelpers.NVarChar(Int32.MaxValue, dataSource.DefaultCollation, CollationLabel.Implicit); if (attrMetadata is LookupAttributeMetadata || attrMetadata.IsPrimaryId == true || typeCode == AttributeTypeCode.Lookup || typeCode == AttributeTypeCode.Customer || typeCode == AttributeTypeCode.Owner) return DataTypeHelpers.EntityReference; @@ -170,7 +170,7 @@ public static DataTypeReference GetAttributeSqlType(this AttributeMetadata attrM maxLength = maxLengthSetting.Value; } - return DataTypeHelpers.NVarChar(maxLength, dataSource.DefaultCollation, CollationLabel.CoercibleDefault); + return DataTypeHelpers.NVarChar(maxLength, dataSource.DefaultCollation, CollationLabel.Implicit); } if (attrMetadata is UniqueIdentifierAttributeMetadata || typeCode == AttributeTypeCode.Uniqueidentifier) @@ -180,7 +180,7 @@ public static DataTypeReference GetAttributeSqlType(this AttributeMetadata attrM return DataTypeHelpers.UniqueIdentifier; if (attrMetadata.AttributeType == AttributeTypeCode.Virtual) - return DataTypeHelpers.NVarChar(Int32.MaxValue, dataSource.DefaultCollation, CollationLabel.CoercibleDefault); + return DataTypeHelpers.NVarChar(Int32.MaxValue, dataSource.DefaultCollation, CollationLabel.Implicit); throw new ApplicationException("Unknown attribute type " + attrMetadata.GetType()); } diff --git a/MarkMpn.Sql4Cds.Engine/NodeContext.cs b/MarkMpn.Sql4Cds.Engine/NodeContext.cs new file mode 100644 index 00000000..24d1da40 --- /dev/null +++ b/MarkMpn.Sql4Cds.Engine/NodeContext.cs @@ -0,0 +1,196 @@ +using System; +using System.Collections.Generic; +using System.Text; +using MarkMpn.Sql4Cds.Engine.ExecutionPlan; +using Microsoft.SqlServer.TransactSql.ScriptDom; +using Microsoft.Xrm.Sdk; + +namespace MarkMpn.Sql4Cds.Engine +{ + /// + /// Provides access to the context in which a node will be executed + /// + class NodeCompilationContext + { + /// + /// Creates a new + /// + /// The data sources that are available to the query + /// The options that the query will be executed with + /// The names and types of the parameters that are available to the query + public NodeCompilationContext( + IDictionary dataSources, + IQueryExecutionOptions options, + IDictionary parameterTypes) + { + DataSources = dataSources; + Options = options; + ParameterTypes = parameterTypes; + } + + /// + /// Returns the data sources that are available to the query + /// + public IDictionary DataSources { get; } + + /// + /// Returns the options that the query will be executed with + /// + public IQueryExecutionOptions Options { get; } + + /// + /// Returns the names and types of the parameters that are available to the query + /// + public IDictionary ParameterTypes { get; } + + /// + /// Returns the details of the primary data source + /// + public DataSource PrimaryDataSource => DataSources[Options.PrimaryDataSource]; + } + + /// + /// Provides access to the context in which a node is being executed + /// + class NodeExecutionContext : NodeCompilationContext + { + /// + /// Creates a new + /// + /// The data sources that are available to the query + /// The options that the query is being executed with + /// The names and types of the parameters that are available to the query + /// The current value of each parameter + public NodeExecutionContext( + IDictionary dataSources, + IQueryExecutionOptions options, + IDictionary parameterTypes, + IDictionary parameterValues) + : base(dataSources, options, parameterTypes) + { + ParameterValues = parameterValues; + } + + /// + /// Returns the current value of each parameter + /// + public IDictionary ParameterValues { get; } + } + + /// + /// Provides access to the context in which an expression will be evaluated + /// + class ExpressionCompilationContext : NodeCompilationContext + { + /// + /// Creates a new + /// + /// The data sources that are available to the query + /// The options that the query is being executed with + /// The names and types of the parameters that are available to the query + /// The schema of data which is available to the expression + /// The schema of data prior to aggregation + public ExpressionCompilationContext( + IDictionary dataSources, + IQueryExecutionOptions options, + IDictionary parameterTypes, + INodeSchema schema, + INodeSchema nonAggregateSchema) + : base(dataSources, options, parameterTypes) + { + Schema = schema; + NonAggregateSchema = nonAggregateSchema; + } + + /// + /// Creates a new based on a + /// + /// The to copy options from + /// The schema of data which is available to the expression + /// The schema of data prior to aggregation + public ExpressionCompilationContext( + NodeCompilationContext nodeContext, + INodeSchema schema, + INodeSchema nonAggregateSchema) + : base(nodeContext.DataSources, nodeContext.Options, nodeContext.ParameterTypes) + { + Schema = schema; + NonAggregateSchema = nonAggregateSchema; + } + + /// + /// Returns the schema of data which is available to the expression + /// + public INodeSchema Schema { get; } + + /// + /// Returns the schema of data prior to aggregation + /// + /// + /// Used to provide more helpful error messages when a non-aggregated field is incorrectly referenced after aggregation + /// + public INodeSchema NonAggregateSchema { get; } + } + + /// + /// Provides access to the context in which an expression is being evaluated + /// + class ExpressionExecutionContext : NodeExecutionContext + { + /// + /// Creates a new + /// + /// The data sources that are available to the query + /// The options that the query is being executed with + /// The values for the current row the expression is being evaluated for + /// The current value of each parameter + public ExpressionExecutionContext( + IDictionary dataSources, + IQueryExecutionOptions options, + IDictionary parameterTypes, + IDictionary parameterValues, + Entity entity) + : base(dataSources, options, parameterTypes, parameterValues) + { + Entity = entity; + } + + /// + /// Creates a new based on a + /// + /// The to copy options from + /// + /// The returned instance is designed to be reused for multiple rows within the same node. The + /// property will initially be null - set it to the + /// representing each row as it is processed. + /// + public ExpressionExecutionContext(NodeExecutionContext nodeContext) + : base(nodeContext.DataSources, nodeContext.Options, nodeContext.ParameterTypes, nodeContext.ParameterValues) + { + Entity = null; + } + + /// + /// Creates a new based on a + /// + /// The to copy options from + /// + /// As no parameter values are specified in this constructor, this is only suitable for use when the expression to + /// be evaluated does not consume any parameters. + /// + /// The returned instance is designed to be reused for multiple rows within the same node. The + /// property will initially be null - set it to the + /// representing each row as it is processed. + /// + public ExpressionExecutionContext(ExpressionCompilationContext compilationContext) + : base(compilationContext.DataSources, compilationContext.Options, compilationContext.ParameterTypes, null) + { + Entity = null; + } + + /// + /// Returns the values for the current row the expression is being evaluated for + /// + public Entity Entity { get; set; } + } +} From 079bce20a5dce26b635b28d93cd93d906f3f581b Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Wed, 22 Mar 2023 08:32:58 +0000 Subject: [PATCH 04/34] Apply explicit collation on all primary expression types --- .../ExecutionPlan/ExpressionExtensions.cs | 109 ++++++++++-------- 1 file changed, 63 insertions(+), 46 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs index eb1c4b50..b7cbafd4 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs @@ -66,64 +66,96 @@ public static Func Compile(this BooleanExpress private static Expression ToExpression(this TSqlFragment expr, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { + Expression expression; + if (expr is ColumnReferenceExpression col) - return ToExpression(col, context, contextParam, out sqlType); + expression = ToExpression(col, context, contextParam, out sqlType); else if (expr is IdentifierLiteral guid) - return ToExpression(guid, context, contextParam, out sqlType); + expression = ToExpression(guid, context, contextParam, out sqlType); else if (expr is IntegerLiteral i) - return ToExpression(i, context, contextParam, out sqlType); + expression = ToExpression(i, context, contextParam, out sqlType); else if (expr is MoneyLiteral money) - return ToExpression(money, context, contextParam, out sqlType); + expression = ToExpression(money, context, contextParam, out sqlType); else if (expr is NullLiteral n) - return ToExpression(n, context, contextParam, out sqlType); + expression = ToExpression(n, context, contextParam, out sqlType); else if (expr is NumericLiteral num) - return ToExpression(num, context, contextParam, out sqlType); + expression = ToExpression(num, context, contextParam, out sqlType); else if (expr is RealLiteral real) - return ToExpression(real, context, contextParam, out sqlType); + expression = ToExpression(real, context, contextParam, out sqlType); else if (expr is StringLiteral str) - return ToExpression(str, context, contextParam, out sqlType); + expression = ToExpression(str, context, contextParam, out sqlType); else if (expr is OdbcLiteral odbc) - return ToExpression(odbc, context, contextParam, out sqlType); + expression = ToExpression(odbc, context, contextParam, out sqlType); else if (expr is BooleanBinaryExpression boolBin) - return ToExpression(boolBin, context, contextParam, out sqlType); + expression = ToExpression(boolBin, context, contextParam, out sqlType); else if (expr is BooleanComparisonExpression cmp) - return ToExpression(cmp, context, contextParam, out sqlType); + expression = ToExpression(cmp, context, contextParam, out sqlType); else if (expr is BooleanParenthesisExpression boolParen) - return ToExpression(boolParen, context, contextParam, out sqlType); + expression = ToExpression(boolParen, context, contextParam, out sqlType); else if (expr is InPredicate inPred) - return ToExpression(inPred, context, contextParam, out sqlType); + expression = ToExpression(inPred, context, contextParam, out sqlType); else if (expr is BooleanIsNullExpression isNull) - return ToExpression(isNull, context, contextParam, out sqlType); + expression = ToExpression(isNull, context, contextParam, out sqlType); else if (expr is LikePredicate like) - return ToExpression(like, context, contextParam, out sqlType); + expression = ToExpression(like, context, contextParam, out sqlType); else if (expr is BooleanNotExpression not) - return ToExpression(not, context, contextParam, out sqlType); + expression = ToExpression(not, context, contextParam, out sqlType); else if (expr is FullTextPredicate fullText) - return ToExpression(fullText, context, contextParam, out sqlType); + expression = ToExpression(fullText, context, contextParam, out sqlType); else if (expr is Microsoft.SqlServer.TransactSql.ScriptDom.BinaryExpression bin) - return ToExpression(bin, context, contextParam, out sqlType); + expression = ToExpression(bin, context, contextParam, out sqlType); else if (expr is FunctionCall func) - return ToExpression(func, context, contextParam, out sqlType); + expression = ToExpression(func, context, contextParam, out sqlType); else if (expr is ParenthesisExpression paren) - return ToExpression(paren, context, contextParam, out sqlType); + expression = ToExpression(paren, context, contextParam, out sqlType); else if (expr is Microsoft.SqlServer.TransactSql.ScriptDom.UnaryExpression unary) - return ToExpression(unary, context, contextParam, out sqlType); + expression = ToExpression(unary, context, contextParam, out sqlType); else if (expr is VariableReference var) - return ToExpression(var, context, contextParam, out sqlType); + expression = ToExpression(var, context, contextParam, out sqlType); else if (expr is SimpleCaseExpression simpleCase) - return ToExpression(simpleCase, context, contextParam, out sqlType); + expression = ToExpression(simpleCase, context, contextParam, out sqlType); else if (expr is SearchedCaseExpression searchedCase) - return ToExpression(searchedCase, context, contextParam, out sqlType); + expression = ToExpression(searchedCase, context, contextParam, out sqlType); else if (expr is ConvertCall convert) - return ToExpression(convert, context, contextParam, out sqlType); + expression = ToExpression(convert, context, contextParam, out sqlType); else if (expr is CastCall cast) - return ToExpression(cast, context, contextParam, out sqlType); + expression = ToExpression(cast, context, contextParam, out sqlType); else if (expr is ParameterlessCall parameterless) - return ToExpression(parameterless, context, contextParam, out sqlType); + expression = ToExpression(parameterless, context, contextParam, out sqlType); else if (expr is GlobalVariableExpression global) - return ToExpression(global, context, contextParam, out sqlType); + expression = ToExpression(global, context, contextParam, out sqlType); else throw new NotSupportedQueryFragmentException("Unhandled expression type", expr); + + if (expr is PrimaryExpression primary && primary.Collation != null) + { + if (!Collation.TryParse(primary.Collation.Value, out var coll)) + throw new NotSupportedQueryFragmentException("Invalid collation", primary.Collation); + + if (expression.Type == typeof(SqlString) && sqlType is SqlDataTypeReferenceWithCollation sqlTypeWithCollation) + { + expression = Expr.Call(() => ConvertCollation(Expr.Arg(), Expr.Arg()), expression, Expression.Constant(coll)); + sqlType = new SqlDataTypeReferenceWithCollation + { + SqlDataTypeOption = sqlTypeWithCollation.SqlDataTypeOption, + Collation = coll, + CollationLabel = CollationLabel.Explicit + }; + + foreach (var param in sqlTypeWithCollation.Parameters) + ((SqlDataTypeReferenceWithCollation)sqlType).Parameters.Add(param); + } + } + + return expression; + } + + private static SqlString ConvertCollation(SqlString value, Collation collation) + { + if (value.IsNull) + return value; + + return collation.ToSqlString(value.Value); } private static Expression ToExpression(ColumnReferenceExpression col, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) @@ -200,14 +232,11 @@ private static Expression ToExpression(RealLiteral real, ExpressionCompilationCo private static Expression ToExpression(StringLiteral str, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { - var collationLabel = CollationLabel.CoercibleDefault; - var collation = GetCollation(context.PrimaryDataSource, str.Collation, ref collationLabel); - sqlType = str.IsNational - ? DataTypeHelpers.NVarChar(str.Value.Length, collation, collationLabel) - : DataTypeHelpers.VarChar(str.Value.Length, collation, collationLabel); + ? DataTypeHelpers.NVarChar(str.Value.Length, context.PrimaryDataSource.DefaultCollation, CollationLabel.CoercibleDefault) + : DataTypeHelpers.VarChar(str.Value.Length, context.PrimaryDataSource.DefaultCollation, CollationLabel.CoercibleDefault); - return Expression.Constant(collation.ToSqlString(str.Value)); + return Expression.Constant(context.PrimaryDataSource.DefaultCollation.ToSqlString(str.Value)); } private static Expression ToExpression(OdbcLiteral odbc, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) @@ -1590,17 +1619,5 @@ public static bool IsConstantValueExpression(this ScalarExpression expr, Express return true; } - - private static Collation GetCollation(DataSource dataSource, Identifier collation, ref CollationLabel collationLabel) - { - if (collation == null) - return dataSource.DefaultCollation; - - if (!Collation.TryParse(collation.Value, out var coll)) - throw new NotSupportedQueryFragmentException("Invalid collation", collation); - - collationLabel = CollationLabel.Explicit; - return coll; - } } } From bb25ed878a0dc272b9dcad2302ad878157ad0731 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Wed, 22 Mar 2023 08:33:14 +0000 Subject: [PATCH 05/34] Removed test - https://github.com/MicrosoftDocs/sql-docs/issues/8123 --- MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs index f42b08df..5d860885 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs @@ -5092,16 +5092,5 @@ public void ExplicitCollation() planBuilder.Build(query, null, out _); } - - [ExpectedException(typeof(NotSupportedQueryFragmentException))] - [TestMethod] - public void TwoExplicitCollationsError() - { - var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); - - var query = "SELECT ('abc' COLLATE French_CI_AS) COLLATE French_CS_AS"; - - planBuilder.Build(query, null, out _); - } } } From 2fd258066082bbe4a99f43af8d35c2f1d9a623a5 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Fri, 24 Mar 2023 20:10:17 +0000 Subject: [PATCH 06/34] Collation progress --- .../AdoProviderTests.cs | 41 ++ .../ExecutionPlanTests.cs | 55 ++- .../FakeXrmEasyTestsBase.cs | 15 +- MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs | 2 +- .../ExecutionPlan/ExpressionExtensions.cs | 108 ++++- .../ExecutionPlan/FilterNode.cs | 107 +++-- .../ExecutionPlan/FoldableJoinNode.cs | 4 + .../ExecutionPlan/HashJoinNode.cs | 16 +- .../ExecutionPlan/SqlTypeConverter.cs | 17 + .../ExecutionPlanBuilder.cs | 369 ++++++++++-------- MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs | 27 ++ .../MarkMpn.Sql4Cds.Engine.projitems | 1 + MarkMpn.Sql4Cds.Engine/NodeContext.cs | 30 ++ .../Visitors/ExplicitCollationVisitor.cs | 38 ++ 14 files changed, 612 insertions(+), 218 deletions(-) create mode 100644 MarkMpn.Sql4Cds.Engine/Visitors/ExplicitCollationVisitor.cs diff --git a/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs index e040dfb8..f29011d4 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs @@ -1009,5 +1009,46 @@ INSERT INTO account (name) VALUES (@name) } } } + + [TestMethod] + public void SortByCollation() + { + using (var con = new Sql4CdsConnection(_localDataSource)) + using (var cmd = con.CreateCommand()) + { + cmd.CommandText = "INSERT INTO account (name) VALUES ('Chiapas'),('Colima'), ('Cinco Rios'), ('California')"; + cmd.ExecuteNonQuery(); + + cmd.CommandText = "SELECT name FROM account ORDER BY name COLLATE Latin1_General_CS_AS ASC"; + + using (var reader = cmd.ExecuteReader()) + { + var results = new List(); + + while (reader.Read()) + results.Add(reader.GetString(0)); + + var expected = new[] { "California", "Chiapas", "Cinco Rios", "Colima" }; + + for (var i = 0; i < expected.Length; i++) + Assert.AreEqual(expected[i], results[i]); + } + + cmd.CommandText = "SELECT name FROM account ORDER BY name COLLATE Traditional_Spanish_ci_ai ASC"; + + using (var reader = cmd.ExecuteReader()) + { + var results = new List(); + + while (reader.Read()) + results.Add(reader.GetString(0)); + + var expected = new[] { "California", "Cinco Rios", "Colima", "Chiapas" }; + + for (var i = 0; i < expected.Length; i++) + Assert.AreEqual(expected[i], results[i]); + } + } + } } } diff --git a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs index 5d860885..bfd70c78 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs @@ -5083,13 +5083,64 @@ public void FoldMultipleJoinConditionsWithKnownValue() "); } + [TestMethod] + [ExpectedException(typeof(NotSupportedQueryFragmentException))] + public void CollationConflict() + { + var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var query = "SELECT * FROM prod.dbo.account p, french.dbo.account f WHERE p.name = f.name"; + planBuilder.Build(query, null, out _); + } + [TestMethod] public void ExplicitCollation() { - var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); + var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var query = "SELECT * FROM prod.dbo.account p, french.dbo.account f WHERE p.name = f.name COLLATE French_CI_AS"; + var plans = planBuilder.Build(query, null, out _); + + Assert.AreEqual(1, plans.Length); + var select = AssertNode(plans[0]); + var join = AssertNode(select.Source); + Assert.AreEqual("p.name", join.LeftAttribute.ToSql()); + Assert.AreEqual("Expr1", join.RightAttribute.ToSql()); + var fetch1 = AssertNode(join.LeftSource); + var computeScalar = AssertNode(join.RightSource); + Assert.AreEqual("ExplicitCollation(f.name COLLATE French_CI_AS)", computeScalar.Columns["Expr1"].ToSql()); + var fetch2 = AssertNode(computeScalar.Source); + } + + [TestMethod] + [ExpectedException(typeof(NotSupportedQueryFragmentException))] + public void NoCollationSelectListError() + { + var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var query = "SELECT (CASE WHEN p.employees > f.employees THEN p.name ELSE f.name END) FROM prod.dbo.account p, french.dbo.account f"; + planBuilder.Build(query, null, out _); + } + + [TestMethod] + public void NoCollationExprWithExplicitCollationSelectList() + { + var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var query = "SELECT (CASE WHEN p.employees > f.employees THEN p.name ELSE f.name END) COLLATE Latin1_General_CI_AS FROM prod.dbo.account p, french.dbo.account f"; + planBuilder.Build(query, null, out _); + } - var query = "SELECT 'abc' COLLATE French_CI_AS"; + [TestMethod] + [ExpectedException(typeof(NotSupportedQueryFragmentException))] + public void NoCollationCollationSensitiveFunctionError() + { + var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var query = "SELECT PATINDEX((CASE WHEN p.employees > f.employees THEN p.name ELSE f.name END), 'a') FROM prod.dbo.account p, french.dbo.account f"; + planBuilder.Build(query, null, out _); + } + [TestMethod] + public void NoCollationExprWithExplicitCollationCollationSensitiveFunctionError() + { + var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var query = "SELECT PATINDEX((CASE WHEN p.employees > f.employees THEN p.name ELSE f.name END) COLLATE Latin1_General_CI_AS, 'a') FROM prod.dbo.account p, french.dbo.account f"; planBuilder.Build(query, null, out _); } } diff --git a/MarkMpn.Sql4Cds.Engine.Tests/FakeXrmEasyTestsBase.cs b/MarkMpn.Sql4Cds.Engine.Tests/FakeXrmEasyTestsBase.cs index 5ee5f543..cc9b16cf 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/FakeXrmEasyTestsBase.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/FakeXrmEasyTestsBase.cs @@ -22,6 +22,9 @@ public class FakeXrmEasyTestsBase protected readonly IOrganizationService _service2; protected readonly XrmFakedContext _context2; protected readonly DataSource _dataSource2; + protected readonly IOrganizationService _service3; + protected readonly XrmFakedContext _context3; + protected readonly DataSource _dataSource3; protected readonly IDictionary _dataSources; protected readonly IDictionary _localDataSource; @@ -45,7 +48,17 @@ public FakeXrmEasyTestsBase() _service2 = _context2.GetOrganizationService(); _dataSource2 = new DataSource { Name = "prod", Connection = _service2, Metadata = new AttributeMetadataCache(_service2), TableSizeCache = new StubTableSizeCache(), MessageCache = new StubMessageCache(), DefaultCollation = Collation.USEnglish }; - _dataSources = new[] { _dataSource, _dataSource2 }.ToDictionary(ds => ds.Name); + _context3 = new XrmFakedContext(); + _context3.InitializeMetadata(Assembly.GetExecutingAssembly()); + _context3.CallerId = _context.CallerId; + _context3.AddFakeMessageExecutor(new RetrieveVersionRequestExecutor()); + _context3.AddGenericFakeMessageExecutor(SampleMessageExecutor.MessageName, new SampleMessageExecutor()); + + _service3 = _context3.GetOrganizationService(); + Collation.TryParse("French_CI_AI", out var frenchCIAI); + _dataSource3 = new DataSource { Name = "french", Connection = _service3, Metadata = new AttributeMetadataCache(_service3), TableSizeCache = new StubTableSizeCache(), MessageCache = new StubMessageCache(), DefaultCollation = frenchCIAI }; + + _dataSources = new[] { _dataSource, _dataSource2, _dataSource3 }.ToDictionary(ds => ds.Name); _localDataSource = new Dictionary { ["local"] = new DataSource { Name = "local", Connection = _service, Metadata = _dataSource.Metadata, TableSizeCache = _dataSource.TableSizeCache, MessageCache = _dataSource.MessageCache, DefaultCollation = Collation.USEnglish } diff --git a/MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs b/MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs index 7af99588..ffe375a6 100644 --- a/MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs +++ b/MarkMpn.Sql4Cds.Engine/DataTypeHelpers.cs @@ -480,7 +480,7 @@ public bool Equals(DataTypeReference x, DataTypeReference y) return false; } - if (xColl != null && yColl != null &&!xColl.Collation.Equals(yColl.Collation)) + if (xColl != null && yColl != null && (xColl.Collation == null ^ yColl.Collation == null || xColl.Collation != null && yColl.Collation != null && !xColl.Collation.Equals(yColl.Collation))) return false; return true; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs index b7cbafd4..9dcf6427 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs @@ -134,7 +134,7 @@ private static Expression ToExpression(this TSqlFragment expr, ExpressionCompila if (expression.Type == typeof(SqlString) && sqlType is SqlDataTypeReferenceWithCollation sqlTypeWithCollation) { - expression = Expr.Call(() => ConvertCollation(Expr.Arg(), Expr.Arg()), expression, Expression.Constant(coll)); + expression = Expr.Call(() => SqlTypeConverter.ConvertCollation(Expr.Arg(), Expr.Arg()), expression, Expression.Constant(coll)); sqlType = new SqlDataTypeReferenceWithCollation { SqlDataTypeOption = sqlTypeWithCollation.SqlDataTypeOption, @@ -150,14 +150,6 @@ private static Expression ToExpression(this TSqlFragment expr, ExpressionCompila return expression; } - private static SqlString ConvertCollation(SqlString value, Collation collation) - { - if (value.IsNull) - return value; - - return collation.ToSqlString(value.Value); - } - private static Expression ToExpression(ColumnReferenceExpression col, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { var name = col.GetColumnName(); @@ -327,6 +319,8 @@ cmp.SecondExpression is StringLiteral str && rhs = SqlTypeConverter.Convert(rhs, rhsType, type); } + AssertCollationSensitive(type, cmp.ComparisonType.ToString().ToLowerInvariant() + " operation", cmp); + switch (cmp.ComparisonType) { case BooleanComparisonType.Equals: @@ -685,9 +679,6 @@ private static MethodInfo GetMethod(Type targetType, DataSource primaryDataSourc // Use the [MaxLength(value)] attribute from the method where available var methodMaxLength = method.GetCustomAttribute(); - // TODO: Add an attribute to indicate if the collation should be taken from a parameter to the function - // or use the default collation for the connection - if (methodMaxLength?.MaxLength != null) sqlType = DataTypeHelpers.NVarChar(methodMaxLength.MaxLength.Value, primaryDataSource.DefaultCollation, CollationLabel.CoercibleDefault); @@ -773,6 +764,64 @@ private static MethodInfo GetMethod(Type targetType, DataSource primaryDataSourc throw new NotSupportedQueryFragmentException($"Cannot convert {paramTypes[i].ToSql()} to {paramType.ToSqlType(primaryDataSource).ToSql()}", i < paramOffset ? func : func.Parameters[i - paramOffset]); } + if (method.GetCustomAttribute(typeof(CollationSensitiveAttribute)) != null) + { + // If method is collation sensitive: + // 1. check all string parameters can be converted to a consistent collation + // 2. check the consistent collation label is not no-collation + // 3. use the same collation for the return type + SqlDataTypeReferenceWithCollation collation = null; + + foreach (var paramType in paramTypes) + { + if (!(paramType is SqlDataTypeReferenceWithCollation collationParam)) + continue; + + if (collation == null) + { + collation = collationParam; + continue; + } + + if (!SqlDataTypeReferenceWithCollation.TryConvertCollation(collation, collationParam, out var consistentCollation, out var collationLabel)) + throw new NotSupportedQueryFragmentException($"Cannot resolve collation conflict between '{collation.Collation.Name}' and {collationParam.Collation.Name}' in {func.FunctionName.Value.ToLowerInvariant()} operation", func); + + collation = new SqlDataTypeReferenceWithCollation + { + Collation = consistentCollation, + CollationLabel = collationLabel + }; + } + + AssertCollationSensitive(collation, func.FunctionName.Value.ToLowerInvariant() + " operation", func); + + for (var i = 0; i < paramTypes.Length; i++) + { + if (!(paramTypes[i] is SqlDataTypeReferenceWithCollation collationParam)) + continue; + + if (!collationParam.Collation.Equals(collation.Collation)) + { + paramExpressions[i] = Expr.Call(() => SqlTypeConverter.ConvertCollation(Expr.Arg(), Expr.Arg()), paramExpressions[i], Expression.Constant(collation.Collation)); + paramTypes[i] = new SqlDataTypeReferenceWithCollation + { + SqlDataTypeOption = collationParam.SqlDataTypeOption, + Collation = collation.Collation, + CollationLabel = CollationLabel.Explicit + }; + + foreach (var param in collationParam.Parameters) + ((SqlDataTypeReferenceWithCollation)paramTypes[i]).Parameters.Add(param); + } + } + + if (sqlType is SqlDataTypeReferenceWithCollation outputCollation) + { + outputCollation.Collation = collation.Collation; + outputCollation.CollationLabel = collation.CollationLabel; + } + } + if (sqlType == null) sqlType = method.ReturnType.ToSqlType(primaryDataSource); @@ -784,6 +833,24 @@ private static Expression ToExpression(this FunctionCall func, ExpressionCompila if (func.OverClause != null) throw new NotSupportedQueryFragmentException("Window functions are not supported", func); + // Special case: ExplicitCollation is a pseudo-function that's introduced by the ExplicitCollationVisitor to wrap + // primary expressions with a collation definition. The inner expression will already have applied the collation + // change so we can return it without any further processing + if (func.FunctionName.Value == "ExplicitCollation" && func.Parameters.Count == 1) + { + var converted = func.Parameters[0].ToExpression(context, contextParam, out sqlType); + + if (!(sqlType is SqlDataTypeReferenceWithCollation coll) || + !coll.SqlDataTypeOption.IsStringType() || + coll.Collation == null || + coll.CollationLabel != CollationLabel.Explicit) + { + throw new NotSupportedQueryFragmentException("Unknown function", func); + } + + return converted; + } + // Find the method to call and get the expressions for the parameter values var method = GetMethod(func, context, contextParam, out var paramValues, out sqlType); @@ -934,6 +1001,8 @@ private static Expression ToExpression(this LikePredicate like, ExpressionCompil escape = SqlTypeConverter.Convert(escape, escapeType, stringType); } + AssertCollationSensitive(stringType, "like operation", like); + if (escape == null) escape = Expression.Constant(SqlString.Null); @@ -1344,10 +1413,14 @@ private static Expression ToExpression(this ConvertCall convert, ExpressionCompi sqlTargetType.SqlDataTypeOption.IsStringType() && sqlTargetType.Parameters.Count == 0) { - sqlType = new SqlDataTypeReference + var coll = sqlType as SqlDataTypeReferenceWithCollation; + + sqlType = new SqlDataTypeReferenceWithCollation { SqlDataTypeOption = sqlTargetType.SqlDataTypeOption, - Parameters = { new IntegerLiteral { Value = "30" } } + Parameters = { new IntegerLiteral { Value = "30" } }, + Collation = coll?.Collation ?? context.PrimaryDataSource.DefaultCollation, + CollationLabel = coll?.CollationLabel ?? CollationLabel.CoercibleDefault }; } @@ -1619,5 +1692,12 @@ public static bool IsConstantValueExpression(this ScalarExpression expr, Express return true; } + + private static void AssertCollationSensitive(DataTypeReference sqlType, string description, TSqlFragment fragment) + { + if (sqlType is SqlDataTypeReferenceWithCollation collation && + collation.CollationLabel == CollationLabel.NoCollation) + throw new NotSupportedQueryFragmentException($"Cannot resolve collation conflict for {description}", fragment); + } } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs index 664075f2..9b8ac5ea 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs @@ -215,7 +215,7 @@ join is NestedLoopNode loop && var leftSchema = join.LeftSource.GetSchema(context); var rightSchema = join.RightSource.GetSchema(context); - if (ExtractJoinCondition(Filter, loop, leftSchema, rightSchema, out foldedJoin, out var removedCondition)) + if (ExtractJoinCondition(Filter, loop, context, leftSchema, rightSchema, out foldedJoin, out var removedCondition)) { Filter = Filter.RemoveCondition(removedCondition); foldedFilters = true; @@ -238,38 +238,46 @@ join is NestedLoopNode loop && return foldedFilters; } - private bool ExtractJoinCondition(BooleanExpression filter, NestedLoopNode join, INodeSchema leftSchema, INodeSchema rightSchema, out FoldableJoinNode foldedJoin, out BooleanExpression removedCondition) + private bool ExtractJoinCondition(BooleanExpression filter, NestedLoopNode join, NodeCompilationContext context, INodeSchema leftSchema, INodeSchema rightSchema, out FoldableJoinNode foldedJoin, out BooleanExpression removedCondition) { if (filter is BooleanComparisonExpression cmp && - cmp.ComparisonType == BooleanComparisonType.Equals && - cmp.FirstExpression is ColumnReferenceExpression col1 && - cmp.SecondExpression is ColumnReferenceExpression col2) + cmp.ComparisonType == BooleanComparisonType.Equals) { var leftSource = join.LeftSource; var rightSource = join.RightSource; + var col1 = cmp.FirstExpression as ColumnReferenceExpression; + var col2 = cmp.SecondExpression as ColumnReferenceExpression; + + // If join is not directly on a.col = b.col, it may be something that we can calculate such as + // a.col1 + a.col2 = left(b.col3, 10) + // Create a ComputeScalar node for each side so the join can work on a single column + // This only works if each side of the equality expression references columns only from one side of the join + var originalLeftSource = leftSource; + var originalRightSource = rightSource; + + if (col1 == null) + col1 = ComputeColumn(context, cmp.FirstExpression, ref leftSource, ref leftSchema, ref rightSource, ref rightSchema); + + if (col2 == null) + col2 = ComputeColumn(context, cmp.SecondExpression, ref leftSource, ref leftSchema, ref rightSource, ref rightSchema); // Equality expression may be written in the opposite order to the join - swap the tables if necessary - if (rightSchema.ContainsColumn(col1.GetColumnName(), out _) && + if (col1 != null && + col2 != null && + rightSchema.ContainsColumn(col1.GetColumnName(), out _) && leftSchema.ContainsColumn(col2.GetColumnName(), out _)) { Swap(ref leftSource, ref rightSource); Swap(ref leftSchema, ref rightSchema); } - if (leftSchema.ContainsColumn(col1.GetColumnName(), out var leftCol) && + if (col1 != null && + col2 != null && + leftSchema.ContainsColumn(col1.GetColumnName(), out var leftCol) && rightSchema.ContainsColumn(col2.GetColumnName(), out var rightCol)) { - if (leftSource is TableSpoolNode leftSpool) - { - leftSpool.Source.Parent = leftSource.Parent; - leftSource = leftSpool.Source; - } - - if (rightSource is TableSpoolNode rightSpool) - { - rightSpool.Source.Parent = rightSource.Parent; - rightSource = rightSpool.Source; - } + leftSource = RemoveTableSpool(leftSource); + rightSource = RemoveTableSpool(rightSource); // Prefer to use a merge join if either of the join keys are the primary key. // Swap the tables if necessary to use the primary key from the right source. @@ -277,6 +285,7 @@ cmp.FirstExpression is ColumnReferenceExpression col1 && { Swap(ref leftSource, ref rightSource); Swap(ref leftSchema, ref rightSchema); + Swap(ref col1, ref col2); Swap(ref leftCol, ref rightCol); } @@ -303,8 +312,8 @@ cmp.FirstExpression is ColumnReferenceExpression col1 && bin.BinaryExpressionType == BooleanBinaryExpressionType.And) { // Recurse into ANDs but not into ORs - if (ExtractJoinCondition(bin.FirstExpression, join, leftSchema, rightSchema, out foldedJoin, out removedCondition) || - ExtractJoinCondition(bin.SecondExpression, join, leftSchema, rightSchema, out foldedJoin, out removedCondition)) + if (ExtractJoinCondition(bin.FirstExpression, join, context, leftSchema, rightSchema, out foldedJoin, out removedCondition) || + ExtractJoinCondition(bin.SecondExpression, join, context, leftSchema, rightSchema, out foldedJoin, out removedCondition)) { return true; } @@ -315,6 +324,52 @@ cmp.FirstExpression is ColumnReferenceExpression col1 && return false; } + private IDataExecutionPlanNodeInternal RemoveTableSpool(IDataExecutionPlanNodeInternal source) + { + if (source is TableSpoolNode spool) + { + spool.Source.Parent = source.Parent; + return spool.Source; + } + + if (source is ComputeScalarNode computeScalar && computeScalar.Source is TableSpoolNode computeScalarSpool) + { + computeScalarSpool.Source.Parent = computeScalar; + computeScalar.Source = computeScalarSpool.Source; + } + + return source; + } + + private ColumnReferenceExpression ComputeColumn(NodeCompilationContext context, ScalarExpression expression, ref IDataExecutionPlanNodeInternal leftSource, ref INodeSchema leftSchema, ref IDataExecutionPlanNodeInternal rightSource, ref INodeSchema rightSchema) + { + return ComputeColumn(context, expression, ref leftSource, ref leftSchema) ?? ComputeColumn(context, expression, ref rightSource, ref rightSchema); + } + + private ColumnReferenceExpression ComputeColumn(NodeCompilationContext context, ScalarExpression expression, ref IDataExecutionPlanNodeInternal source, ref INodeSchema schema) + { + var columns = expression.GetColumns().ToList(); + var s = schema; + + if (columns.Count == 0 || !columns.All(c => s.ContainsColumn(c, out _))) + return null; + + var exprName = context.GetExpressionName(); + var computeScalar = new ComputeScalarNode + { + Source = source, + Columns = + { + [exprName] = expression + } + }; + + source = computeScalar; + schema = computeScalar.GetSchema(context); + + return exprName.ToColumnReference(); + } + private void Swap(ref T first, ref T second) { var temp = first; @@ -388,6 +443,10 @@ private bool FoldInExistsToFetchXml(NodeCompilationContext context, IList ExecuteInternal(NodeExecutionContext cont // Build the hash table var leftSchema = LeftSource.GetSchema(context); - leftSchema.ContainsColumn(LeftAttribute.GetColumnName(), out var leftCol); - var leftColType = leftSchema.Schema[leftCol]; + var leftCompilationContext = new ExpressionCompilationContext(context, leftSchema, null); + LeftAttribute.GetType(leftCompilationContext, out var leftColType); var rightSchema = RightSource.GetSchema(context); - rightSchema.ContainsColumn(RightAttribute.GetColumnName(), out var rightCol); - var rightColType = rightSchema.Schema[rightCol]; + var rightCompilationContext = new ExpressionCompilationContext(context, rightSchema, null); + RightAttribute.GetType(rightCompilationContext, out var rightColType); if (!SqlTypeConverter.CanMakeConsistentTypes(leftColType, rightColType, context.PrimaryDataSource, out var keyType)) throw new QueryExecutionException($"Cannot match key types {leftColType.ToSql()} and {rightColType.ToSql()}"); - var leftKeyAccessor = (ScalarExpression) leftCol.ToColumnReference(); + var leftKeyAccessor = (ScalarExpression)LeftAttribute; if (!leftColType.IsSameAs(keyType)) leftKeyAccessor = new ConvertCall { Parameter = leftKeyAccessor, DataType = keyType }; - var leftKeyConverter = leftKeyAccessor.Compile(new ExpressionCompilationContext(context, leftSchema, null)); + var leftKeyConverter = leftKeyAccessor.Compile(leftCompilationContext); - var rightKeyAccessor = (ScalarExpression)rightCol.ToColumnReference(); + var rightKeyAccessor = (ScalarExpression)RightAttribute; if (!rightColType.IsSameAs(keyType)) rightKeyAccessor = new ConvertCall { Parameter = rightKeyAccessor, DataType = keyType }; - var rightKeyConverter = rightKeyAccessor.Compile(new ExpressionCompilationContext(context, rightSchema, null)); + var rightKeyConverter = rightKeyAccessor.Compile(rightCompilationContext); var expressionContext = new ExpressionExecutionContext(context); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlTypeConverter.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlTypeConverter.cs index 4c4907ba..e71fd7d9 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlTypeConverter.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SqlTypeConverter.cs @@ -612,6 +612,9 @@ public static Expression Convert(Expression expr, DataTypeReference from, DataTy { throw new NotSupportedQueryFragmentException("Invalid attributes specified for type " + toSqlType.SqlDataTypeOption, toSqlType); } + + if (targetCollation != null) + expr = Expr.Call(() => ConvertCollation(Expr.Arg(), Expr.Arg()), expr, Expression.Constant(targetCollation)); } // Apply changes to precision & scale @@ -646,6 +649,20 @@ public static Expression Convert(Expression expr, DataTypeReference from, DataTy return expr; } + /// + /// Converts a value from one collation to another + /// + /// The value to convert + /// The collation to convert the to + /// A new value with the requested collation + public static SqlString ConvertCollation(SqlString value, Collation collation) + { + if (value.IsNull) + return value; + + return collation.ToSqlString(value.Value); + } + private static SqlString Truncate(SqlString value, int maxLength, string valueOnTruncate, Exception exceptionOnTruncate) { if (value.IsNull) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs index c89ff4e6..ea741c95 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs @@ -18,8 +18,6 @@ namespace MarkMpn.Sql4Cds.Engine { class ExecutionPlanBuilder { - private int _colNameCounter; - private IDictionary _parameterTypes; private ExpressionCompilationContext _staticContext; private NodeCompilationContext _nodeContext; @@ -62,19 +60,19 @@ public IRootExecutionPlanNode[] Build(string sql, IDictionary(StringComparer.OrdinalIgnoreCase); - _staticContext = new ExpressionCompilationContext(DataSources, Options, _parameterTypes, null, null); - _nodeContext = new NodeCompilationContext(DataSources, Options, _parameterTypes); + var parameterTypes = new Dictionary(StringComparer.OrdinalIgnoreCase); + _staticContext = new ExpressionCompilationContext(DataSources, Options, parameterTypes, null, null); + _nodeContext = new NodeCompilationContext(DataSources, Options, parameterTypes); if (parameters != null) { foreach (var param in parameters) - _parameterTypes[param.Key] = param.Value; + parameterTypes[param.Key] = param.Value; } // Add in standard global variables - _parameterTypes["@@IDENTITY"] = DataTypeHelpers.EntityReference; - _parameterTypes["@@ROWCOUNT"] = DataTypeHelpers.Int; + parameterTypes["@@IDENTITY"] = DataTypeHelpers.EntityReference; + parameterTypes["@@ROWCOUNT"] = DataTypeHelpers.Int; var queries = new List(); @@ -123,7 +121,8 @@ public IRootExecutionPlanNode[] Build(string sql, IDictionary(), null, null, _parameterTypes); + var selectQry = ConvertSelectQuerySpec(select, Array.Empty(), null, null, _nodeContext); predicateSource = selectQry.Source; sourceCol = selectQry.ColumnSet[0].SourceColumn; } @@ -516,7 +515,7 @@ private IRootExecutionPlanNodeInternal ConvertSetVariableStatement(SetVariableSt if (set.Parameters != null && set.Parameters.Count > 0) throw new NotSupportedQueryFragmentException("Parameters are not supported", set.Parameters[0]); - if (!_parameterTypes.TryGetValue(set.Variable.Name, out var paramType)) + if (!_nodeContext.ParameterTypes.TryGetValue(set.Variable.Name, out var paramType)) throw new NotSupportedQueryFragmentException("Must declare the scalar variable", set.Variable); // Create the SELECT statement that generates the required information @@ -590,7 +589,7 @@ private IRootExecutionPlanNodeInternal[] ConvertDeclareVariableStatement(Declare foreach (var declaration in declare.Declarations) { - if (_parameterTypes.ContainsKey(declaration.VariableName.Value)) + if (_nodeContext.ParameterTypes.ContainsKey(declaration.VariableName.Value)) throw new NotSupportedQueryFragmentException("The variable name has already been declared. Variable names must be unique within a query batch", declaration.VariableName); // Apply default maximum length for [n][var]char types @@ -603,12 +602,20 @@ private IRootExecutionPlanNodeInternal[] ConvertDeclareVariableStatement(Declare throw new NotSupportedQueryFragmentException("Table variables are not supported", dataType); if (dataType.SqlDataTypeOption == SqlDataTypeOption.Char || - dataType.SqlDataTypeOption == SqlDataTypeOption.NChar || - dataType.SqlDataTypeOption == SqlDataTypeOption.VarChar || - dataType.SqlDataTypeOption == SqlDataTypeOption.NVarChar) + dataType.SqlDataTypeOption == SqlDataTypeOption.NChar || + dataType.SqlDataTypeOption == SqlDataTypeOption.VarChar || + dataType.SqlDataTypeOption == SqlDataTypeOption.NVarChar) { if (dataType.Parameters.Count == 0) dataType.Parameters.Add(new IntegerLiteral { Value = "1" }); + + declaration.DataType = new SqlDataTypeReferenceWithCollation + { + SqlDataTypeOption = dataType.SqlDataTypeOption, + Parameters = { dataType.Parameters[0] }, + Collation = _nodeContext.PrimaryDataSource.DefaultCollation, + CollationLabel = CollationLabel.CoercibleDefault + }; } } @@ -616,7 +623,7 @@ private IRootExecutionPlanNodeInternal[] ConvertDeclareVariableStatement(Declare // Make the variables available in our local copy of parameters so later statements // in the same batch can use them - _parameterTypes[declaration.VariableName.Value] = declaration.DataType; + _nodeContext.ParameterTypes[declaration.VariableName.Value] = declaration.DataType; if (declaration.Value != null) { @@ -773,7 +780,7 @@ private InsertNode ConvertInsertStatement(InsertStatement insert) string[] columns; if (insert.InsertSpecification.InsertSource is ValuesInsertSource values) - source = ConvertInsertValuesSource(values, insert.OptimizerHints, null, null, _parameterTypes, out columns); + source = ConvertInsertValuesSource(values, insert.OptimizerHints, null, null, _nodeContext, out columns); else if (insert.InsertSpecification.InsertSource is SelectInsertSource select) source = ConvertInsertSelectSource(select, insert.OptimizerHints, out columns); else @@ -782,22 +789,22 @@ private InsertNode ConvertInsertStatement(InsertStatement insert) return ConvertInsertSpecification(target, insert.InsertSpecification.Columns, source, columns, insert.OptimizerHints); } - private IDataExecutionPlanNodeInternal ConvertInsertValuesSource(ValuesInsertSource values, IList hints, INodeSchema outerSchema, Dictionary outerReferences, IDictionary parameterTypes, out string[] columns) + private IDataExecutionPlanNodeInternal ConvertInsertValuesSource(ValuesInsertSource values, IList hints, INodeSchema outerSchema, Dictionary outerReferences, NodeCompilationContext context, out string[] columns) { // Convert the values to an InlineDerviedTable var table = new InlineDerivedTable { - Alias = new Identifier { Value = $"Expr{++_colNameCounter}" } + Alias = new Identifier { Value = context.GetExpressionName() } }; foreach (var col in values.RowValues[0].ColumnValues) - table.Columns.Add(new Identifier { Value = $"Expr{++_colNameCounter}" }); + table.Columns.Add(new Identifier { Value = context.GetExpressionName() }); foreach (var row in values.RowValues) table.RowValues.Add(row); columns = table.Columns.Select(col => col.Value).ToArray(); - return ConvertInlineDerivedTable(table, hints, outerSchema, outerReferences, parameterTypes); + return ConvertInlineDerivedTable(table, hints, outerSchema, outerReferences, context); } private IExecutionPlanNodeInternal ConvertInsertSelectSource(SelectInsertSource selectSource, IList hints, out string[] columns) @@ -986,7 +993,7 @@ attr is LookupAttributeMetadata lookupAttr && if (targetLookupAttribute.Targets.Length > 1 && !virtualTypeAttributes.Contains(targetAttrName + "type") && targetLookupAttribute.AttributeType != AttributeTypeCode.PartyList && - (schema == null || (node.ColumnMappings[targetAttrName].ToColumnReference().GetType(GetExpressionContext(schema), out var lookupType) != typeof(SqlEntityReference) && lookupType != DataTypeHelpers.ImplicitIntForNullLiteral))) + (schema == null || (node.ColumnMappings[targetAttrName].ToColumnReference().GetType(GetExpressionContext(schema, _nodeContext), out var lookupType) != typeof(SqlEntityReference) && lookupType != DataTypeHelpers.ImplicitIntForNullLiteral))) { // Special case: not required for listmember.entityid if (metadata.LogicalName == "listmember" && targetLookupAttribute.LogicalName == "entityid") @@ -1429,7 +1436,7 @@ private UpdateNode ConvertSetClause(IList setClauses, HashSet update.PrimaryIdSource = $"{targetAlias}.{targetMetadata.PrimaryIdAttribute}"; var schema = select.Source.GetSchema(_nodeContext); - var expressionContext = new ExpressionCompilationContext(DataSources, Options, _parameterTypes, schema, null); + var expressionContext = GetExpressionContext(schema, _nodeContext); foreach (var assignment in setClauses.Cast()) { @@ -1614,7 +1621,7 @@ private IRootExecutionPlanNodeInternal ConvertSelectStatement(SelectStatement se variableAssignments.Add(set.Variable.Name); - if (!_parameterTypes.TryGetValue(set.Variable.Name, out var paramType)) + if (!_nodeContext.ParameterTypes.TryGetValue(set.Variable.Name, out var paramType)) throw new NotSupportedQueryFragmentException("Must declare the scalar variable", set.Variable); // Create the SELECT statement that generates the required information @@ -1670,7 +1677,7 @@ private IRootExecutionPlanNodeInternal ConvertSelectStatement(SelectStatement se } } - var converted = ConvertSelectStatement(select.QueryExpression, select.OptimizerHints, null, null, _parameterTypes); + var converted = ConvertSelectStatement(select.QueryExpression, select.OptimizerHints, null, null, _nodeContext); if (variableAssignments.Count > 0) { @@ -1688,26 +1695,26 @@ private IRootExecutionPlanNodeInternal ConvertSelectStatement(SelectStatement se return converted; } - private SelectNode ConvertSelectStatement(QueryExpression query, IList hints, INodeSchema outerSchema, Dictionary outerReferences, IDictionary parameterTypes) + private SelectNode ConvertSelectStatement(QueryExpression query, IList hints, INodeSchema outerSchema, Dictionary outerReferences, NodeCompilationContext context) { if (query is QuerySpecification querySpec) - return ConvertSelectQuerySpec(querySpec, hints, outerSchema, outerReferences, parameterTypes); + return ConvertSelectQuerySpec(querySpec, hints, outerSchema, outerReferences, context); if (query is BinaryQueryExpression binary) - return ConvertBinaryQuery(binary, hints, outerSchema, outerReferences, parameterTypes); + return ConvertBinaryQuery(binary, hints, outerSchema, outerReferences, context); if (query is QueryParenthesisExpression paren) { paren.QueryExpression.ForClause = paren.ForClause; paren.QueryExpression.OffsetClause = paren.OffsetClause; paren.QueryExpression.OrderByClause = paren.OrderByClause; - return ConvertSelectStatement(paren.QueryExpression, hints, outerSchema, outerReferences, parameterTypes); + return ConvertSelectStatement(paren.QueryExpression, hints, outerSchema, outerReferences, context); } throw new NotSupportedQueryFragmentException("Unhandled SELECT query expression", query); } - private SelectNode ConvertBinaryQuery(BinaryQueryExpression binary, IList hints, INodeSchema outerSchema, Dictionary outerReferences, IDictionary parameterTypes) + private SelectNode ConvertBinaryQuery(BinaryQueryExpression binary, IList hints, INodeSchema outerSchema, Dictionary outerReferences, NodeCompilationContext context) { if (binary.BinaryQueryExpressionType != BinaryQueryExpressionType.Union) throw new NotSupportedQueryFragmentException($"Unhandled {binary.BinaryQueryExpressionType} query type", binary); @@ -1715,8 +1722,8 @@ private SelectNode ConvertBinaryQuery(BinaryQueryExpression binary, IList new ColumnReferenceExpression { MultiPartIdentifier = new MultiPartIdentifier { Identifiers = { new Identifier { Value = col.OutputColumn } } } }).ToArray(), binary, parameterTypes, outerSchema, outerReferences, null); - node = ConvertOffsetClause(node, binary.OffsetClause, parameterTypes); + node = ConvertOrderByClause(node, hints, binary.OrderByClause, concat.ColumnSet.Select(col => new ColumnReferenceExpression { MultiPartIdentifier = new MultiPartIdentifier { Identifiers = { new Identifier { Value = col.OutputColumn } } } }).ToArray(), binary, context, outerSchema, outerReferences, null); + node = ConvertOffsetClause(node, binary.OffsetClause, context); var select = new SelectNode { Source = node }; select.ColumnSet.AddRange(concat.ColumnSet.Select((col, i) => new SelectColumn { SourceColumn = col.OutputColumn, SourceExpression = col.SourceExpressions[0], OutputColumn = left.ColumnSet[i].OutputColumn })); @@ -1766,7 +1773,7 @@ private SelectNode ConvertBinaryQuery(BinaryQueryExpression binary, IList hints, INodeSchema outerSchema, Dictionary outerReferences, IDictionary parameterTypes) + private SelectNode ConvertSelectQuerySpec(QuerySpecification querySpec, IList hints, INodeSchema outerSchema, Dictionary outerReferences, NodeCompilationContext context) { // Check for any aggregates in the FROM or WHERE clauses var aggregateCollector = new AggregateCollectingVisitor(); @@ -1786,33 +1793,33 @@ private SelectNode ConvertSelectQuerySpec(QuerySpecification querySpec, IList() } } : ConvertFromClause(querySpec.FromClause.TableReferences, hints, querySpec, outerSchema, outerReferences, parameterTypes); + var node = querySpec.FromClause == null ? new ConstantScanNode { Values = { new Dictionary() } } : ConvertFromClause(querySpec.FromClause.TableReferences, hints, querySpec, outerSchema, outerReferences, context); - node = ConvertInSubqueries(node, hints, querySpec, parameterTypes, outerSchema, outerReferences); - node = ConvertExistsSubqueries(node, hints, querySpec, parameterTypes, outerSchema, outerReferences); + node = ConvertInSubqueries(node, hints, querySpec, context, outerSchema, outerReferences); + node = ConvertExistsSubqueries(node, hints, querySpec, context, outerSchema, outerReferences); // Add filters from WHERE - node = ConvertWhereClause(node, hints, querySpec.WhereClause, outerSchema, outerReferences, parameterTypes, querySpec); + node = ConvertWhereClause(node, hints, querySpec.WhereClause, outerSchema, outerReferences, context, querySpec); // Add aggregates from GROUP BY/SELECT/HAVING/ORDER BY var preGroupByNode = node; - node = ConvertGroupByAggregates(node, querySpec, parameterTypes, outerSchema, outerReferences); - var nonAggregateSchema = preGroupByNode == node ? null : preGroupByNode.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); + node = ConvertGroupByAggregates(node, querySpec, context, outerSchema, outerReferences); + var nonAggregateSchema = preGroupByNode == node ? null : preGroupByNode.GetSchema(context); // Add filters from HAVING - node = ConvertHavingClause(node, hints, querySpec.HavingClause, parameterTypes, outerSchema, outerReferences, querySpec, nonAggregateSchema); + node = ConvertHavingClause(node, hints, querySpec.HavingClause, context, outerSchema, outerReferences, querySpec, nonAggregateSchema); // Add DISTINCT var distinct = querySpec.UniqueRowFilter == UniqueRowFilter.Distinct ? new DistinctNode { Source = node } : null; node = distinct ?? node; // Add SELECT - var selectNode = ConvertSelectClause(querySpec.SelectElements, hints, node, distinct, querySpec, parameterTypes, outerSchema, outerReferences, nonAggregateSchema); + var selectNode = ConvertSelectClause(querySpec.SelectElements, hints, node, distinct, querySpec, context, outerSchema, outerReferences, nonAggregateSchema); node = selectNode.Source; // Add sorts from ORDER BY var selectFields = new List(); - var preOrderSchema = node.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); + var preOrderSchema = node.GetSchema(context); foreach (var el in querySpec.SelectElements) { if (el is SelectScalarExpression expr) @@ -1835,21 +1842,21 @@ private SelectNode ConvertSelectQuerySpec(QuerySpecification querySpec, IList hints, TSqlFragment query, IDictionary parameterTypes, INodeSchema outerSchema, IDictionary outerReferences) + private IDataExecutionPlanNodeInternal ConvertInSubqueries(IDataExecutionPlanNodeInternal source, IList hints, TSqlFragment query, NodeCompilationContext context, INodeSchema outerSchema, IDictionary outerReferences) { var visitor = new InSubqueryVisitor(); query.Accept(visitor); @@ -1859,12 +1866,12 @@ private IDataExecutionPlanNodeInternal ConvertInSubqueries(IDataExecutionPlanNod var computeScalar = source as ComputeScalarNode; var rewrites = new Dictionary(); - var schema = source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); + var schema = source.GetSchema(context); foreach (var inSubquery in visitor.InSubqueries) { // Validate the LHS expression - inSubquery.Expression.GetType(GetExpressionContext(schema, parameterTypes), out _); + inSubquery.Expression.GetType(GetExpressionContext(schema, context), out _); // Each query of the format "col1 IN (SELECT col2 FROM source)" becomes a left outer join: // LEFT JOIN source ON col1 = col2 @@ -1879,7 +1886,7 @@ private IDataExecutionPlanNodeInternal ConvertInSubqueries(IDataExecutionPlanNod source = computeScalar; } - var alias = $"Expr{++_colNameCounter}"; + var alias = context.GetExpressionName(); computeScalar.Columns[alias] = inSubquery.Expression.Clone(); lhsCol = alias.ToColumnReference(); } @@ -1890,9 +1897,10 @@ private IDataExecutionPlanNodeInternal ConvertInSubqueries(IDataExecutionPlanNod lhsCol = lhsColNormalized.ToColumnReference(); } - var parameters = parameterTypes == null ? new Dictionary(StringComparer.OrdinalIgnoreCase) : new Dictionary(parameterTypes, StringComparer.OrdinalIgnoreCase); + var parameters = context.ParameterTypes == null ? new Dictionary(StringComparer.OrdinalIgnoreCase) : new Dictionary(context.ParameterTypes, StringComparer.OrdinalIgnoreCase); + var innerContext = new NodeCompilationContext(context, parameters); var references = new Dictionary(); - var innerQuery = ConvertSelectStatement(inSubquery.Subquery.QueryExpression, hints, schema, references, parameters); + var innerQuery = ConvertSelectStatement(inSubquery.Subquery.QueryExpression, hints, schema, references, innerContext); // Scalar subquery must return exactly one column and one row if (innerQuery.ColumnSet.Count != 1) @@ -1904,7 +1912,7 @@ private IDataExecutionPlanNodeInternal ConvertInSubqueries(IDataExecutionPlanNod if (references.Count == 0) { - if (UseMergeJoin(source, innerQuery.Source, references, testColumn, lhsCol.GetColumnName(), true, out var outputCol, out var merge)) + if (UseMergeJoin(source, innerQuery.Source, context, references, testColumn, lhsCol.GetColumnName(), true, out var outputCol, out var merge)) { testColumn = outputCol; join = merge; @@ -1924,7 +1932,7 @@ private IDataExecutionPlanNodeInternal ConvertInSubqueries(IDataExecutionPlanNod // This isn't a correlated subquery, so we can use a foldable join type. Alias the results so there's no conflict with the // same table being used inside the IN subquery and elsewhere - var alias = new AliasNode(innerQuery, new Identifier { Value = $"Expr{++_colNameCounter}" }); + var alias = new AliasNode(innerQuery, new Identifier { Value = context.GetExpressionName() }); testColumn = $"{alias.Alias}.{alias.ColumnSet[0].OutputColumn}"; join = new HashJoinNode @@ -1939,7 +1947,7 @@ private IDataExecutionPlanNodeInternal ConvertInSubqueries(IDataExecutionPlanNod if (!join.SemiJoin) { // Convert the join to a semi join to ensure requests for wildcard columns aren't folded to the IN subquery - var definedValue = $"Expr{++_colNameCounter}"; + var definedValue = context.GetExpressionName(); join.SemiJoin = true; join.DefinedValues[definedValue] = testColumn; testColumn = definedValue; @@ -1953,9 +1961,9 @@ private IDataExecutionPlanNodeInternal ConvertInSubqueries(IDataExecutionPlanNod // get all the related records and spool that in memory to get the relevant results in the nested loop. Need to understand how // many rows are likely from the outer query to work out if this is going to be more efficient or not. if (innerQuery.Source is ISingleSourceExecutionPlanNode loopRightSourceSimple) - InsertCorrelatedSubquerySpool(loopRightSourceSimple, source, hints, parameterTypes, references.Values.ToArray()); + InsertCorrelatedSubquerySpool(loopRightSourceSimple, source, hints, context, references.Values.ToArray()); - var definedValue = $"Expr{++_colNameCounter}"; + var definedValue = context.GetExpressionName(); join = new NestedLoopNode { @@ -1991,7 +1999,7 @@ private IDataExecutionPlanNodeInternal ConvertInSubqueries(IDataExecutionPlanNod return source; } - private IDataExecutionPlanNodeInternal ConvertExistsSubqueries(IDataExecutionPlanNodeInternal source, IList hints, TSqlFragment query, IDictionary parameterTypes, INodeSchema outerSchema, IDictionary outerReferences) + private IDataExecutionPlanNodeInternal ConvertExistsSubqueries(IDataExecutionPlanNodeInternal source, IList hints, TSqlFragment query, NodeCompilationContext context, INodeSchema outerSchema, IDictionary outerReferences) { var visitor = new ExistsSubqueryVisitor(); query.Accept(visitor); @@ -2000,14 +2008,15 @@ private IDataExecutionPlanNodeInternal ConvertExistsSubqueries(IDataExecutionPla return source; var rewrites = new Dictionary(); - var schema = source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); + var schema = source.GetSchema(context); foreach (var existsSubquery in visitor.ExistsSubqueries) { // Each query of the format "EXISTS (SELECT * FROM source)" becomes a outer semi join - var parameters = parameterTypes == null ? new Dictionary(StringComparer.OrdinalIgnoreCase) : new Dictionary(parameterTypes, StringComparer.OrdinalIgnoreCase); + var parameters = context.ParameterTypes == null ? new Dictionary(StringComparer.OrdinalIgnoreCase) : new Dictionary(context.ParameterTypes, StringComparer.OrdinalIgnoreCase); + var innerContext = new NodeCompilationContext(context, parameters); var references = new Dictionary(); - var innerQuery = ConvertSelectStatement(existsSubquery.Subquery.QueryExpression, hints, schema, references, parameters); + var innerQuery = ConvertSelectStatement(existsSubquery.Subquery.QueryExpression, hints, schema, references, innerContext); var innerSchema = innerQuery.Source.GetSchema(new NodeCompilationContext(DataSources, Options, parameters)); var innerSchemaPrimaryKey = innerSchema.PrimaryKey; @@ -2029,7 +2038,7 @@ private IDataExecutionPlanNodeInternal ConvertExistsSubqueries(IDataExecutionPla // We need a non-null value to use if (innerSchemaPrimaryKey == null) { - innerSchemaPrimaryKey = $"Expr{++_colNameCounter}"; + innerSchemaPrimaryKey = context.GetExpressionName(); if (!(innerQuery.Source is ComputeScalarNode computeScalar)) { @@ -2047,7 +2056,7 @@ private IDataExecutionPlanNodeInternal ConvertExistsSubqueries(IDataExecutionPla SpoolType = SpoolType.Lazy }; - testColumn = $"Expr{++_colNameCounter}"; + testColumn = context.GetExpressionName(); join = new NestedLoopNode { @@ -2062,7 +2071,7 @@ private IDataExecutionPlanNodeInternal ConvertExistsSubqueries(IDataExecutionPla } }; } - else if (UseMergeJoin(source, innerQuery.Source, references, null, null, true, out testColumn, out var merge)) + else if (UseMergeJoin(source, innerQuery.Source, context, references, null, null, true, out testColumn, out var merge)) { join = merge; } @@ -2074,7 +2083,7 @@ private IDataExecutionPlanNodeInternal ConvertExistsSubqueries(IDataExecutionPla // get all the related records and spool that in memory to get the relevant results in the nested loop. Need to understand how // many rows are likely from the outer query to work out if this is going to be more efficient or not. if (innerQuery.Source is ISingleSourceExecutionPlanNode loopRightSourceSimple) - InsertCorrelatedSubquerySpool(loopRightSourceSimple, source, hints, parameterTypes, references.Values.ToArray()); + InsertCorrelatedSubquerySpool(loopRightSourceSimple, source, hints, context, references.Values.ToArray()); // We only need one record to check for EXISTS if (!(innerQuery.Source is TopNode) && !(innerQuery.Source is OffsetFetchNode)) @@ -2089,7 +2098,7 @@ private IDataExecutionPlanNodeInternal ConvertExistsSubqueries(IDataExecutionPla // We need a non-null value to use if (innerSchemaPrimaryKey == null) { - innerSchemaPrimaryKey = $"Expr{++_colNameCounter}"; + innerSchemaPrimaryKey = context.GetExpressionName(); if (!(innerQuery.Source is ComputeScalarNode computeScalar)) { @@ -2100,7 +2109,7 @@ private IDataExecutionPlanNodeInternal ConvertExistsSubqueries(IDataExecutionPla computeScalar.Columns[innerSchemaPrimaryKey] = new IntegerLiteral { Value = "1" }; } - var definedValue = $"Expr{++_colNameCounter}"; + var definedValue = context.GetExpressionName(); join = new NestedLoopNode { @@ -2130,18 +2139,18 @@ private IDataExecutionPlanNodeInternal ConvertExistsSubqueries(IDataExecutionPla return source; } - private IDataExecutionPlanNodeInternal ConvertHavingClause(IDataExecutionPlanNodeInternal source, IList hints, HavingClause havingClause, IDictionary parameterTypes, INodeSchema outerSchema, IDictionary outerReferences, TSqlFragment query, INodeSchema nonAggregateSchema) + private IDataExecutionPlanNodeInternal ConvertHavingClause(IDataExecutionPlanNodeInternal source, IList hints, HavingClause havingClause, NodeCompilationContext context, INodeSchema outerSchema, IDictionary outerReferences, TSqlFragment query, INodeSchema nonAggregateSchema) { if (havingClause == null) return source; - CaptureOuterReferences(outerSchema, source, havingClause, parameterTypes, outerReferences); + CaptureOuterReferences(outerSchema, source, havingClause, context, outerReferences); var computeScalar = new ComputeScalarNode { Source = source }; - ConvertScalarSubqueries(havingClause.SearchCondition, hints, ref source, computeScalar, parameterTypes, query); + ConvertScalarSubqueries(havingClause.SearchCondition, hints, ref source, computeScalar, context, query); // Validate the final expression - havingClause.SearchCondition.GetType(GetExpressionContext(source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)), parameterTypes, nonAggregateSchema), out _); + havingClause.SearchCondition.GetType(GetExpressionContext(source.GetSchema(context), context, nonAggregateSchema), out _); return new FilterNode { @@ -2150,7 +2159,7 @@ private IDataExecutionPlanNodeInternal ConvertHavingClause(IDataExecutionPlanNod }; } - private IDataExecutionPlanNodeInternal ConvertGroupByAggregates(IDataExecutionPlanNodeInternal source, QuerySpecification querySpec, IDictionary parameterTypes, INodeSchema outerSchema, IDictionary outerReferences) + private IDataExecutionPlanNodeInternal ConvertGroupByAggregates(IDataExecutionPlanNodeInternal source, QuerySpecification querySpec, NodeCompilationContext context, INodeSchema outerSchema, IDictionary outerReferences) { // Check if there is a GROUP BY clause or aggregate functions to convert if (querySpec.GroupByClause == null) @@ -2169,7 +2178,7 @@ private IDataExecutionPlanNodeInternal ConvertGroupByAggregates(IDataExecutionPl throw new NotSupportedQueryFragmentException("Unhandled GROUP BY option", querySpec.GroupByClause); } - var schema = source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); + var schema = source.GetSchema(context); // Create the grouping expressions. Grouping is done on single columns only - if a grouping is a more complex expression, // create a new calculated column using a Compute Scalar node first. @@ -2177,7 +2186,7 @@ private IDataExecutionPlanNodeInternal ConvertGroupByAggregates(IDataExecutionPl if (querySpec.GroupByClause != null) { - CaptureOuterReferences(outerSchema, source, querySpec.GroupByClause, parameterTypes, outerReferences); + CaptureOuterReferences(outerSchema, source, querySpec.GroupByClause, context, outerReferences); foreach (var grouping in querySpec.GroupByClause.GroupingSpecifications) { @@ -2185,7 +2194,7 @@ private IDataExecutionPlanNodeInternal ConvertGroupByAggregates(IDataExecutionPl throw new NotSupportedQueryFragmentException("Unhandled GROUP BY expression", grouping); // Validate the GROUP BY expression - exprGroup.Expression.GetType(GetExpressionContext(schema, parameterTypes), out _); + exprGroup.Expression.GetType(GetExpressionContext(schema, context), out _); if (exprGroup.Expression is ColumnReferenceExpression col) { @@ -2247,7 +2256,7 @@ private IDataExecutionPlanNodeInternal ConvertGroupByAggregates(IDataExecutionPl } if (name == null) - name = $"Expr{++_colNameCounter}"; + name = context.GetExpressionName(); col = new ColumnReferenceExpression { @@ -2296,7 +2305,7 @@ private IDataExecutionPlanNodeInternal ConvertGroupByAggregates(IDataExecutionPl foreach (var aggregate in aggregateCollector.Aggregates.Select(a => new { Expression = a, Alias = (string)null }).Concat(aggregateCollector.SelectAggregates.Select(s => new { Expression = (FunctionCall)s.Expression, Alias = s.ColumnName?.Identifier?.Value }))) { - CaptureOuterReferences(outerSchema, source, aggregate.Expression, parameterTypes, outerReferences); + CaptureOuterReferences(outerSchema, source, aggregate.Expression, context, outerReferences); var converted = new Aggregate { @@ -2341,7 +2350,7 @@ private IDataExecutionPlanNodeInternal ConvertGroupByAggregates(IDataExecutionPl if (converted.AggregateType == AggregateType.CountStar) converted.SqlExpression = null; else - converted.SqlExpression.GetType(GetExpressionContext(schema, parameterTypes), out _); + converted.SqlExpression.GetType(GetExpressionContext(schema, context), out _); // Create a name for the column that holds the aggregate value in the result set. string aggregateName; @@ -2362,7 +2371,7 @@ private IDataExecutionPlanNodeInternal ConvertGroupByAggregates(IDataExecutionPl } else { - aggregateName = $"Expr{++_colNameCounter}"; + aggregateName = context.GetExpressionName(); } hashMatch.Aggregates[aggregateName] = converted; @@ -2415,7 +2424,7 @@ func.Parameters[0] is ColumnReferenceExpression datepart && return false; } - private IDataExecutionPlanNodeInternal ConvertOffsetClause(IDataExecutionPlanNodeInternal source, OffsetClause offsetClause, IDictionary parameterTypes) + private IDataExecutionPlanNodeInternal ConvertOffsetClause(IDataExecutionPlanNodeInternal source, OffsetClause offsetClause, NodeCompilationContext context) { if (offsetClause == null) return source; @@ -2438,7 +2447,7 @@ private IDataExecutionPlanNodeInternal ConvertOffsetClause(IDataExecutionPlanNod }; } - private IDataExecutionPlanNodeInternal ConvertTopClause(IDataExecutionPlanNodeInternal source, TopRowFilter topRowFilter, OrderByClause orderByClause, IDictionary parameterTypes) + private IDataExecutionPlanNodeInternal ConvertTopClause(IDataExecutionPlanNodeInternal source, TopRowFilter topRowFilter, OrderByClause orderByClause, NodeCompilationContext context) { if (topRowFilter == null) return source; @@ -2456,7 +2465,7 @@ private IDataExecutionPlanNodeInternal ConvertTopClause(IDataExecutionPlanNodeIn if (orderByClause == null) throw new NotSupportedQueryFragmentException("The TOP N WITH TIES clause is not allowed without a corresponding ORDER BY clause", topRowFilter); - var schema = source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); + var schema = source.GetSchema(context); foreach (var sort in orderByClause.OrderByElements) { @@ -2496,17 +2505,17 @@ private IDataExecutionPlanNodeInternal ConvertTopClause(IDataExecutionPlanNodeIn }; } - private IDataExecutionPlanNodeInternal ConvertOrderByClause(IDataExecutionPlanNodeInternal source, IList hints, OrderByClause orderByClause, ScalarExpression[] selectList, TSqlFragment query, IDictionary parameterTypes, INodeSchema outerSchema, Dictionary outerReferences, INodeSchema nonAggregateSchema) + private IDataExecutionPlanNodeInternal ConvertOrderByClause(IDataExecutionPlanNodeInternal source, IList hints, OrderByClause orderByClause, ScalarExpression[] selectList, TSqlFragment query, NodeCompilationContext context, INodeSchema outerSchema, Dictionary outerReferences, INodeSchema nonAggregateSchema) { if (orderByClause == null) return source; - CaptureOuterReferences(outerSchema, source, orderByClause, parameterTypes, outerReferences); + CaptureOuterReferences(outerSchema, source, orderByClause, context, outerReferences); var computeScalar = new ComputeScalarNode { Source = source }; - ConvertScalarSubqueries(orderByClause, hints, ref source, computeScalar, parameterTypes, query); + ConvertScalarSubqueries(orderByClause, hints, ref source, computeScalar, context, query); - var schema = source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); + var schema = source.GetSchema(context); var sort = new SortNode { Source = source }; // Sorts can use aliases from the SELECT clause @@ -2553,15 +2562,15 @@ private IDataExecutionPlanNodeInternal ConvertOrderByClause(IDataExecutionPlanNo !(orderBy.Expression is VariableReference) && !(orderBy.Expression is Literal)) { - var calculated = ComputeScalarExpression(orderBy.Expression, hints, query, computeScalar, nonAggregateSchema, parameterTypes, ref source); + var calculated = ComputeScalarExpression(orderBy.Expression, hints, query, computeScalar, nonAggregateSchema, context, ref source); sort.Source = source; - schema = source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); + schema = source.GetSchema(context); calculationRewrites[orderBy.Expression] = calculated.ToColumnReference(); } // Validate the expression - orderBy.Expression.GetType(GetExpressionContext(schema, parameterTypes, nonAggregateSchema), out _); + orderBy.Expression.GetType(GetExpressionContext(schema, context, nonAggregateSchema), out _); sort.Sorts.Add(orderBy.Clone()); } @@ -2576,7 +2585,7 @@ private IDataExecutionPlanNodeInternal ConvertOrderByClause(IDataExecutionPlanNo return sort; } - private IDataExecutionPlanNodeInternal ConvertWhereClause(IDataExecutionPlanNodeInternal source, IList hints, WhereClause whereClause, INodeSchema outerSchema, Dictionary outerReferences, IDictionary parameterTypes, TSqlFragment query) + private IDataExecutionPlanNodeInternal ConvertWhereClause(IDataExecutionPlanNodeInternal source, IList hints, WhereClause whereClause, INodeSchema outerSchema, Dictionary outerReferences, NodeCompilationContext context, TSqlFragment query) { if (whereClause == null) return source; @@ -2584,13 +2593,13 @@ private IDataExecutionPlanNodeInternal ConvertWhereClause(IDataExecutionPlanNode if (whereClause.Cursor != null) throw new NotSupportedQueryFragmentException("Unsupported cursor", whereClause.Cursor); - CaptureOuterReferences(outerSchema, source, whereClause.SearchCondition, parameterTypes, outerReferences); + CaptureOuterReferences(outerSchema, source, whereClause.SearchCondition, context, outerReferences); var computeScalar = new ComputeScalarNode { Source = source }; - ConvertScalarSubqueries(whereClause.SearchCondition, hints, ref source, computeScalar, parameterTypes, query); + ConvertScalarSubqueries(whereClause.SearchCondition, hints, ref source, computeScalar, context, query); // Validate the final expression - whereClause.SearchCondition.GetType(GetExpressionContext(source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)), parameterTypes), out _); + whereClause.SearchCondition.GetType(GetExpressionContext(source.GetSchema(context), context), out _); return new FilterNode { @@ -2599,7 +2608,7 @@ private IDataExecutionPlanNodeInternal ConvertWhereClause(IDataExecutionPlanNode }; } - private TSqlFragment CaptureOuterReferences(INodeSchema outerSchema, IDataExecutionPlanNodeInternal source, TSqlFragment query, IDictionary parameterTypes, IDictionary outerReferences) + private TSqlFragment CaptureOuterReferences(INodeSchema outerSchema, IDataExecutionPlanNodeInternal source, TSqlFragment query, NodeCompilationContext context, IDictionary outerReferences) { if (outerSchema == null) return query; @@ -2607,7 +2616,7 @@ private TSqlFragment CaptureOuterReferences(INodeSchema outerSchema, IDataExecut // We're in a subquery. Check if any columns in the WHERE clause are from the outer query // so we know which columns to pass through and rewrite the filter to use parameters var rewrites = new Dictionary(); - var innerSchema = source?.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); + var innerSchema = source?.GetSchema(context); var columns = query.GetColumns(); foreach (var column in columns) @@ -2624,9 +2633,9 @@ private TSqlFragment CaptureOuterReferences(INodeSchema outerSchema, IDataExecut if (fromOuter) { - var paramName = $"@Expr{++_colNameCounter}"; + var paramName = "@" + context.GetExpressionName(); outerReferences.Add(outerColumn, paramName); - parameterTypes[paramName] = outerSchema.Schema[outerColumn]; + context.ParameterTypes[paramName] = outerSchema.Schema[outerColumn]; rewrites.Add( new ColumnReferenceExpression { MultiPartIdentifier = new MultiPartIdentifier { Identifiers = { new Identifier { Value = column } } } }, @@ -2643,9 +2652,9 @@ private TSqlFragment CaptureOuterReferences(INodeSchema outerSchema, IDataExecut return query; } - private SelectNode ConvertSelectClause(IList selectElements, IList hints, IDataExecutionPlanNodeInternal node, DistinctNode distinct, TSqlFragment query, IDictionary parameterTypes, INodeSchema outerSchema, IDictionary outerReferences, INodeSchema nonAggregateSchema) + private SelectNode ConvertSelectClause(IList selectElements, IList hints, IDataExecutionPlanNodeInternal node, DistinctNode distinct, TSqlFragment query, NodeCompilationContext context, INodeSchema outerSchema, IDictionary outerReferences, INodeSchema nonAggregateSchema) { - var schema = node.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); + var schema = node.GetSchema(context); var select = new SelectNode { @@ -2659,19 +2668,20 @@ private SelectNode ConvertSelectClause(IList selectElements, ILis foreach (var element in selectElements) { - CaptureOuterReferences(outerSchema, computeScalar, element, parameterTypes, outerReferences); + CaptureOuterReferences(outerSchema, computeScalar, element, context, outerReferences); if (element is SelectScalarExpression scalar) { if (scalar.Expression is ColumnReferenceExpression col) { + // Check the expression is valid. This will throw an exception in case of missing columns etc. + col.GetType(GetExpressionContext(schema, context, nonAggregateSchema), out var colType); + if (colType is SqlDataTypeReferenceWithCollation colTypeColl && colTypeColl.CollationLabel == CollationLabel.NoCollation) + throw new NotSupportedQueryFragmentException("Cannot resolve collation conflict", element); + var colName = col.GetColumnName(); - if (!schema.ContainsColumn(colName, out colName)) - { - // Column name isn't valid. Use the expression extensions to throw a consistent error message - col.GetType(GetExpressionContext(schema, parameterTypes, nonAggregateSchema), out _); - } + schema.ContainsColumn(colName, out colName); var alias = scalar.ColumnName?.Value ?? col.MultiPartIdentifier.Identifiers.Last().Value; @@ -2685,7 +2695,12 @@ private SelectNode ConvertSelectClause(IList selectElements, ILis else { var scalarSource = distinct?.Source ?? node; - var alias = ComputeScalarExpression(scalar.Expression, hints, query, computeScalar, nonAggregateSchema, parameterTypes, ref scalarSource); + var alias = ComputeScalarExpression(scalar.Expression, hints, query, computeScalar, nonAggregateSchema, context, ref scalarSource); + + var scalarSchema = computeScalar.GetSchema(context); + var colType = scalarSchema.Schema[alias]; + if (colType is SqlDataTypeReferenceWithCollation colTypeColl && colTypeColl.CollationLabel == CollationLabel.NoCollation) + throw new NotSupportedQueryFragmentException("Cannot resolve collation conflict", element); if (distinct != null) distinct.Source = scalarSource; @@ -2704,9 +2719,21 @@ private SelectNode ConvertSelectClause(IList selectElements, ILis { var colName = star.Qualifier == null ? null : String.Join(".", star.Qualifier.Identifiers.Select(id => id.Value)); - if (colName != null && !schema.Schema.Keys.Any(col => col.StartsWith(colName + ".", StringComparison.OrdinalIgnoreCase))) + var cols = schema.Schema.Keys + .Where(col => colName == null || col.StartsWith(colName + ".", StringComparison.OrdinalIgnoreCase)) + .ToList(); + + if (colName != null && cols.Count == 0) throw new NotSupportedQueryFragmentException("The column prefix does not match with a table name or alias name used in the query", star); + // Can't select no-collation columns + foreach (var col in cols) + { + var colType = schema.Schema[col]; + if (colType is SqlDataTypeReferenceWithCollation colTypeColl && colTypeColl.CollationLabel == CollationLabel.NoCollation) + throw new NotSupportedQueryFragmentException("Cannot resolve collation conflict", element); + } + select.ColumnSet.Add(new SelectColumn { SourceColumn = colName, @@ -2737,7 +2764,7 @@ private SelectNode ConvertSelectClause(IList selectElements, ILis { if (col.AllColumns) { - var distinctSchema = distinct.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); + var distinctSchema = distinct.GetSchema(context); distinct.Columns.AddRange(distinctSchema.Schema.Keys.Where(k => col.SourceColumn == null || (k.Split('.')[0] + ".*") == col.SourceColumn)); } else @@ -2750,23 +2777,23 @@ private SelectNode ConvertSelectClause(IList selectElements, ILis return select; } - private string ComputeScalarExpression(ScalarExpression expression, IList hints, TSqlFragment query, ComputeScalarNode computeScalar, INodeSchema nonAggregateSchema, IDictionary parameterTypes, ref IDataExecutionPlanNodeInternal node) + private string ComputeScalarExpression(ScalarExpression expression, IList hints, TSqlFragment query, ComputeScalarNode computeScalar, INodeSchema nonAggregateSchema, NodeCompilationContext context, ref IDataExecutionPlanNodeInternal node) { - var computedColumn = ConvertScalarSubqueries(expression, hints, ref node, computeScalar, parameterTypes, query); + var computedColumn = ConvertScalarSubqueries(expression, hints, ref node, computeScalar, context, query); if (computedColumn != null) expression = computedColumn; // Check the type of this expression now so any errors can be reported - var computeScalarSchema = computeScalar.Source.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); - expression.GetType(GetExpressionContext(computeScalarSchema, parameterTypes, nonAggregateSchema), out _); + var computeScalarSchema = computeScalar.Source.GetSchema(context); + expression.GetType(GetExpressionContext(computeScalarSchema, context, nonAggregateSchema), out _); - var alias = $"Expr{++_colNameCounter}"; + var alias = context.GetExpressionName(); computeScalar.Columns[alias] = expression.Clone(); return alias; } - private ColumnReferenceExpression ConvertScalarSubqueries(TSqlFragment expression, IList hints, ref IDataExecutionPlanNodeInternal node, ComputeScalarNode computeScalar, IDictionary parameterTypes, TSqlFragment query) + private ColumnReferenceExpression ConvertScalarSubqueries(TSqlFragment expression, IList hints, ref IDataExecutionPlanNodeInternal node, ComputeScalarNode computeScalar, NodeCompilationContext context, TSqlFragment query) { /* * Possible subquery execution plans: @@ -2783,10 +2810,11 @@ private ColumnReferenceExpression ConvertScalarSubqueries(TSqlFragment expressio foreach (var subquery in subqueryVisitor.Subqueries) { - var outerSchema = node.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); + var outerSchema = node.GetSchema(context); var outerReferences = new Dictionary(); - var innerParameterTypes = parameterTypes == null ? new Dictionary(StringComparer.OrdinalIgnoreCase) : new Dictionary(parameterTypes, StringComparer.OrdinalIgnoreCase); - var subqueryPlan = ConvertSelectStatement(subquery.QueryExpression, hints, outerSchema, outerReferences, innerParameterTypes); + var innerParameterTypes = context.ParameterTypes == null ? new Dictionary(StringComparer.OrdinalIgnoreCase) : new Dictionary(context.ParameterTypes, StringComparer.OrdinalIgnoreCase); + var innerContext = new NodeCompilationContext(context, innerParameterTypes); + var subqueryPlan = ConvertSelectStatement(subquery.QueryExpression, hints, outerSchema, outerReferences, innerContext); // Scalar subquery must return exactly one column and one row if (subqueryPlan.ColumnSet.Count != 1) @@ -2795,22 +2823,22 @@ private ColumnReferenceExpression ConvertScalarSubqueries(TSqlFragment expressio string outputcol; var subqueryCol = subqueryPlan.ColumnSet[0].SourceColumn; BaseJoinNode join = null; - if (UseMergeJoin(node, subqueryPlan.Source, outerReferences, subqueryCol, null, false, out outputcol, out var merge)) + if (UseMergeJoin(node, subqueryPlan.Source, context, outerReferences, subqueryCol, null, false, out outputcol, out var merge)) { join = merge; } else { - outputcol = $"Expr{++_colNameCounter}"; + outputcol = context.GetExpressionName(); var loopRightSource = subqueryPlan.Source; // Unless the subquery has got an explicit TOP 1 clause, insert an aggregate and assertion nodes // to check for one row - if (!(subqueryPlan.Source.EstimateRowsOut(new NodeCompilationContext(DataSources, Options, parameterTypes)) is RowCountEstimateDefiniteRange range) || range.Maximum > 1) + if (!(subqueryPlan.Source.EstimateRowsOut(context) is RowCountEstimateDefiniteRange range) || range.Maximum > 1) { - subqueryCol = $"Expr{++_colNameCounter}"; - var rowCountCol = $"Expr{++_colNameCounter}"; + subqueryCol = context.GetExpressionName(); + var rowCountCol = context.GetExpressionName(); var aggregate = new HashMatchAggregateNode { Source = loopRightSource, @@ -2846,7 +2874,7 @@ private ColumnReferenceExpression ConvertScalarSubqueries(TSqlFragment expressio // If it is correlated, add a spool where possible closer to the data source if (outerReferences.Count == 0) { - if (EstimateRowsOut(node, parameterTypes) > 1) + if (EstimateRowsOut(node, context) > 1) { var spool = new TableSpoolNode { Source = loopRightSource, SpoolType = SpoolType.Lazy }; loopRightSource = spool; @@ -2854,7 +2882,7 @@ private ColumnReferenceExpression ConvertScalarSubqueries(TSqlFragment expressio } else if (loopRightSource is ISingleSourceExecutionPlanNode loopRightSourceSimple) { - InsertCorrelatedSubquerySpool(loopRightSourceSimple, node, hints, parameterTypes, outerReferences.Values.ToArray()); + InsertCorrelatedSubquerySpool(loopRightSourceSimple, node, hints, context, outerReferences.Values.ToArray()); } // Add a nested loop to call the subquery @@ -2887,7 +2915,7 @@ private ColumnReferenceExpression ConvertScalarSubqueries(TSqlFragment expressio return null; } - private bool UseMergeJoin(IDataExecutionPlanNodeInternal node, IDataExecutionPlanNode subqueryPlan, Dictionary outerReferences, string subqueryCol, string inPredicateCol, bool semiJoin, out string outputCol, out MergeJoinNode merge) + private bool UseMergeJoin(IDataExecutionPlanNodeInternal node, IDataExecutionPlanNode subqueryPlan, NodeCompilationContext context, Dictionary outerReferences, string subqueryCol, string inPredicateCol, bool semiJoin, out string outputCol, out MergeJoinNode merge) { outputCol = null; merge = null; @@ -2969,7 +2997,7 @@ private bool UseMergeJoin(IDataExecutionPlanNodeInternal node, IDataExecutionPla if (alias != null) fetch.Alias = alias.Alias; else - fetch.Alias = $"Expr{++_colNameCounter}"; + fetch.Alias = context.GetExpressionName(); var rightAttribute = innerKey.ToColumnReference(); if (rightAttribute.MultiPartIdentifier.Identifiers.Count == 2) @@ -3007,7 +3035,7 @@ private bool UseMergeJoin(IDataExecutionPlanNodeInternal node, IDataExecutionPla } merge.SemiJoin = true; - var definedValue = $"Expr{++_colNameCounter}"; + var definedValue = context.GetExpressionName(); merge.DefinedValues[definedValue] = outputCol ?? rightAttribute.GetColumnName(); outputCol = definedValue; } @@ -3015,7 +3043,7 @@ private bool UseMergeJoin(IDataExecutionPlanNodeInternal node, IDataExecutionPla return true; } - private void InsertCorrelatedSubquerySpool(ISingleSourceExecutionPlanNode node, IDataExecutionPlanNode outerSource, IList hints, IDictionary parameterTypes, string[] outerReferences) + private void InsertCorrelatedSubquerySpool(ISingleSourceExecutionPlanNode node, IDataExecutionPlanNode outerSource, IList hints, NodeCompilationContext context, string[] outerReferences) { if (hints.Any(hint => hint.HintKind == OptimizerHintKind.NoPerformanceSpool)) return; @@ -3079,8 +3107,8 @@ private void InsertCorrelatedSubquerySpool(ISingleSourceExecutionPlanNode node, // Check the estimated counts for the outer loop and the source at the point we'd insert the spool // If the outer loop is non-trivial (>= 100 rows) or the inner loop is small (<= 5000 records) then we want // to use the spool. - var outerCount = EstimateRowsOut((IDataExecutionPlanNodeInternal) outerSource, parameterTypes); - var innerCount = outerCount >= 100 ? -1 : EstimateRowsOut(lastCorrelatedStep.Source, parameterTypes); + var outerCount = EstimateRowsOut((IDataExecutionPlanNodeInternal) outerSource, context); + var innerCount = outerCount >= 100 ? -1 : EstimateRowsOut(lastCorrelatedStep.Source, context); if (outerCount >= 100 || innerCount <= 5000) { @@ -3094,18 +3122,18 @@ private void InsertCorrelatedSubquerySpool(ISingleSourceExecutionPlanNode node, } } - private int EstimateRowsOut(IExecutionPlanNode source, IDictionary parameterTypes) + private int EstimateRowsOut(IExecutionPlanNode source, NodeCompilationContext context) { if (source is IDataExecutionPlanNodeInternal dataNode) { - dataNode.EstimateRowsOut(new NodeCompilationContext(DataSources, Options, parameterTypes)); + dataNode.EstimateRowsOut(context); return dataNode.EstimatedRowsOut; } else { foreach (var child in source.GetSources()) { - EstimateRowsOut(child, parameterTypes); + EstimateRowsOut(child, context); } } @@ -3166,13 +3194,13 @@ private bool SplitCorrelatedCriteria(BooleanExpression filter, out BooleanExpres return false; } - private IDataExecutionPlanNodeInternal ConvertFromClause(IList tables, IList hints, TSqlFragment query, INodeSchema outerSchema, Dictionary outerReferences, IDictionary parameterTypes) + private IDataExecutionPlanNodeInternal ConvertFromClause(IList tables, IList hints, TSqlFragment query, INodeSchema outerSchema, Dictionary outerReferences, NodeCompilationContext context) { - var node = ConvertTableReference(tables[0], hints, query, outerSchema, outerReferences, parameterTypes); + var node = ConvertTableReference(tables[0], hints, query, outerSchema, outerReferences, context); for (var i = 1; i < tables.Count; i++) { - var nextTable = ConvertTableReference(tables[i], hints, query, outerSchema, outerReferences, parameterTypes); + var nextTable = ConvertTableReference(tables[i], hints, query, outerSchema, outerReferences, context); // Join predicates will be lifted from the WHERE clause during folding later. For now, just add a table spool // to cache the results of the second table and use a nested loop to join them. @@ -3184,7 +3212,7 @@ private IDataExecutionPlanNodeInternal ConvertFromClause(IList t return node; } - private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference reference, IList hints, TSqlFragment query, INodeSchema outerSchema, Dictionary outerReferences, IDictionary parameterTypes) + private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference reference, IList hints, TSqlFragment query, INodeSchema outerSchema, Dictionary outerReferences, NodeCompilationContext context) { if (reference is NamedTableReference table) { @@ -3305,10 +3333,10 @@ private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference refe { // If the join involves the primary key of one table we can safely use a merge join. // Otherwise use a nested loop join - var lhs = ConvertTableReference(join.FirstTableReference, hints, query, outerSchema, outerReferences, parameterTypes); - var rhs = ConvertTableReference(join.SecondTableReference, hints, query, outerSchema, outerReferences, parameterTypes); - var lhsSchema = lhs.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); - var rhsSchema = rhs.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); + var lhs = ConvertTableReference(join.FirstTableReference, hints, query, outerSchema, outerReferences, context); + var rhs = ConvertTableReference(join.SecondTableReference, hints, query, outerSchema, outerReferences, context); + var lhsSchema = lhs.GetSchema(context); + var rhsSchema = rhs.GetSchema(context); var fixedValueColumns = GetFixedValueColumnsFromWhereClause(query, lhsSchema, rhsSchema); var joinConditionVisitor = new JoinConditionVisitor(lhsSchema, rhsSchema, fixedValueColumns); @@ -3335,7 +3363,7 @@ private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference refe lhs = lhsComputeScalar; } - var lhsColumn = ComputeScalarExpression(joinConditionVisitor.LhsExpression, hints, query, lhsComputeScalar, null, parameterTypes, ref lhs); + var lhsColumn = ComputeScalarExpression(joinConditionVisitor.LhsExpression, hints, query, lhsComputeScalar, null, context, ref lhs); joinConditionVisitor.LhsKey = new ColumnReferenceExpression { MultiPartIdentifier = new MultiPartIdentifier { Identifiers = { new Identifier { Value = lhsColumn } } } }; } @@ -3347,7 +3375,7 @@ private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference refe rhs = rhsComputeScalar; } - var rhsColumn = ComputeScalarExpression(joinConditionVisitor.RhsExpression, hints, query, rhsComputeScalar, null, parameterTypes, ref lhs); + var rhsColumn = ComputeScalarExpression(joinConditionVisitor.RhsExpression, hints, query, rhsComputeScalar, null, context, ref lhs); joinConditionVisitor.RhsKey = new ColumnReferenceExpression { MultiPartIdentifier = new MultiPartIdentifier { Identifiers = { new Identifier { Value = rhsColumn } } } }; } } @@ -3424,8 +3452,8 @@ private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference refe } // Validate the join condition - var joinSchema = joinNode.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); - join.SearchCondition.GetType(GetExpressionContext(joinSchema, parameterTypes), out _); + var joinSchema = joinNode.GetSchema(context); + join.SearchCondition.GetType(GetExpressionContext(joinSchema, context), out _); return joinNode; } @@ -3435,33 +3463,34 @@ private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference refe if (queryDerivedTable.Columns.Count > 0) throw new NotSupportedQueryFragmentException("Unhandled query derived table column list", queryDerivedTable); - var select = ConvertSelectStatement(queryDerivedTable.QueryExpression, hints, outerSchema, outerReferences, parameterTypes); + var select = ConvertSelectStatement(queryDerivedTable.QueryExpression, hints, outerSchema, outerReferences, context); var alias = new AliasNode(select, queryDerivedTable.Alias); return alias; } if (reference is InlineDerivedTable inlineDerivedTable) - return ConvertInlineDerivedTable(inlineDerivedTable, hints, outerSchema, outerReferences, parameterTypes); + return ConvertInlineDerivedTable(inlineDerivedTable, hints, outerSchema, outerReferences, context); if (reference is UnqualifiedJoin unqualifiedJoin) { - var lhs = ConvertTableReference(unqualifiedJoin.FirstTableReference, hints, query, outerSchema, outerReferences, parameterTypes); + var lhs = ConvertTableReference(unqualifiedJoin.FirstTableReference, hints, query, outerSchema, outerReferences, context); IDataExecutionPlanNodeInternal rhs; Dictionary lhsReferences; if (unqualifiedJoin.UnqualifiedJoinType == UnqualifiedJoinType.CrossJoin) { - rhs = ConvertTableReference(unqualifiedJoin.SecondTableReference, hints, query, outerSchema, outerReferences, parameterTypes); + rhs = ConvertTableReference(unqualifiedJoin.SecondTableReference, hints, query, outerSchema, outerReferences, context); lhsReferences = null; } else { // CROSS APPLY / OUTER APPLY - treat the second table as a correlated subquery - var lhsSchema = lhs.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); + var lhsSchema = lhs.GetSchema(context); lhsReferences = new Dictionary(); - var innerParameterTypes = parameterTypes == null ? new Dictionary(StringComparer.OrdinalIgnoreCase) : new Dictionary(parameterTypes, StringComparer.OrdinalIgnoreCase); - var subqueryPlan = ConvertTableReference(unqualifiedJoin.SecondTableReference, hints, query, lhsSchema, lhsReferences, innerParameterTypes); + var innerParameterTypes = context.ParameterTypes == null ? new Dictionary(StringComparer.OrdinalIgnoreCase) : new Dictionary(context.ParameterTypes, StringComparer.OrdinalIgnoreCase); + var innerContext = new NodeCompilationContext(context, innerParameterTypes); + var subqueryPlan = ConvertTableReference(unqualifiedJoin.SecondTableReference, hints, query, lhsSchema, lhsReferences, innerContext); rhs = subqueryPlan; // If the subquery is uncorrelated, add a table spool to cache the results @@ -3471,7 +3500,7 @@ private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference refe var spool = new TableSpoolNode { Source = rhs, SpoolType = SpoolType.Lazy }; rhs = spool; } - else if (UseMergeJoin(lhs, subqueryPlan, lhsReferences, null, null, false, out _, out var merge)) + else if (UseMergeJoin(lhs, subqueryPlan, context, lhsReferences, null, null, false, out _, out var merge)) { if (unqualifiedJoin.UnqualifiedJoinType == UnqualifiedJoinType.CrossApply) merge.JoinType = QualifiedJoinType.Inner; @@ -3480,7 +3509,7 @@ private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference refe } else if (rhs is ISingleSourceExecutionPlanNode loopRightSourceSimple) { - InsertCorrelatedSubquerySpool(loopRightSourceSimple, lhs, hints, parameterTypes, lhsReferences.Values.ToArray()); + InsertCorrelatedSubquerySpool(loopRightSourceSimple, lhs, hints, context, lhsReferences.Values.ToArray()); } } @@ -3500,7 +3529,7 @@ private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference refe if (reference is SchemaObjectFunctionTableReference tvf) { // Capture any references to data from an outer query - CaptureOuterReferences(outerSchema, null, tvf, parameterTypes, outerReferences); + CaptureOuterReferences(outerSchema, null, tvf, context, outerReferences); // Convert any scalar subqueries in the parameters to its own execution plan, and capture the references from those plans // as parameters to be passed to the function @@ -3508,19 +3537,19 @@ private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference refe var computeScalar = new ComputeScalarNode { Source = source }; foreach (var param in tvf.Parameters.ToList()) - ConvertScalarSubqueries(param, hints, ref source, computeScalar, parameterTypes, tvf); + ConvertScalarSubqueries(param, hints, ref source, computeScalar, context, tvf); if (source is ConstantScanNode) source = null; else if (computeScalar.Columns.Count > 0) source = computeScalar; - var scalarSubquerySchema = source?.GetSchema(new NodeCompilationContext(DataSources, Options, parameterTypes)); + var scalarSubquerySchema = source?.GetSchema(context); var scalarSubqueryReferences = new Dictionary(); - CaptureOuterReferences(scalarSubquerySchema, null, tvf, parameterTypes, scalarSubqueryReferences); + CaptureOuterReferences(scalarSubquerySchema, null, tvf, context, scalarSubqueryReferences); var dataSource = SelectDataSource(tvf.SchemaObject); - var execute = ExecuteMessageNode.FromMessage(tvf, dataSource, GetExpressionContext(null, parameterTypes)); + var execute = ExecuteMessageNode.FromMessage(tvf, dataSource, GetExpressionContext(null, context)); if (source == null) return execute; @@ -3586,7 +3615,7 @@ private void GetFixedValueColumnsFromWhereClause(HashSet columns, Boolea } } - private IDataExecutionPlanNodeInternal ConvertInlineDerivedTable(InlineDerivedTable inlineDerivedTable, IList hints, INodeSchema outerSchema, Dictionary outerReferences, IDictionary parameterTypes) + private IDataExecutionPlanNodeInternal ConvertInlineDerivedTable(InlineDerivedTable inlineDerivedTable, IList hints, INodeSchema outerSchema, Dictionary outerReferences, NodeCompilationContext context) { // Check all the rows have the expected number of values and column names are unique var columnNames = inlineDerivedTable.Columns.Select(col => col.Value).ToList(); @@ -3614,7 +3643,7 @@ private IDataExecutionPlanNodeInternal ConvertInlineDerivedTable(InlineDerivedTa }; } - var converted = ConvertSelectStatement(select, hints, outerSchema, outerReferences, parameterTypes); + var converted = ConvertSelectStatement(select, hints, outerSchema, outerReferences, context); var source = converted.Source; // Make sure expected column names are used @@ -3662,9 +3691,9 @@ private QuerySpecification CreateSelectRow(RowValue row, IList colum return querySpec; } - private ExpressionCompilationContext GetExpressionContext(INodeSchema schema, IDictionary parameterTypes = null, INodeSchema nonAggregateSchema = null) + private ExpressionCompilationContext GetExpressionContext(INodeSchema schema, NodeCompilationContext context, INodeSchema nonAggregateSchema = null) { - return new ExpressionCompilationContext(DataSources, Options, parameterTypes ?? _parameterTypes, schema, nonAggregateSchema); + return new ExpressionCompilationContext(context, schema, nonAggregateSchema); } } } diff --git a/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs b/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs index be33eeb3..85c3a31a 100644 --- a/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs +++ b/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs @@ -333,6 +333,7 @@ public static SqlInt32 Year(SqlDateTime date) /// The string to get the prefix of /// The number of characters to return /// The first characters of the string + [CollationSensitive] public static SqlString Left(SqlString s, [MaxLength] SqlInt32 length) { if (s.IsNull || length.IsNull) @@ -350,6 +351,7 @@ public static SqlString Left(SqlString s, [MaxLength] SqlInt32 length) /// The string to get the suffix of /// The number of characters to return /// The last characters of the string + [CollationSensitive] public static SqlString Right(SqlString s, [MaxLength] SqlInt32 length) { if (s.IsNull || length.IsNull) @@ -368,6 +370,7 @@ public static SqlString Right(SqlString s, [MaxLength] SqlInt32 length) /// The substring to be found /// The replacement string /// Replaces any instances of with in the + [CollationSensitive] public static SqlString Replace(SqlString input, SqlString find, SqlString replace) { if (input.IsNull || find.IsNull || replace.IsNull) @@ -381,6 +384,7 @@ public static SqlString Replace(SqlString input, SqlString find, SqlString repla /// /// The string expression to be evaluated /// + [CollationSensitive] public static SqlInt32 Len(SqlString s) { if (s.IsNull) @@ -428,6 +432,7 @@ public static SqlInt32 DataLength(T value, [SourceType] DataTypeReference typ /// An integer that specifies where the returned characters start (the numbering is 1 based, meaning that the first character in the expression is 1) /// A positive integer that specifies how many characters of the expression will be returned /// + [CollationSensitive] public static SqlString Substring(SqlString expression, SqlInt32 start, [MaxLength] SqlInt32 length) { if (expression.IsNull || start.IsNull || length.IsNull) @@ -495,6 +500,7 @@ public static SqlString RTrim([MaxLength] SqlString expression) /// A character expression containing the sequence to find /// A character expression to search. /// + [CollationSensitive] public static SqlInt32 CharIndex(SqlString find, SqlString search) { return CharIndex(find, search, 0); @@ -507,6 +513,7 @@ public static SqlInt32 CharIndex(SqlString find, SqlString search) /// A character expression to search. /// An integer or bigint expression at which the search starts. If start_location is not specified, has a negative value, or has a zero (0) value, the search starts at the beginning of expressionToSearch. /// + [CollationSensitive] public static SqlInt32 CharIndex(SqlString find, SqlString search, SqlInt32 startLocation) { if (find.IsNull || search.IsNull || startLocation.IsNull) @@ -521,6 +528,18 @@ public static SqlInt32 CharIndex(SqlString find, SqlString search, SqlInt32 star return search.Value.IndexOf(find.Value, startLocation.Value - 1, StringComparison.OrdinalIgnoreCase) + 1; } + /// + /// Returns the starting position of the first occurrence of a pattern in a specified expression, or zero if the pattern is not found, on all valid text and character data types. + /// + /// A character expression that contains the sequence to be found. Wildcard characters can be used; however, the % character must come before and follow + /// An expression that is searched for the specified + /// + [CollationSensitive] + public static SqlInt32 PatIndex(SqlString pattern, SqlString expression) + { + throw new NotImplementedException(); + } + /// /// Returns the single-byte character with the specified integer code /// @@ -924,4 +943,12 @@ class SourceTypeAttribute : Attribute class OptionalAttribute : Attribute { } + + /// + /// Indicates that a function is collation sensitive + /// + [AttributeUsage(AttributeTargets.Method)] + class CollationSensitiveAttribute : Attribute + { + } } diff --git a/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems b/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems index 17601662..a3d67510 100644 --- a/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems +++ b/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems @@ -112,6 +112,7 @@ + diff --git a/MarkMpn.Sql4Cds.Engine/NodeContext.cs b/MarkMpn.Sql4Cds.Engine/NodeContext.cs index 24d1da40..4ded09a4 100644 --- a/MarkMpn.Sql4Cds.Engine/NodeContext.cs +++ b/MarkMpn.Sql4Cds.Engine/NodeContext.cs @@ -12,6 +12,9 @@ namespace MarkMpn.Sql4Cds.Engine /// class NodeCompilationContext { + private readonly NodeCompilationContext _parentContext; + private int _expressionCounter; + /// /// Creates a new /// @@ -28,6 +31,21 @@ public NodeCompilationContext( ParameterTypes = parameterTypes; } + /// + /// Creates a new as a child of another context + /// + /// The parent context that this context is being created from + /// The names and types of the parameters that are available to this section of the query + public NodeCompilationContext( + NodeCompilationContext parentContext, + IDictionary parameterTypes) + { + DataSources = parentContext.DataSources; + Options = parentContext.Options; + ParameterTypes = parameterTypes; + _parentContext = parentContext; + } + /// /// Returns the data sources that are available to the query /// @@ -47,6 +65,18 @@ public NodeCompilationContext( /// Returns the details of the primary data source /// public DataSource PrimaryDataSource => DataSources[Options.PrimaryDataSource]; + + /// + /// Generates a unique name for an expression + /// + /// The name to use for the expression + public string GetExpressionName() + { + if (_parentContext != null) + return _parentContext.GetExpressionName(); + + return $"Expr{++_expressionCounter}"; + } } /// diff --git a/MarkMpn.Sql4Cds.Engine/Visitors/ExplicitCollationVisitor.cs b/MarkMpn.Sql4Cds.Engine/Visitors/ExplicitCollationVisitor.cs new file mode 100644 index 00000000..0f368dd0 --- /dev/null +++ b/MarkMpn.Sql4Cds.Engine/Visitors/ExplicitCollationVisitor.cs @@ -0,0 +1,38 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.SqlServer.TransactSql.ScriptDom; + +namespace MarkMpn.Sql4Cds.Engine.Visitors +{ + /// + /// Replaces explicit collations with a function call to make further processing simpler + /// + class ExplicitCollationVisitor : RewriteVisitorBase + { + protected override ScalarExpression ReplaceExpression(ScalarExpression expression, out string name) + { + name = null; + + if (expression is PrimaryExpression primary && + primary.Collation != null) + { + return new FunctionCall + { + FunctionName = new Identifier { Value = "ExplicitCollation" }, + Parameters = + { + expression + } + }; + } + + return expression; + } + + protected override BooleanExpression ReplaceExpression(BooleanExpression expression) + { + return expression; + } + } +} From 320b45800600156cbd81696b1532d162c4b28239 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Fri, 24 Mar 2023 20:17:42 +0000 Subject: [PATCH 07/34] Fixed collation for cast/convert --- .../ExecutionPlan/ExpressionExtensions.cs | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs index 9dcf6427..e0694d1a 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs @@ -1408,19 +1408,25 @@ private static Expression ToExpression(this ConvertCall convert, ExpressionCompi sqlType = convert.DataType; - // Set default length to 30 if (sqlType is SqlDataTypeReference sqlTargetType && - sqlTargetType.SqlDataTypeOption.IsStringType() && - sqlTargetType.Parameters.Count == 0) + sqlTargetType.SqlDataTypeOption.IsStringType()) { - var coll = sqlType as SqlDataTypeReferenceWithCollation; + // Set default length to 30 + if (sqlTargetType.Parameters.Count == 0) + sqlTargetType.Parameters.Add(new IntegerLiteral { Value = "30" }); + + // If the input is a character string, the output string has the collation label of the input string + // If the input is not a character string, the output string is coercible-default and assigned the collation of the current database for the connection + var valueTypeColl = valueType as SqlDataTypeReferenceWithCollation; + var collation = valueTypeColl?.Collation ?? context.PrimaryDataSource.DefaultCollation; + var collationLabel = valueTypeColl?.CollationLabel ?? CollationLabel.CoercibleDefault; sqlType = new SqlDataTypeReferenceWithCollation { SqlDataTypeOption = sqlTargetType.SqlDataTypeOption, - Parameters = { new IntegerLiteral { Value = "30" } }, - Collation = coll?.Collation ?? context.PrimaryDataSource.DefaultCollation, - CollationLabel = coll?.CollationLabel ?? CollationLabel.CoercibleDefault + Parameters = { sqlTargetType.Parameters[0] }, + Collation = collation, + CollationLabel = collationLabel }; } From 8dbb9f9b7082fba9c5a531bdd62ad07a204597d4 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Fri, 24 Mar 2023 20:34:10 +0000 Subject: [PATCH 08/34] Made LIKE collation aware --- .../ExecutionPlan/ExpressionExtensions.cs | 49 +++++++++++++++---- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs index e0694d1a..6ce8cfb0 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs @@ -973,7 +973,6 @@ private static Expression ToExpression(this LikePredicate like, ExpressionCompil var pattern = like.SecondExpression.ToExpression(context, contextParam, out var patternType); var escape = like.EscapeExpression?.ToExpression(context, contextParam, out escapeType); - // TODO: Use the collations of the value/pattern and ensure they are consistent sqlType = DataTypeHelpers.Bit; var stringType = DataTypeHelpers.NVarChar(Int32.MaxValue, context.PrimaryDataSource.DefaultCollation, CollationLabel.CoercibleDefault); @@ -1032,7 +1031,12 @@ private static Regex LikeToRegex(SqlString pattern, SqlString escape) var inRange = false; var escapeChar = escape.IsNull ? '\0' : escape.Value[0]; - foreach (var ch in pattern.Value) + var pat = pattern.Value; + + if (pattern.SqlCompareOptions.HasFlag(SqlCompareOptions.IgnoreNonSpace)) + pat = RemoveDiacritics(pat); + + foreach (var ch in pat) { if (escapeChar != '\0' && ch == escapeChar) { @@ -1093,7 +1097,7 @@ private static Regex LikeToRegex(SqlString pattern, SqlString escape) regexBuilder.Append("$"); - return new Regex(regexBuilder.ToString(), RegexOptions.IgnoreCase); + return new Regex(regexBuilder.ToString(), pattern.SqlCompareOptions.HasFlag(SqlCompareOptions.IgnoreCase) ? RegexOptions.IgnoreCase : RegexOptions.None); } private static SqlBoolean Like(SqlString value, SqlString pattern, SqlString escape, bool not) @@ -1103,12 +1107,8 @@ private static SqlBoolean Like(SqlString value, SqlString pattern, SqlString esc // Convert the LIKE pattern to a regex var regex = LikeToRegex(pattern, escape); - var result = regex.IsMatch(value.Value); - - if (not) - result = !result; - return result; + return Like(value, regex, not); } private static SqlBoolean Like(SqlString value, Regex pattern, bool not) @@ -1116,7 +1116,12 @@ private static SqlBoolean Like(SqlString value, Regex pattern, bool not) if (value.IsNull) return false; - var result = pattern.IsMatch(value.Value); + var text = value.Value; + + if (value.SqlCompareOptions.HasFlag(SqlCompareOptions.IgnoreNonSpace)) + text = RemoveDiacritics(text); + + var result = pattern.IsMatch(text); if (not) result = !result; @@ -1124,6 +1129,32 @@ private static SqlBoolean Like(SqlString value, Regex pattern, bool not) return result; } + /// + /// Removes accents from a string, used for accent-insensitive collations + /// + /// https://stackoverflow.com/a/249126/269629 + /// The text to remove the accents from + /// A version of the with accents removed + static string RemoveDiacritics(string text) + { + var normalizedString = text.Normalize(NormalizationForm.FormD); + var stringBuilder = new StringBuilder(capacity: normalizedString.Length); + + for (int i = 0; i < normalizedString.Length; i++) + { + char c = normalizedString[i]; + var unicodeCategory = CharUnicodeInfo.GetUnicodeCategory(c); + if (unicodeCategory != UnicodeCategory.NonSpacingMark) + { + stringBuilder.Append(c); + } + } + + return stringBuilder + .ToString() + .Normalize(NormalizationForm.FormC); + } + private static Expression ToExpression(this SimpleCaseExpression simpleCase, ExpressionCompilationContext context, ParameterExpression contextParam, out DataTypeReference sqlType) { // Convert all the different elements to expressions From f75be041616fcaab2e0e7b9c58d8faf8340c5de5 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Fri, 24 Mar 2023 20:34:17 +0000 Subject: [PATCH 09/34] Load collations from XTB --- MarkMpn.Sql4Cds.Engine/DataSource.cs | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine/DataSource.cs b/MarkMpn.Sql4Cds.Engine/DataSource.cs index c6eee687..334b04d4 100644 --- a/MarkMpn.Sql4Cds.Engine/DataSource.cs +++ b/MarkMpn.Sql4Cds.Engine/DataSource.cs @@ -14,6 +14,8 @@ namespace MarkMpn.Sql4Cds.Engine /// public class DataSource { + private Collation _defaultCollation; + /// /// Creates a new using default values based on an existing connection. /// @@ -45,8 +47,6 @@ public DataSource(IOrganizationService org) Name = name; TableSizeCache = new TableSizeCache(org, Metadata); MessageCache = new MessageCache(org, Metadata); - - DefaultCollation = LoadDefaultCollation(); } /// @@ -84,16 +84,29 @@ public DataSource() /// /// Returns the default collation used by this instance /// - internal Collation DefaultCollation { get; set; } + internal Collation DefaultCollation + { + get + { + if (_defaultCollation == null) + _defaultCollation = LoadDefaultCollation(); + + return _defaultCollation; + } + set + { + _defaultCollation = value; + } + } private Collation LoadDefaultCollation() { var qry = new QueryExpression("organization") { - ColumnSet = new ColumnSet("lcid") + ColumnSet = new ColumnSet("localeid") }; var org = Connection.RetrieveMultiple(qry).Entities[0]; - var lcid = org.GetAttributeValue("lcid"); + var lcid = org.GetAttributeValue("localeid"); // Collation options are set based on the default language. Most are CI/AI but a few are not // https://learn.microsoft.com/en-us/power-platform/admin/language-collations#language-and-associated-collation-used-with-dataverse From 4bcbf6d53804284f7bc004e059f00c362a9eaa30 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Fri, 24 Mar 2023 20:47:11 +0000 Subject: [PATCH 10/34] Implemented PATINDEX --- .../ExecutionPlan/ExpressionExtensions.cs | 40 ++++++++++++++----- MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs | 16 +++++++- MarkMpn.Sql4Cds/FunctionMetadata.cs | 3 ++ 3 files changed, 47 insertions(+), 12 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs index 6ce8cfb0..ef5d65b3 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs @@ -1010,7 +1010,7 @@ private static Expression ToExpression(this LikePredicate like, ExpressionCompil // Do a one-off conversion to regex try { - var regex = LikeToRegex((SqlString)((ConstantExpression)pattern).Value, (SqlString)(((ConstantExpression)escape)?.Value ?? SqlString.Null)); + var regex = LikeToRegex((SqlString)((ConstantExpression)pattern).Value, (SqlString)(((ConstantExpression)escape)?.Value ?? SqlString.Null), false); return Expr.Call(() => Like(Expr.Arg(), Expr.Arg(), Expr.Arg()), value, Expression.Constant(regex), Expression.Constant(like.NotDefined)); } catch (ArgumentException ex) @@ -1022,20 +1022,37 @@ private static Expression ToExpression(this LikePredicate like, ExpressionCompil return Expr.Call(() => Like(Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg()), value, pattern, escape, Expression.Constant(like.NotDefined)); } - private static Regex LikeToRegex(SqlString pattern, SqlString escape) + internal static Regex LikeToRegex(SqlString pattern, SqlString escape, bool patIndex) { var regexBuilder = new StringBuilder(); - regexBuilder.Append("^"); - - var escaped = false; - var inRange = false; - var escapeChar = escape.IsNull ? '\0' : escape.Value[0]; - var pat = pattern.Value; if (pattern.SqlCompareOptions.HasFlag(SqlCompareOptions.IgnoreNonSpace)) pat = RemoveDiacritics(pat); + var endWildcard = false; + + if (!patIndex) + { + regexBuilder.Append("^"); + } + else + { + if (!pattern.Value.StartsWith("%")) + regexBuilder.Append("^"); + else + pat = pat.TrimStart('%'); + + endWildcard = pat.EndsWith("%"); + + if (endWildcard) + pat = pat.TrimEnd('%'); + } + + var escaped = false; + var inRange = false; + var escapeChar = escape.IsNull ? '\0' : escape.Value[0]; + foreach (var ch in pat) { if (escapeChar != '\0' && ch == escapeChar) @@ -1095,7 +1112,8 @@ private static Regex LikeToRegex(SqlString pattern, SqlString escape) if (escaped || inRange) throw new ArgumentException("Invalid LIKE pattern"); - regexBuilder.Append("$"); + if (!patIndex || !endWildcard) + regexBuilder.Append("$"); return new Regex(regexBuilder.ToString(), pattern.SqlCompareOptions.HasFlag(SqlCompareOptions.IgnoreCase) ? RegexOptions.IgnoreCase : RegexOptions.None); } @@ -1106,7 +1124,7 @@ private static SqlBoolean Like(SqlString value, SqlString pattern, SqlString esc return false; // Convert the LIKE pattern to a regex - var regex = LikeToRegex(pattern, escape); + var regex = LikeToRegex(pattern, escape, false); return Like(value, regex, not); } @@ -1135,7 +1153,7 @@ private static SqlBoolean Like(SqlString value, Regex pattern, bool not) /// https://stackoverflow.com/a/249126/269629 /// The text to remove the accents from /// A version of the with accents removed - static string RemoveDiacritics(string text) + internal static string RemoveDiacritics(string text) { var normalizedString = text.Normalize(NormalizationForm.FormD); var stringBuilder = new StringBuilder(capacity: normalizedString.Length); diff --git a/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs b/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs index 85c3a31a..ae9dd051 100644 --- a/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs +++ b/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs @@ -537,7 +537,21 @@ public static SqlInt32 CharIndex(SqlString find, SqlString search, SqlInt32 star [CollationSensitive] public static SqlInt32 PatIndex(SqlString pattern, SqlString expression) { - throw new NotImplementedException(); + if (pattern.IsNull || expression.IsNull) + return 0; + + var regex = ExpressionExtensions.LikeToRegex(pattern, SqlString.Null, true); + var value = expression.Value; + + if (expression.SqlCompareOptions.HasFlag(SqlCompareOptions.IgnoreNonSpace)) + value = ExpressionExtensions.RemoveDiacritics(value); + + var match = regex.Match(value); + + if (!match.Success) + return 0; + + return match.Index + 1; } /// diff --git a/MarkMpn.Sql4Cds/FunctionMetadata.cs b/MarkMpn.Sql4Cds/FunctionMetadata.cs index 182d938e..3efb526f 100644 --- a/MarkMpn.Sql4Cds/FunctionMetadata.cs +++ b/MarkMpn.Sql4Cds/FunctionMetadata.cs @@ -147,6 +147,9 @@ public abstract class SqlFunctions [Description("Returns a value formatted with the specified format and optional culture")] public abstract object format(object value, string format, string culture); + + [Description("Returns the starting position of the first occurrence of a pattern in a specified expression, or zero if the pattern is not found")] + public abstract string patindex(string pattern, string expression); } } } From 863a201c019ae7ccb290f8167b76177146dd73a1 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Fri, 24 Mar 2023 21:15:37 +0000 Subject: [PATCH 11/34] Added UPPER and LOWER functions --- MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs | 28 +++++++++++++++++++ MarkMpn.Sql4Cds/FunctionMetadata.cs | 6 ++++ 2 files changed, 34 insertions(+) diff --git a/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs b/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs index ae9dd051..b4e3f905 100644 --- a/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs +++ b/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs @@ -679,6 +679,34 @@ private static SqlString Format(IFormattable value, SqlString format, SqlString return SqlString.Null; } } + + /// + /// Returns a character expression with lowercase character data converted to uppercase + /// + /// An expression of character data + /// + [CollationSensitive] + public static SqlString Upper([MaxLength] SqlString value) + { + if (value.IsNull) + return value; + + return value.Value.ToUpper(value.CultureInfo); + } + + /// + /// Returns a character expression with uppercase character data converted to lowercase + /// + /// An expression of character data + /// + [CollationSensitive] + public static SqlString Lower([MaxLength] SqlString value) + { + if (value.IsNull) + return value; + + return value.Value.ToLower(value.CultureInfo); + } } /// diff --git a/MarkMpn.Sql4Cds/FunctionMetadata.cs b/MarkMpn.Sql4Cds/FunctionMetadata.cs index 3efb526f..1a05b078 100644 --- a/MarkMpn.Sql4Cds/FunctionMetadata.cs +++ b/MarkMpn.Sql4Cds/FunctionMetadata.cs @@ -150,6 +150,12 @@ public abstract class SqlFunctions [Description("Returns the starting position of the first occurrence of a pattern in a specified expression, or zero if the pattern is not found")] public abstract string patindex(string pattern, string expression); + + [Description("Converts a string to uppercase")] + public abstract string upper(string value); + + [Description("Converts a string to lowercase")] + public abstract string lower(string value); } } } From 60eac3e86309a0e9c27a3e1e7f69169178b25928 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Fri, 24 Mar 2023 21:15:48 +0000 Subject: [PATCH 12/34] Fixed bugs with collation sensitive functions --- .../AdoProviderTests.cs | 23 ++++++++++++++ .../ExecutionPlanTests.cs | 4 +-- .../ExecutionPlan/ExpressionExtensions.cs | 30 ++++++++++++++++--- 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs index f29011d4..31dab17f 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs @@ -1050,5 +1050,28 @@ public void SortByCollation() } } } + + [TestMethod] + public void CollationSensitiveFunctions() + { + using (var con = new Sql4CdsConnection(_localDataSource)) + using (var cmd = con.CreateCommand()) + { + cmd.CommandText = "select case when 'test' like 't%' then 1 else 0 end"; + var actual = cmd.ExecuteScalar(); + + Assert.AreEqual(1, actual); + + cmd.CommandText = "select case when 'TEST' collate latin1_general_cs_ai like 't%' then 1 else 0 end"; + actual = cmd.ExecuteScalar(); + + Assert.AreEqual(0, actual); + + cmd.CommandText = "select case when upper('test' collate latin1_general_cs_ai) like 't%' then 1 else 0 end"; + actual = cmd.ExecuteScalar(); + + Assert.AreEqual(0, actual); + } + } } } diff --git a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs index bfd70c78..5c57c58b 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs @@ -3625,7 +3625,7 @@ DECLARE @test varchar(3) { if (plan is IDataReaderExecutionPlanNode selectQuery) { - var results = selectQuery.Execute(new NodeExecutionContext(_dataSources, this, parameterTypes, parameterValues), CommandBehavior.Default); + var results = selectQuery.Execute(new NodeExecutionContext(_localDataSource, this, parameterTypes, parameterValues), CommandBehavior.Default); var dataTable = new DataTable(); dataTable.Load(results); @@ -3635,7 +3635,7 @@ DECLARE @test varchar(3) } else if (plan is IDmlQueryExecutionPlanNode dmlQuery) { - dmlQuery.Execute(new NodeExecutionContext(_dataSources, this, parameterTypes, parameterValues), out _); + dmlQuery.Execute(new NodeExecutionContext(_localDataSource, this, parameterTypes, parameterValues), out _); } } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs index ef5d65b3..fca17766 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExpressionExtensions.cs @@ -764,6 +764,9 @@ private static MethodInfo GetMethod(Type targetType, DataSource primaryDataSourc throw new NotSupportedQueryFragmentException($"Cannot convert {paramTypes[i].ToSql()} to {paramType.ToSqlType(primaryDataSource).ToSql()}", i < paramOffset ? func : func.Parameters[i - paramOffset]); } + if (sqlType == null) + sqlType = method.ReturnType.ToSqlType(primaryDataSource); + if (method.GetCustomAttribute(typeof(CollationSensitiveAttribute)) != null) { // If method is collation sensitive: @@ -822,9 +825,6 @@ private static MethodInfo GetMethod(Type targetType, DataSource primaryDataSourc } } - if (sqlType == null) - sqlType = method.ReturnType.ToSqlType(primaryDataSource); - return method; } @@ -982,6 +982,7 @@ private static Expression ToExpression(this LikePredicate like, ExpressionCompil throw new NotSupportedQueryFragmentException($"No implicit conversion exists for types {valueType.ToSql()} and {stringType.ToSql()}", like.FirstExpression); value = SqlTypeConverter.Convert(value, valueType, stringType); + valueType = stringType; } if (pattern.Type != typeof(SqlString)) @@ -990,6 +991,7 @@ private static Expression ToExpression(this LikePredicate like, ExpressionCompil throw new NotSupportedQueryFragmentException($"No implicit conversion exists for types {patternType.ToSql()} and {stringType.ToSql()}", like.SecondExpression); pattern = SqlTypeConverter.Convert(pattern, patternType, stringType); + patternType = stringType; } if (escape != null && escape.Type != typeof(SqlString)) @@ -998,6 +1000,23 @@ private static Expression ToExpression(this LikePredicate like, ExpressionCompil throw new NotSupportedQueryFragmentException($"No implicit conversion exists for types {escapeType.ToSql()} and {stringType.ToSql()}", like.EscapeExpression); escape = SqlTypeConverter.Convert(escape, escapeType, stringType); + escapeType = stringType; + } + + if (!SqlDataTypeReferenceWithCollation.TryConvertCollation((SqlDataTypeReference)valueType, (SqlDataTypeReference)patternType, out var collation, out var collationLabel)) + throw new NotSupportedQueryFragmentException($"Cannot resolve collation conflict between '{((SqlDataTypeReferenceWithCollation)valueType).Collation.Name}' and {((SqlDataTypeReferenceWithCollation)patternType).Collation.Name}' in like operation", like); + + ((SqlDataTypeReferenceWithCollation)stringType).Collation = collation; + ((SqlDataTypeReferenceWithCollation)stringType).CollationLabel = collationLabel; + + if (escapeType != null && !SqlDataTypeReferenceWithCollation.TryConvertCollation(stringType, (SqlDataTypeReference)escapeType, out collation, out collationLabel)) + { + throw new NotSupportedQueryFragmentException($"Cannot resolve collation conflict between '{((SqlDataTypeReferenceWithCollation)stringType).Collation.Name}' and {((SqlDataTypeReferenceWithCollation)escapeType).Collation.Name}' in like operation", like); + } + else + { + ((SqlDataTypeReferenceWithCollation)stringType).Collation = collation; + ((SqlDataTypeReferenceWithCollation)stringType).CollationLabel = collationLabel; } AssertCollationSensitive(stringType, "like operation", like); @@ -1010,7 +1029,7 @@ private static Expression ToExpression(this LikePredicate like, ExpressionCompil // Do a one-off conversion to regex try { - var regex = LikeToRegex((SqlString)((ConstantExpression)pattern).Value, (SqlString)(((ConstantExpression)escape)?.Value ?? SqlString.Null), false); + var regex = LikeToRegex(SqlTypeConverter.ConvertCollation((SqlString)((ConstantExpression)pattern).Value, collation), (SqlString)(((ConstantExpression)escape)?.Value ?? SqlString.Null), false); return Expr.Call(() => Like(Expr.Arg(), Expr.Arg(), Expr.Arg()), value, Expression.Constant(regex), Expression.Constant(like.NotDefined)); } catch (ArgumentException ex) @@ -1019,6 +1038,9 @@ private static Expression ToExpression(this LikePredicate like, ExpressionCompil } } + value = Expr.Call(() => SqlTypeConverter.ConvertCollation(Expr.Arg(), Expr.Arg()), value, Expression.Constant(collation)); + pattern = Expr.Call(() => SqlTypeConverter.ConvertCollation(Expr.Arg(), Expr.Arg()), pattern, Expression.Constant(collation)); + return Expr.Call(() => Like(Expr.Arg(), Expr.Arg(), Expr.Arg(), Expr.Arg()), value, pattern, escape, Expression.Constant(like.NotDefined)); } From 9077195019f9654e772799cf4aa3162b3fc88ce4 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Wed, 29 Mar 2023 08:51:52 +0100 Subject: [PATCH 13/34] System function progress --- .../ExecutionPlanTests.cs | 8 ++ MarkMpn.Sql4Cds.Engine/Collation.cs | 9 ++ .../ExecutionPlan/SystemFunctionNode.cs | 94 +++++++++++++++++++ .../ExecutionPlanBuilder.cs | 23 ++++- .../MarkMpn.Sql4Cds.Engine.projitems | 1 + 5 files changed, 134 insertions(+), 1 deletion(-) create mode 100644 MarkMpn.Sql4Cds.Engine/ExecutionPlan/SystemFunctionNode.cs diff --git a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs index 5c57c58b..114ee289 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs @@ -5143,5 +5143,13 @@ public void NoCollationExprWithExplicitCollationCollationSensitiveFunctionError( var query = "SELECT PATINDEX((CASE WHEN p.employees > f.employees THEN p.name ELSE f.name END) COLLATE Latin1_General_CI_AS, 'a') FROM prod.dbo.account p, french.dbo.account f"; planBuilder.Build(query, null, out _); } + + [TestMethod] + public void CollationFunctions() + { + var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); + var query = "SELECT name, COLLATIONPROPERTY(name, 'lcid') FROM sys.fn_helpcollations()"; + planBuilder.Build(query, null, out _); + } } } diff --git a/MarkMpn.Sql4Cds.Engine/Collation.cs b/MarkMpn.Sql4Cds.Engine/Collation.cs index acb40582..6ff92988 100644 --- a/MarkMpn.Sql4Cds.Engine/Collation.cs +++ b/MarkMpn.Sql4Cds.Engine/Collation.cs @@ -180,6 +180,15 @@ public static bool TryParse(string name, out Collation coll) return false; } + public static IEnumerable GetAllCollations() + { + foreach (var kvp in _collationNameToLcid) + { + yield return new Collation(kvp.Key + "_BIN", kvp.Value, SqlCompareOptions.BinarySort); + yield return new Collation(kvp.Key + "_BIN2", kvp.Value, SqlCompareOptions.BinarySort2); + } + } + /// /// Applies the current collation to a string value /// diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SystemFunctionNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SystemFunctionNode.cs new file mode 100644 index 00000000..af93cb66 --- /dev/null +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SystemFunctionNode.cs @@ -0,0 +1,94 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Text; +using Microsoft.SqlServer.TransactSql.ScriptDom; +using Microsoft.Xrm.Sdk; + +namespace MarkMpn.Sql4Cds.Engine.ExecutionPlan +{ + class SystemFunctionNode : BaseDataNode + { + /// + /// The instance that this node will be executed against + /// + [Category("Data Source")] + [Description("The data source this query is executed against")] + public string DataSource { get; set; } + + /// + /// The name of the function to execute + /// + [Category("System Function")] + [Description("The name of the function to execute")] + public SystemFunction SystemFunction { get; set; } + + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) + { + } + + public override object Clone() + { + return new SystemFunctionNode + { + DataSource = DataSource, + SystemFunction = SystemFunction + }; + } + + public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) + { + return this; + } + + public override INodeSchema GetSchema(NodeCompilationContext context) + { + var dataSource = context.DataSources[DataSource]; + + switch (SystemFunction) + { + case SystemFunction.fn_helpcollations: + return new NodeSchema( + schema: new Dictionary(StringComparer.OrdinalIgnoreCase) + { + ["name"] = DataTypeHelpers.NVarChar(128, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), + ["description"] = DataTypeHelpers.NVarChar(1000, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), + }, + aliases: null, + primaryKey: null, + sortOrder: null, + notNullColumns: new[] { "name", "description" }); + } + + throw new NotSupportedException("Unsupported function " + SystemFunction); + } + + public override IEnumerable GetSources() + { + return Array.Empty(); + } + + protected override RowCountEstimate EstimateRowsOutInternal(NodeCompilationContext context) + { + return new RowCountEstimate(100); + } + + protected override IEnumerable ExecuteInternal(NodeExecutionContext context) + { + var dataSource = context.DataSources[DataSource]; + + switch (SystemFunction) + { + case SystemFunction.fn_helpcollations: + break; + } + + throw new NotSupportedException("Unsupported function " + SystemFunction); + } + } + + enum SystemFunction + { + fn_helpcollations + } +} diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs index ea741c95..1552c799 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs @@ -3549,7 +3549,28 @@ private IDataExecutionPlanNodeInternal ConvertTableReference(TableReference refe CaptureOuterReferences(scalarSubquerySchema, null, tvf, context, scalarSubqueryReferences); var dataSource = SelectDataSource(tvf.SchemaObject); - var execute = ExecuteMessageNode.FromMessage(tvf, dataSource, GetExpressionContext(null, context)); + IDataExecutionPlanNodeInternal execute; + + if (String.IsNullOrEmpty(tvf.SchemaObject.SchemaIdentifier?.Value) || + tvf.SchemaObject.SchemaIdentifier.Value.Equals("dbo", StringComparison.OrdinalIgnoreCase)) + { + execute = ExecuteMessageNode.FromMessage(tvf, dataSource, GetExpressionContext(null, context)); + } + else if (tvf.SchemaObject.SchemaIdentifier.Value.Equals("sys", StringComparison.OrdinalIgnoreCase)) + { + if (!Enum.TryParse(tvf.SchemaObject.BaseIdentifier.Value, true, out var systemFunction)) + throw new NotSupportedQueryFragmentException("Invalid function name", tvf); + + execute = new SystemFunctionNode + { + DataSource = dataSource.Name, + SystemFunction = systemFunction + }; + } + else + { + throw new NotSupportedQueryFragmentException("Invalid function name", tvf); + } if (source == null) return execute; diff --git a/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems b/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems index a3d67510..55ff95ca 100644 --- a/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems +++ b/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems @@ -26,6 +26,7 @@ + From be9c856bb3903071fe80434403c5ce0c7d6fcf94 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Wed, 29 Mar 2023 08:56:03 +0100 Subject: [PATCH 14/34] Added failing test for IN subquery --- .../AdoProviderTests.cs | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs index 31dab17f..3c3840f0 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs @@ -1073,5 +1073,28 @@ public void CollationSensitiveFunctions() Assert.AreEqual(0, actual); } } + + [TestMethod] + public void MergeSemiJoin() + { + using (var con = new Sql4CdsConnection(_dataSources)) + using (var cmd = con.CreateCommand()) + { + cmd.CommandText = "insert into account (name) values ('data8')"; + cmd.ExecuteNonQuery(); + cmd.ExecuteNonQuery(); + + con.ChangeDatabase("prod"); + cmd.ExecuteNonQuery(); + + cmd.CommandText = "SELECT name FROM account WHERE name IN (SELECT name FROM uat..account)"; + + using (var reader = cmd.ExecuteReader()) + { + Assert.IsTrue(reader.Read()); + Assert.IsFalse(reader.Read()); + } + } + } } } From 62cd0688bd57b8f49236d5f1771430f61c7c7b9b Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Fri, 31 Mar 2023 21:00:58 +0100 Subject: [PATCH 15/34] Added collation functions --- .../ExecutionPlanTests.cs | 11 ++- .../Ado/Sql4CdsParameter.cs | 12 +-- MarkMpn.Sql4Cds.Engine/Collation.cs | 78 ++++++++++++++++++- .../ExecutionPlan/SystemFunctionNode.cs | 14 +++- MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs | 55 +++++++++++++ MarkMpn.Sql4Cds/FunctionMetadata.cs | 3 + 6 files changed, 159 insertions(+), 14 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs index 114ee289..e438066c 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs @@ -5148,8 +5148,15 @@ public void NoCollationExprWithExplicitCollationCollationSensitiveFunctionError( public void CollationFunctions() { var planBuilder = new ExecutionPlanBuilder(_dataSources.Values, new OptionsWrapper(this) { PrimaryDataSource = "prod" }); - var query = "SELECT name, COLLATIONPROPERTY(name, 'lcid') FROM sys.fn_helpcollations()"; - planBuilder.Build(query, null, out _); + var query = "SELECT *, COLLATIONPROPERTY(name, 'lcid') FROM sys.fn_helpcollations()"; + var plans = planBuilder.Build(query, null, out _); + + Assert.AreEqual(1, plans.Length); + var select = AssertNode(plans[0]); + CollectionAssert.AreEqual(new[] { "name", "description", "Expr1" }, select.ColumnSet.Select(col => col.OutputColumn).ToArray()); + var computeScalar = AssertNode(select.Source); + var sysFunc = AssertNode(computeScalar.Source); + Assert.AreEqual(SystemFunction.fn_helpcollations, sysFunc.SystemFunction); } } } diff --git a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsParameter.cs b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsParameter.cs index c8572558..cd111f55 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsParameter.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsParameter.cs @@ -149,11 +149,11 @@ internal DataTypeReference GetDataType() switch (DbType) { case DbType.AnsiString: - _dataType = DataTypeHelpers.VarChar(Size, new Collation(null, LocaleId, CompareInfo), CollationLabel.CoercibleDefault); + _dataType = DataTypeHelpers.VarChar(Size, new Collation(null, LocaleId, CompareInfo, null), CollationLabel.CoercibleDefault); break; case DbType.AnsiStringFixedLength: - _dataType = DataTypeHelpers.Char(Size, new Collation(null, LocaleId, CompareInfo), CollationLabel.CoercibleDefault); + _dataType = DataTypeHelpers.Char(Size, new Collation(null, LocaleId, CompareInfo, null), CollationLabel.CoercibleDefault); break; case DbType.Binary: @@ -226,11 +226,11 @@ internal DataTypeReference GetDataType() break; case DbType.String: - _dataType = DataTypeHelpers.NVarChar(Size, new Collation(null, LocaleId, CompareInfo), CollationLabel.CoercibleDefault); + _dataType = DataTypeHelpers.NVarChar(Size, new Collation(null, LocaleId, CompareInfo, null), CollationLabel.CoercibleDefault); break; case DbType.StringFixedLength: - _dataType = DataTypeHelpers.NChar(Size, new Collation(null, LocaleId, CompareInfo), CollationLabel.CoercibleDefault); + _dataType = DataTypeHelpers.NChar(Size, new Collation(null, LocaleId, CompareInfo, null), CollationLabel.CoercibleDefault); break; case DbType.Time: @@ -250,11 +250,11 @@ internal DataTypeReference GetDataType() break; case DbType.VarNumeric: - _dataType = DataTypeHelpers.NVarChar(Int32.MaxValue, new Collation(null, LocaleId, CompareInfo), CollationLabel.CoercibleDefault); + _dataType = DataTypeHelpers.NVarChar(Int32.MaxValue, new Collation(null, LocaleId, CompareInfo, null), CollationLabel.CoercibleDefault); break; case DbType.Xml: - _dataType = DataTypeHelpers.NVarChar(Int32.MaxValue, new Collation(null, LocaleId, CompareInfo), CollationLabel.CoercibleDefault); + _dataType = DataTypeHelpers.NVarChar(Int32.MaxValue, new Collation(null, LocaleId, CompareInfo, null), CollationLabel.CoercibleDefault); break; } } diff --git a/MarkMpn.Sql4Cds.Engine/Collation.cs b/MarkMpn.Sql4Cds.Engine/Collation.cs index 6ff92988..6f1d53bd 100644 --- a/MarkMpn.Sql4Cds.Engine/Collation.cs +++ b/MarkMpn.Sql4Cds.Engine/Collation.cs @@ -37,11 +37,13 @@ static Collation() /// The name the collation was parsed from /// The locale ID to use /// Additional comparison options - public Collation(string name, int lcid, SqlCompareOptions compareOptions) + /// A description of the collation for display purposes + public Collation(string name, int lcid, SqlCompareOptions compareOptions, string description) { Name = name; LCID = lcid; CompareOptions = compareOptions; + Description = description; } /// @@ -82,6 +84,11 @@ public Collation(int lcid, bool caseSensitive, bool accentSensitive) /// public string Name { get; } + /// + /// Returns the description of the collation + /// + public string Description { get; } + /// /// Returns the default collation to be used for system data /// @@ -171,7 +178,7 @@ public static bool TryParse(string name, out Collation coll) if (!_collationNameToLcid.TryGetValue(collationName, out var lcid)) break; - coll = new Collation(name, lcid, compareOptions); + coll = new Collation(name, lcid, compareOptions, null); return true; } } @@ -184,8 +191,71 @@ public static IEnumerable GetAllCollations() { foreach (var kvp in _collationNameToLcid) { - yield return new Collation(kvp.Key + "_BIN", kvp.Value, SqlCompareOptions.BinarySort); - yield return new Collation(kvp.Key + "_BIN2", kvp.Value, SqlCompareOptions.BinarySort2); + yield return new Collation(kvp.Key + "_BIN", kvp.Value, SqlCompareOptions.BinarySort, kvp.Key + ", binary sort"); + yield return new Collation(kvp.Key + "_BIN2", kvp.Value, SqlCompareOptions.BinarySort2, kvp.Key + ", binary code point comparison sort"); + + var options = SqlCompareOptions.None; + var description = new string[5]; + description[0] = kvp.Key; + + foreach (var c in new[] { "_CI", "_CS" }) + { + if (c == "_CI") + { + options |= SqlCompareOptions.IgnoreCase; + description[1] = "case-insensitive"; + } + else + { + options &= ~SqlCompareOptions.IgnoreCase; + description[1] = "case-sensitive"; + } + + foreach (var a in new[] { "_AI", "_AS" }) + { + if (a == "_AI") + { + options |= SqlCompareOptions.IgnoreNonSpace; + description[2] = "accent-insensitive"; + } + else + { + options &= ~SqlCompareOptions.IgnoreNonSpace; + description[2] = "accent-sensitive"; + } + + foreach (var k in new[] { "", "_KS" }) + { + if (k == "") + { + options |= SqlCompareOptions.IgnoreKanaType; + description[3] = "kanatype-insensitive"; + } + else + { + options &= ~SqlCompareOptions.IgnoreKanaType; + description[3] = "kanatype-sensitive"; + } + + foreach (var w in new[] { "", "_WS" }) + { + if (w == "") + { + options |= SqlCompareOptions.IgnoreWidth; + description[4] = "width-insensitive"; + } + else + { + options &= ~SqlCompareOptions.IgnoreWidth; + description[4] = "width-sensitive"; + } + + // Albanian-100, case-sensitive, accent-insensitive, kanatype-sensitive, width-insensitive + yield return new Collation(kvp.Key + c + a + k + w, kvp.Value, options, String.Join(", ", description)); + } + } + } + } } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SystemFunctionNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SystemFunctionNode.cs index af93cb66..3c47caa7 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SystemFunctionNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SystemFunctionNode.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.ComponentModel; +using System.Data.SqlTypes; using System.Text; using Microsoft.SqlServer.TransactSql.ScriptDom; using Microsoft.Xrm.Sdk; @@ -80,10 +81,19 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont switch (SystemFunction) { case SystemFunction.fn_helpcollations: + foreach (var coll in Collation.GetAllCollations()) + { + yield return new Entity + { + ["name"] = dataSource.DefaultCollation.ToSqlString(coll.Name), + ["description"] = dataSource.DefaultCollation.ToSqlString(coll.Description) + }; + } break; - } - throw new NotSupportedException("Unsupported function " + SystemFunction); + default: + throw new NotSupportedException("Unsupported function " + SystemFunction); + } } } diff --git a/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs b/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs index b4e3f905..691a3e47 100644 --- a/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs +++ b/MarkMpn.Sql4Cds.Engine/ExpressionFunctions.cs @@ -707,6 +707,61 @@ public static SqlString Lower([MaxLength] SqlString value) return value.Value.ToLower(value.CultureInfo); } + + /// + /// Returns the requested property of a specified collation + /// + /// The name of the collation + /// The collation property + /// + public static SqlInt32 CollationProperty(SqlString collation, SqlString property) + { + if (collation.IsNull || property.IsNull) + return SqlInt32.Null; + + if (!Collation.TryParse(collation.Value, out var coll)) + return SqlInt32.Null; + + switch (property.Value.ToLowerInvariant()) + { + case "codepage": + return 0; + + case "lcid": + return coll.LCID; + + case "comparisonstyle": + var compare = 0; + + if (coll.CompareOptions.HasFlag(CompareOptions.IgnoreCase)) + compare |= 1; + + if (coll.CompareOptions.HasFlag(CompareOptions.IgnoreNonSpace)) + compare |= 2; + + if (coll.CompareOptions.HasFlag(CompareOptions.IgnoreKanaType)) + compare |= 65536; + + if (coll.CompareOptions.HasFlag(CompareOptions.IgnoreWidth)) + compare |= 131072; + + return compare; + + case "version": + if (coll.Name.Contains("140")) + return 3; + + if (coll.Name.Contains("100")) + return 2; + + if (coll.Name.Contains("90")) + return 1; + + return 0; + } + + return SqlInt32.Null; + } } /// diff --git a/MarkMpn.Sql4Cds/FunctionMetadata.cs b/MarkMpn.Sql4Cds/FunctionMetadata.cs index 1a05b078..643b4a39 100644 --- a/MarkMpn.Sql4Cds/FunctionMetadata.cs +++ b/MarkMpn.Sql4Cds/FunctionMetadata.cs @@ -156,6 +156,9 @@ public abstract class SqlFunctions [Description("Converts a string to lowercase")] public abstract string lower(string value); + + [Description("Returns the requested property of a specified collation")] + public abstract int collationproperty(string collation_name, string property); } } } From 3b3d5dc779093be0deb13bd08d15610d3bdfa969 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Sun, 2 Apr 2023 16:29:58 +0100 Subject: [PATCH 16/34] Prevent duplicating calculated columns when also using SELECT * --- .../ExecutionPlan/AliasNode.cs | 12 +++++++++-- .../ExecutionPlan/SelectNode.cs | 20 ++++++++++++++++--- .../ExecutionPlanBuilder.cs | 12 ++++++----- 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AliasNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AliasNode.cs index 81fa4045..1c4540cf 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AliasNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AliasNode.cs @@ -24,6 +24,7 @@ public AliasNode(SelectNode select, Identifier identifier) ColumnSet.AddRange(select.ColumnSet); Source = select.Source; Alias = identifier.Value; + LogicalSourceSchema = select.LogicalSourceSchema; // Check for duplicate columns var duplicateColumn = select.ColumnSet @@ -60,6 +61,12 @@ private AliasNode() [Browsable(false)] public IDataExecutionPlanNodeInternal Source { get; set; } + /// + /// The schema that shold be used for expanding "*" columns + /// + [Browsable(false)] + public INodeSchema LogicalSourceSchema { get; set; } + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { var mappings = ColumnSet.Where(col => !col.AllColumns).ToDictionary(col => col.OutputColumn, col => col.SourceColumn); @@ -94,7 +101,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext Source.Parent = this; SelectNode.FoldFetchXmlColumns(Source, ColumnSet, context); - SelectNode.ExpandWildcardColumns(Source, ColumnSet, context); + SelectNode.ExpandWildcardColumns(Source, LogicalSourceSchema, ColumnSet, context); if (Source is FetchXmlScan fetchXml) { @@ -266,7 +273,8 @@ public override object Clone() var clone = new AliasNode { Alias = Alias, - Source = (IDataExecutionPlanNodeInternal)Source.Clone() + Source = (IDataExecutionPlanNodeInternal)Source.Clone(), + LogicalSourceSchema = LogicalSourceSchema, }; clone.Source.Parent = clone; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SelectNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SelectNode.cs index 812202d5..c402b485 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SelectNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SelectNode.cs @@ -22,6 +22,8 @@ class SelectNode : BaseNode, ISingleSourceExecutionPlanNode, IDataReaderExecutio private int _executionCount; private readonly Timer _timer = new Timer(); + public SelectNode() { } + /// /// The columns that should be included in the query results /// @@ -36,6 +38,12 @@ class SelectNode : BaseNode, ISingleSourceExecutionPlanNode, IDataReaderExecutio [Browsable(false)] public IDataExecutionPlanNodeInternal Source { get; set; } + /// + /// The schema that shold be used for expanding "*" columns + /// + [Browsable(false)] + public INodeSchema LogicalSourceSchema { get; set; } + [Browsable(false)] public string Sql { get; set; } @@ -257,10 +265,10 @@ private void FoldMetadataColumns(IDataExecutionPlanNode source, List columnSet, NodeCompilationContext context) + internal static void ExpandWildcardColumns(IDataExecutionPlanNodeInternal source, INodeSchema sourceSchema, List columnSet, NodeCompilationContext context) { // Expand any AllColumns if (columnSet.Any(col => col.AllColumns)) @@ -276,8 +284,13 @@ internal static void ExpandWildcardColumns(IDataExecutionPlanNodeInternal source continue; } - foreach (var src in schema.Schema.Keys.Where(k => col.SourceColumn == null || k.StartsWith(col.SourceColumn + ".", StringComparison.OrdinalIgnoreCase)).OrderBy(k => k, StringComparer.OrdinalIgnoreCase)) + foreach (var src in sourceSchema.Schema.Keys.Where(k => col.SourceColumn == null || k.StartsWith(col.SourceColumn + ".", StringComparison.OrdinalIgnoreCase)).OrderBy(k => k, StringComparer.OrdinalIgnoreCase)) { + // Columns might be available in the logical source schema but not in + // the real one, e.g. due to aggregation + if (!schema.ContainsColumn(src, out _)) + src.ToColumnReference().GetType(new ExpressionCompilationContext(context, schema, sourceSchema), out _); + expanded.Add(new SelectColumn { SourceColumn = src, @@ -312,6 +325,7 @@ public object Clone() var clone = new SelectNode { Source = (IDataExecutionPlanNodeInternal)Source.Clone(), + LogicalSourceSchema = LogicalSourceSchema, Sql = Sql, Index = Index, Length = Length diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs index 1552c799..9bb78cae 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs @@ -261,7 +261,7 @@ private IRootExecutionPlanNodeInternal[] ConvertExecuteStatement(ExecuteStatemen if (message.OutputParameters.Any(p => !p.IsScalarType())) { // Expose the produced data set - var select = new SelectNode { Source = node }; + var select = new SelectNode { Source = node, LogicalSourceSchema = schema }; foreach (var col in schema.Schema.Keys.OrderBy(col => col)) select.ColumnSet.Add(new SelectColumn { SourceColumn = col, OutputColumn = col }); @@ -1767,7 +1767,7 @@ private SelectNode ConvertBinaryQuery(BinaryQueryExpression binary, IList new ColumnReferenceExpression { MultiPartIdentifier = new MultiPartIdentifier { Identifiers = { new Identifier { Value = col.OutputColumn } } } }).ToArray(), binary, context, outerSchema, outerReferences, null); node = ConvertOffsetClause(node, binary.OffsetClause, context); - var select = new SelectNode { Source = node }; + var select = new SelectNode { Source = node, LogicalSourceSchema = concat.GetSchema(context) }; select.ColumnSet.AddRange(concat.ColumnSet.Select((col, i) => new SelectColumn { SourceColumn = col.OutputColumn, SourceExpression = col.SourceExpressions[0], OutputColumn = left.ColumnSet[i].OutputColumn })); return select; @@ -1794,6 +1794,7 @@ private SelectNode ConvertSelectQuerySpec(QuerySpecification querySpec, IList() } } : ConvertFromClause(querySpec.FromClause.TableReferences, hints, querySpec, outerSchema, outerReferences, context); + var logicalSchema = node.GetSchema(context); node = ConvertInSubqueries(node, hints, querySpec, context, outerSchema, outerReferences); node = ConvertExistsSubqueries(node, hints, querySpec, context, outerSchema, outerReferences); @@ -1814,7 +1815,7 @@ private SelectNode ConvertSelectQuerySpec(QuerySpecification querySpec, IList selectElements, IList hints, IDataExecutionPlanNodeInternal node, DistinctNode distinct, TSqlFragment query, NodeCompilationContext context, INodeSchema outerSchema, IDictionary outerReferences, INodeSchema nonAggregateSchema) + private SelectNode ConvertSelectClause(IList selectElements, IList hints, IDataExecutionPlanNodeInternal node, DistinctNode distinct, TSqlFragment query, NodeCompilationContext context, INodeSchema outerSchema, IDictionary outerReferences, INodeSchema nonAggregateSchema, INodeSchema logicalSourceSchema) { var schema = node.GetSchema(context); var select = new SelectNode { - Source = node + Source = node, + LogicalSourceSchema = logicalSourceSchema }; var computeScalar = new ComputeScalarNode From 5e8ecd54af639169d008b7ebc5f814786f310fb4 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Mon, 3 Apr 2023 18:43:42 +0100 Subject: [PATCH 17/34] Handle executing the same command after changing database --- MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsCommand.cs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsCommand.cs b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsCommand.cs index d55ca340..38a04df4 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsCommand.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsCommand.cs @@ -27,6 +27,7 @@ public class Sql4CdsCommand : DbCommand private CommandType _commandType; private CancellationTokenSource _cts; private bool _cancelledManually; + private string _lastDatabase; public Sql4CdsCommand(Sql4CdsConnection connection) : this(connection, string.Empty) { @@ -161,7 +162,7 @@ public override object ExecuteScalar() public override void Prepare() { - if (UseTDSEndpointDirectly || Plan != null) + if (_lastDatabase == _connection.Database && (UseTDSEndpointDirectly || Plan != null)) return; GeneratePlan(true); @@ -199,6 +200,7 @@ public IRootExecutionPlanNode[] GeneratePlan(bool compileForExecution) var plan = _planBuilder.Build(commandText, ((Sql4CdsParameterCollection)Parameters).GetParameterTypes(), out var useTDSEndpointDirectly); UseTDSEndpointDirectly = useTDSEndpointDirectly; + _lastDatabase = _connection.Database; if (compileForExecution) { From 4b103456409965680c414223873d4c75dcc65017 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Mon, 3 Apr 2023 18:44:48 +0100 Subject: [PATCH 18/34] Keep column ordering consistent --- .../AdoProviderTests.cs | 2 + .../ExecutionPlan/AliasNode.cs | 4 +- .../ExecutionPlan/BaseAggregateNode.cs | 2 +- .../ExecutionPlan/BaseDmlNode.cs | 2 +- .../ExecutionPlan/BaseJoinNode.cs | 6 +- .../ExecutionPlan/ComputeScalarNode.cs | 2 +- .../ExecutionPlan/ConcatenateNode.cs | 2 +- .../ExecutionPlan/ConstantScanNode.cs | 9 +- .../ExecutionPlan/ExecuteMessageNode.cs | 9 +- .../ExecutionPlan/FetchXmlScan.cs | 12 +- .../ExecutionPlan/GlobalOptionSetQueryNode.cs | 2 +- .../ExecutionPlan/MetadataQueryNode.cs | 12 +- .../ExecutionPlan/NodeSchema.cs | 122 ++++++++++++++++-- .../RetrieveTotalRecordCountNode.cs | 2 +- .../ExecutionPlan/SelectNode.cs | 2 +- .../ExecutionPlan/SystemFunctionNode.cs | 2 +- 16 files changed, 151 insertions(+), 41 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs index 3c3840f0..be276058 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/AdoProviderTests.cs @@ -1080,6 +1080,8 @@ public void MergeSemiJoin() using (var con = new Sql4CdsConnection(_dataSources)) using (var cmd = con.CreateCommand()) { + cmd.CommandTimeout = 0; + cmd.CommandText = "insert into account (name) values ('data8')"; cmd.ExecuteNonQuery(); cmd.ExecuteNonQuery(); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AliasNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AliasNode.cs index 1c4540cf..7359904b 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AliasNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AliasNode.cs @@ -155,7 +155,7 @@ public override INodeSchema GetSchema(NodeCompilationContext context) { // Map the base names to the alias names var sourceSchema = Source.GetSchema(context); - var schema = new Dictionary(StringComparer.OrdinalIgnoreCase); + var schema = new ColumnList(); var aliases = new Dictionary>(); var primaryKey = (string)null; var mappings = new Dictionary(StringComparer.OrdinalIgnoreCase); @@ -218,7 +218,7 @@ public override INodeSchema GetSchema(NodeCompilationContext context) sortOrder: sortOrder); } - private void AddSchemaColumn(string outputColumn, string sourceColumn, Dictionary schema, Dictionary> aliases, ref string primaryKey, Dictionary mappings, INodeSchema sourceSchema) + private void AddSchemaColumn(string outputColumn, string sourceColumn, ColumnList schema, Dictionary> aliases, ref string primaryKey, Dictionary mappings, INodeSchema sourceSchema) { if (!sourceSchema.ContainsColumn(sourceColumn, out var normalized)) return; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseAggregateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseAggregateNode.cs index ad73a0fa..e1073edf 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseAggregateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseAggregateNode.cs @@ -178,7 +178,7 @@ public override INodeSchema GetSchema(NodeCompilationContext context) { var sourceSchema = Source.GetSchema(context); var expressionContext = new ExpressionCompilationContext(context, sourceSchema, null); - var schema = new Dictionary(StringComparer.OrdinalIgnoreCase); + var schema = new ColumnList(); var aliases = new Dictionary>(StringComparer.OrdinalIgnoreCase); var primaryKey = (string)null; var notNullColumns = new List(); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs index 416f6780..291bddc3 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs @@ -211,7 +211,7 @@ protected List GetDmlSourceEntities(NodeExecutionContext context, out IN // Store the values under the column index as well as name for compatibility with INSERT ... SELECT ... var dataTable = new DataTable(); var schemaTable = dataReader.GetSchemaTable(); - var columnTypes = new Dictionary(StringComparer.OrdinalIgnoreCase); + var columnTypes = new ColumnList(); var targetDataSource = context.DataSources[DataSource]; for (var i = 0; i < schemaTable.Rows.Count; i++) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseJoinNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseJoinNode.cs index 1b46b3fa..d408cc98 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseJoinNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseJoinNode.cs @@ -17,6 +17,7 @@ abstract class BaseJoinNode : BaseDataNode private INodeSchema _lastLeftSchema; private INodeSchema _lastRightSchema; private INodeSchema _lastSchema; + private bool _lastSchemaIncludedSemiJoin; /// /// The first data source to merge @@ -115,10 +116,10 @@ protected virtual INodeSchema GetSchema(NodeCompilationContext context, bool inc var outerSchema = LeftSource.GetSchema(context); var innerSchema = GetRightSchema(context); - if (outerSchema == _lastLeftSchema && innerSchema == _lastRightSchema) + if (outerSchema == _lastLeftSchema && innerSchema == _lastRightSchema && includeSemiJoin == _lastSchemaIncludedSemiJoin) return _lastSchema; - var schema = new Dictionary(StringComparer.OrdinalIgnoreCase); + var schema = new ColumnList(); var aliases = new Dictionary>(StringComparer.OrdinalIgnoreCase); var primaryKey = GetPrimaryKey(outerSchema, innerSchema); var notNullColumns = new List(); @@ -162,6 +163,7 @@ protected virtual INodeSchema GetSchema(NodeCompilationContext context, bool inc aliases: aliases, notNullColumns: notNullColumns, sortOrder: GetSortOrder(outerSchema, innerSchema)); + _lastSchemaIncludedSemiJoin = includeSemiJoin; return _lastSchema; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ComputeScalarNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ComputeScalarNode.cs index f6bd3489..7d971116 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ComputeScalarNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ComputeScalarNode.cs @@ -54,7 +54,7 @@ public override INodeSchema GetSchema(NodeCompilationContext context) // Copy the source schema and add in the additional computed columns var sourceSchema = Source.GetSchema(context); var expressionCompilationContext = new ExpressionCompilationContext(context, sourceSchema, null); - var schema = new Dictionary(sourceSchema.Schema.Count, StringComparer.OrdinalIgnoreCase); + var schema = new ColumnList(); foreach (var col in sourceSchema.Schema) schema[col.Key] = col.Value; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConcatenateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConcatenateNode.cs index 478d93e7..f7bf40bb 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConcatenateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConcatenateNode.cs @@ -48,7 +48,7 @@ protected override IEnumerable ExecuteInternal(NodeExecutionContext cont public override INodeSchema GetSchema(NodeCompilationContext context) { - var schema = new Dictionary(StringComparer.OrdinalIgnoreCase); + var schema = new ColumnList(); var sourceSchema = Sources[0].GetSchema(context); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConstantScanNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConstantScanNode.cs index 36b8e685..1fb6502a 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConstantScanNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ConstantScanNode.cs @@ -31,7 +31,7 @@ class ConstantScanNode : BaseDataNode /// The types of values to be returned /// [Browsable(false)] - public Dictionary Schema { get; private set; } = new Dictionary(); + public IDictionary Schema { get; private set; } = new ColumnList(); protected override IEnumerable ExecuteInternal(NodeExecutionContext context) { @@ -56,9 +56,14 @@ public override IEnumerable GetSources() public override INodeSchema GetSchema(NodeCompilationContext context) { + var schema = new ColumnList(); + + foreach (var col in Schema) + schema[PrefixWithAlias(col.Key)] = col.Value; + return new NodeSchema( primaryKey: null, - schema: Schema.ToDictionary(kvp => PrefixWithAlias(kvp.Key), kvp => kvp.Value, StringComparer.OrdinalIgnoreCase), + schema: schema, aliases: Schema.ToDictionary(kvp => kvp.Key, kvp => (IReadOnlyList) new List { PrefixWithAlias(kvp.Key) }, StringComparer.OrdinalIgnoreCase), notNullColumns: null, sortOrder: null); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteMessageNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteMessageNode.cs index 33a2469d..79d3f684 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteMessageNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteMessageNode.cs @@ -92,7 +92,7 @@ class ExecuteMessageNode : BaseDataNode, IDmlQueryExecutionPlanNode /// The types of values to be returned /// [Browsable(false)] - public Dictionary Schema { get; private set; } = new Dictionary(); + public IDictionary Schema { get; private set; } = new ColumnList(); /// /// Indicates if custom plugins should be skipped @@ -148,9 +148,14 @@ private bool GetBypassPluginExecution(IList queryHints, IQueryExe public override INodeSchema GetSchema(NodeCompilationContext context) { + var schema = new ColumnList(); + + foreach (var col in Schema) + schema[PrefixWithAlias(col.Key)] = col.Value; + return new NodeSchema( primaryKey: null, - schema: Schema.ToDictionary(kvp => PrefixWithAlias(kvp.Key), kvp => kvp.Value, StringComparer.OrdinalIgnoreCase), + schema: schema, aliases: Schema.ToDictionary(kvp => kvp.Key, kvp => (IReadOnlyList)new List { PrefixWithAlias(kvp.Key) }, StringComparer.OrdinalIgnoreCase), notNullColumns: null, sortOrder: null); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs index 2f86555e..6c85af17 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs @@ -655,7 +655,7 @@ public override INodeSchema GetSchema(NodeCompilationContext context) var entity = FetchXml.Items.OfType().Single(); var meta = dataSource.Metadata[entity.name]; - var schema = new Dictionary(StringComparer.OrdinalIgnoreCase); + var schema = new ColumnList(); var aliases = new Dictionary>(StringComparer.OrdinalIgnoreCase); var primaryKey = FetchXml.aggregate ? null : $"{Alias}.{meta.PrimaryIdAttribute}"; var notNullColumns = new HashSet(); @@ -759,7 +759,7 @@ internal static bool IsValidAlias(string alias) return Regex.IsMatch(alias, "^[A-Za-z_][A-Za-z0-9_]*$"); } - private void AddSchemaAttributes(DataSource dataSource, Dictionary schema, Dictionary> aliases, ref string primaryKey, HashSet notNullColumns, List sortOrder, string entityName, string alias, object[] items, bool innerJoin, bool requireTablePrefix) + private void AddSchemaAttributes(DataSource dataSource, ColumnList schema, Dictionary> aliases, ref string primaryKey, HashSet notNullColumns, List sortOrder, string entityName, string alias, object[] items, bool innerJoin, bool requireTablePrefix) { if (items == null && !ReturnFullSchema) return; @@ -768,7 +768,7 @@ private void AddSchemaAttributes(DataSource dataSource, Dictionary a.LogicalName)) { if (attrMetadata.IsValidForRead == false) continue; @@ -918,7 +918,7 @@ private void AddSchemaAttributes(DataSource dataSource, Dictionary schema, Dictionary> aliases, HashSet notNullColumns, string alias, filter filter) + private void AddNotNullFilters(ColumnList schema, Dictionary> aliases, HashSet notNullColumns, string alias, filter filter) { if (filter.Items == null) return; @@ -941,7 +941,7 @@ private void AddNotNullFilters(Dictionary schema, Dic AddNotNullFilters(schema, aliases, notNullColumns, alias, subFilter); } - private void AddSchemaAttribute(DataSource dataSource, Dictionary schema, Dictionary> aliases, HashSet notNullColumns, string fullName, string simpleName, DataTypeReference type, AttributeMetadata attrMetadata, bool innerJoin) + private void AddSchemaAttribute(DataSource dataSource, ColumnList schema, Dictionary> aliases, HashSet notNullColumns, string fullName, string simpleName, DataTypeReference type, AttributeMetadata attrMetadata, bool innerJoin) { var notNull = innerJoin && (attrMetadata.RequiredLevel?.Value == AttributeRequiredLevel.SystemRequired || attrMetadata.LogicalName == "createdon" || attrMetadata.LogicalName == "createdby" || attrMetadata.AttributeOf == "createdby"); @@ -969,7 +969,7 @@ private void AddSchemaAttribute(DataSource dataSource, Dictionary schema, Dictionary> aliases, HashSet notNullColumns, string fullName, string simpleName, DataTypeReference type, bool notNull) + private void AddSchemaAttribute(ColumnList schema, Dictionary> aliases, HashSet notNullColumns, string fullName, string simpleName, DataTypeReference type, bool notNull) { schema[fullName] = type; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GlobalOptionSetQueryNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GlobalOptionSetQueryNode.cs index c31f8032..fbb79d97 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GlobalOptionSetQueryNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GlobalOptionSetQueryNode.cs @@ -117,7 +117,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext public override INodeSchema GetSchema(NodeCompilationContext context) { - var schema = new Dictionary(StringComparer.OrdinalIgnoreCase); + var schema = new ColumnList(); var aliases = new Dictionary>(StringComparer.OrdinalIgnoreCase); foreach (var prop in _optionsetProps.Values) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MetadataQueryNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MetadataQueryNode.cs index 2e45f681..581287c2 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MetadataQueryNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MetadataQueryNode.cs @@ -441,7 +441,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext public override INodeSchema GetSchema(NodeCompilationContext context) { - var schema = new Dictionary(StringComparer.OrdinalIgnoreCase); + var schema = new ColumnList(); var aliases = new Dictionary>(); var primaryKey = (string)null; var notNullColumns = new HashSet(StringComparer.OrdinalIgnoreCase); @@ -456,7 +456,7 @@ public override INodeSchema GetSchema(NodeCompilationContext context) if (Query.Properties != null) entityProps = entityProps.Where(p => Query.Properties.AllProperties || Query.Properties.PropertyNames.Contains(p.PropertyName, StringComparer.OrdinalIgnoreCase)); - foreach (var prop in entityProps) + foreach (var prop in entityProps.OrderBy(p => p.SqlName)) { var fullName = $"{EntityAlias}.{prop.SqlName}"; schema[fullName] = prop.SqlType; @@ -483,7 +483,7 @@ public override INodeSchema GetSchema(NodeCompilationContext context) if (Query.AttributeQuery?.Properties != null) attributeProps = attributeProps.Where(p => Query.AttributeQuery.Properties.AllProperties || Query.AttributeQuery.Properties.PropertyNames.Contains(p.PropertyName, StringComparer.OrdinalIgnoreCase)); - foreach (var prop in attributeProps) + foreach (var prop in attributeProps.OrderBy(p => p.SqlName)) { var fullName = $"{AttributeAlias}.{prop.SqlName}"; schema[fullName] = prop.SqlType; @@ -511,7 +511,7 @@ public override INodeSchema GetSchema(NodeCompilationContext context) if (Query.RelationshipQuery?.Properties != null) relationshipProps = relationshipProps.Where(p => Query.RelationshipQuery.Properties.AllProperties || Query.RelationshipQuery.Properties.PropertyNames.Contains(p.PropertyName, StringComparer.OrdinalIgnoreCase)); - foreach (var prop in relationshipProps) + foreach (var prop in relationshipProps.OrderBy(p => p.SqlName)) { var fullName = $"{OneToManyRelationshipAlias}.{prop.SqlName}"; schema[fullName] = prop.SqlType; @@ -539,7 +539,7 @@ public override INodeSchema GetSchema(NodeCompilationContext context) if (Query.RelationshipQuery?.Properties != null) relationshipProps = relationshipProps.Where(p => Query.RelationshipQuery.Properties.AllProperties || Query.RelationshipQuery.Properties.PropertyNames.Contains(p.PropertyName, StringComparer.OrdinalIgnoreCase)); - foreach (var prop in relationshipProps) + foreach (var prop in relationshipProps.OrderBy(p => p.SqlName)) { var fullName = $"{ManyToOneRelationshipAlias}.{prop.SqlName}"; schema[fullName] = prop.SqlType; @@ -567,7 +567,7 @@ public override INodeSchema GetSchema(NodeCompilationContext context) if (Query.RelationshipQuery?.Properties != null) relationshipProps = relationshipProps.Where(p => Query.RelationshipQuery.Properties.AllProperties || Query.RelationshipQuery.Properties.PropertyNames.Contains(p.PropertyName, StringComparer.OrdinalIgnoreCase)); - foreach (var prop in relationshipProps) + foreach (var prop in relationshipProps.OrderBy(p => p.SqlName)) { var fullName = $"{ManyToManyRelationshipAlias}.{prop.SqlName}"; schema[fullName] = prop.SqlType; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/NodeSchema.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/NodeSchema.cs index 1b196d43..a2cca010 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/NodeSchema.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/NodeSchema.cs @@ -1,5 +1,7 @@ using System; +using System.Collections; using System.Collections.Generic; +using System.Collections.Specialized; using System.Linq; using System.Text; using System.Threading.Tasks; @@ -10,7 +12,7 @@ namespace MarkMpn.Sql4Cds.Engine.ExecutionPlan /// /// Describes the schema of data produced by a node in an execution plan /// - public class NodeSchema : INodeSchema + class NodeSchema : INodeSchema { /// /// Creates a new @@ -18,7 +20,7 @@ public class NodeSchema : INodeSchema public NodeSchema(IReadOnlyDictionary schema, IReadOnlyDictionary> aliases, string primaryKey, IReadOnlyList notNullColumns, IReadOnlyList sortOrder) { PrimaryKey = primaryKey; - Schema = schema ?? new Dictionary(); + Schema = schema ?? new ColumnList(); Aliases = aliases ?? new Dictionary>(); SortOrder = sortOrder ?? Array.Empty(); NotNullColumns = notNullColumns ?? Array.Empty(); @@ -32,19 +34,12 @@ public NodeSchema(INodeSchema copy) { PrimaryKey = copy.PrimaryKey; - if (copy.Schema is Dictionary schema) - { - Schema = new Dictionary(schema, StringComparer.OrdinalIgnoreCase); - } - else - { - schema = new Dictionary(copy.Schema.Count, StringComparer.OrdinalIgnoreCase); + var schema = new ColumnList(); - foreach (var kvp in copy.Schema) - schema[kvp.Key] = kvp.Value; + foreach (var kvp in copy.Schema) + schema[kvp.Key] = kvp.Value; - Schema = schema; - } + Schema = schema; if (copy.Aliases is Dictionary> aliases) { @@ -168,4 +163,105 @@ public interface INodeSchema /// true if the data is sorted by the required columns, irrespective of the column ordering, or false otherwise bool IsSortedBy(ISet requiredSorts); } + + class ColumnList : IDictionary, IReadOnlyDictionary + { + private readonly OrderedDictionary _inner; + + public ColumnList() + { + _inner = new OrderedDictionary(StringComparer.OrdinalIgnoreCase); + } + + public DataTypeReference this[string key] + { + get => (DataTypeReference)_inner[key]; + set => _inner[key] = value; + } + + public ICollection Keys => _inner.Keys.Cast().ToList(); + + public ICollection Values => _inner.Values.Cast().ToList(); + + public int Count => _inner.Count; + + public bool IsReadOnly => false; + + IEnumerable IReadOnlyDictionary.Keys => _inner.Keys.Cast(); + + IEnumerable IReadOnlyDictionary.Values => _inner.Values.Cast(); + + public void Add(string key, DataTypeReference value) + { + _inner.Add(key, value); + } + + public void Add(KeyValuePair item) + { + _inner.Add(item.Key, item.Value); + } + + public void Clear() + { + _inner.Clear(); + } + + public bool Contains(KeyValuePair item) + { + return TryGetValue(item.Key, out var value) && value == item.Value; + } + + public bool ContainsKey(string key) + { + return _inner.Contains(key); + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + _inner.CopyTo(array, arrayIndex); + } + + public IEnumerator> GetEnumerator() + { + var enumerator = _inner.GetEnumerator(); + + while (enumerator.MoveNext()) + yield return new KeyValuePair((string)enumerator.Key, (DataTypeReference)enumerator.Value); + } + + public bool Remove(string key) + { + if (!_inner.Contains(key)) + return false; + + _inner.Remove(key); + return true; + } + + public bool Remove(KeyValuePair item) + { + if (!Contains(item)) + return false; + + _inner.Remove(item.Key); + return true; + } + + public bool TryGetValue(string key, out DataTypeReference value) + { + if (!_inner.Contains(key)) + { + value = null; + return false; + } + + value = (DataTypeReference)_inner[key]; + return true; + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + } } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RetrieveTotalRecordCountNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RetrieveTotalRecordCountNode.cs index acc1b475..05f1c7f2 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RetrieveTotalRecordCountNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/RetrieveTotalRecordCountNode.cs @@ -51,7 +51,7 @@ public override INodeSchema GetSchema(NodeCompilationContext context) { return new NodeSchema( primaryKey: null, - schema: new Dictionary(StringComparer.OrdinalIgnoreCase) + schema: new ColumnList { [$"{EntityName}_count"] = DataTypeHelpers.BigInt }, diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SelectNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SelectNode.cs index c402b485..3e778f48 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SelectNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SelectNode.cs @@ -284,7 +284,7 @@ internal static void ExpandWildcardColumns(IDataExecutionPlanNodeInternal source continue; } - foreach (var src in sourceSchema.Schema.Keys.Where(k => col.SourceColumn == null || k.StartsWith(col.SourceColumn + ".", StringComparison.OrdinalIgnoreCase)).OrderBy(k => k, StringComparer.OrdinalIgnoreCase)) + foreach (var src in sourceSchema.Schema.Keys.Where(k => col.SourceColumn == null || k.StartsWith(col.SourceColumn + ".", StringComparison.OrdinalIgnoreCase))) { // Columns might be available in the logical source schema but not in // the real one, e.g. due to aggregation diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SystemFunctionNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SystemFunctionNode.cs index 3c47caa7..25b3bac9 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SystemFunctionNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/SystemFunctionNode.cs @@ -50,7 +50,7 @@ public override INodeSchema GetSchema(NodeCompilationContext context) { case SystemFunction.fn_helpcollations: return new NodeSchema( - schema: new Dictionary(StringComparer.OrdinalIgnoreCase) + schema: new ColumnList { ["name"] = DataTypeHelpers.NVarChar(128, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), ["description"] = DataTypeHelpers.NVarChar(1000, dataSource.DefaultCollation, CollationLabel.CoercibleDefault), From 113966d631651771041890612a4efd85293cf417 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Mon, 10 Apr 2023 17:05:12 +0100 Subject: [PATCH 19/34] Allow ordering columns in original schema order instead of alphabetically --- .../ExecutionPlanTests.cs | 2 + .../OptionsWrapper.cs | 3 + .../Sql2FetchXmlTests.cs | 2 + MarkMpn.Sql4Cds.Engine.Tests/StubOptions.cs | 2 + .../Ado/CancellationTokenOptionsWrapper.cs | 2 + .../Ado/ChangeDatabaseOptionsWrapper.cs | 3 + .../Ado/DefaultQueryExecutionOptions.cs | 2 + .../Ado/Sql4CdsConnection.cs | 9 +++ .../ExecutionPlan/FetchXmlScan.cs | 25 +++++-- .../ExecutionPlan/GlobalOptionSetQueryNode.cs | 13 +++- .../ExecutionPlan/MetadataQueryNode.cs | 68 ++++++++++++++++--- .../IQueryExecutionOptions.cs | 21 ++++++ MarkMpn.Sql4Cds/QueryExecutionOptions.cs | 1 + MarkMpn.Sql4Cds/Settings.cs | 2 + MarkMpn.Sql4Cds/SettingsForm.Designer.cs | 17 ++++- MarkMpn.Sql4Cds/SettingsForm.cs | 2 + MarkMpn.Sql4Cds/SettingsForm.resx | 2 +- 17 files changed, 157 insertions(+), 19 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs index e438066c..c3a25402 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs @@ -81,6 +81,8 @@ void IQueryExecutionOptions.Progress(double? progress, string message) bool IQueryExecutionOptions.QuotedIdentifiers => true; + public ColumnOrdering ColumnOrdering => ColumnOrdering.Alphabetical; + [TestMethod] public void SimpleSelect() { diff --git a/MarkMpn.Sql4Cds.Engine.Tests/OptionsWrapper.cs b/MarkMpn.Sql4Cds.Engine.Tests/OptionsWrapper.cs index b7bc8d4f..97482a8f 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/OptionsWrapper.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/OptionsWrapper.cs @@ -31,6 +31,7 @@ public OptionsWrapper(IQueryExecutionOptions options) PrimaryDataSource = options.PrimaryDataSource; UserId = options.UserId; QuotedIdentifiers = options.QuotedIdentifiers; + ColumnOrdering = options.ColumnOrdering; } public CancellationToken CancellationToken { get; set; } @@ -63,6 +64,8 @@ public OptionsWrapper(IQueryExecutionOptions options) public bool QuotedIdentifiers { get; set; } + public ColumnOrdering ColumnOrdering { get; set; } + public void ConfirmDelete(ConfirmDmlStatementEventArgs e) { } diff --git a/MarkMpn.Sql4Cds.Engine.Tests/Sql2FetchXmlTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/Sql2FetchXmlTests.cs index de6ee303..89dc7d64 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/Sql2FetchXmlTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/Sql2FetchXmlTests.cs @@ -56,6 +56,8 @@ public class Sql2FetchXmlTests : FakeXrmEasyTestsBase, IQueryExecutionOptions bool IQueryExecutionOptions.QuotedIdentifiers => false; + ColumnOrdering IQueryExecutionOptions.ColumnOrdering => ColumnOrdering.Alphabetical; + [TestMethod] public void SimpleSelect() { diff --git a/MarkMpn.Sql4Cds.Engine.Tests/StubOptions.cs b/MarkMpn.Sql4Cds.Engine.Tests/StubOptions.cs index 091cd630..a265a3b0 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/StubOptions.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/StubOptions.cs @@ -33,6 +33,8 @@ class StubOptions : IQueryExecutionOptions bool IQueryExecutionOptions.BypassCustomPlugins => false; + ColumnOrdering IQueryExecutionOptions.ColumnOrdering => ColumnOrdering.Alphabetical; + void IQueryExecutionOptions.ConfirmInsert(ConfirmDmlStatementEventArgs e) { } diff --git a/MarkMpn.Sql4Cds.Engine/Ado/CancellationTokenOptionsWrapper.cs b/MarkMpn.Sql4Cds.Engine/Ado/CancellationTokenOptionsWrapper.cs index 9be106f4..95c990db 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/CancellationTokenOptionsWrapper.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/CancellationTokenOptionsWrapper.cs @@ -47,6 +47,8 @@ public CancellationTokenOptionsWrapper(IQueryExecutionOptions options, Cancellat public bool QuotedIdentifiers => _options.QuotedIdentifiers; + public ColumnOrdering ColumnOrdering => _options.ColumnOrdering; + public void ConfirmDelete(ConfirmDmlStatementEventArgs e) { _options.ConfirmDelete(e); diff --git a/MarkMpn.Sql4Cds.Engine/Ado/ChangeDatabaseOptionsWrapper.cs b/MarkMpn.Sql4Cds.Engine/Ado/ChangeDatabaseOptionsWrapper.cs index 45f7b2ac..b764273a 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/ChangeDatabaseOptionsWrapper.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/ChangeDatabaseOptionsWrapper.cs @@ -28,6 +28,7 @@ public ChangeDatabaseOptionsWrapper(Sql4CdsConnection connection, IQueryExecutio UseLocalTimeZone = options.UseLocalTimeZone; BypassCustomPlugins = options.BypassCustomPlugins; QuotedIdentifiers = options.QuotedIdentifiers; + ColumnOrdering = options.ColumnOrdering; } public CancellationToken CancellationToken => _options.CancellationToken; @@ -58,6 +59,8 @@ public ChangeDatabaseOptionsWrapper(Sql4CdsConnection connection, IQueryExecutio public bool QuotedIdentifiers { get; set; } + public ColumnOrdering ColumnOrdering { get; set; } + public void ConfirmDelete(ConfirmDmlStatementEventArgs e) { if (!e.Cancel) diff --git a/MarkMpn.Sql4Cds.Engine/Ado/DefaultQueryExecutionOptions.cs b/MarkMpn.Sql4Cds.Engine/Ado/DefaultQueryExecutionOptions.cs index be742e13..75e22104 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/DefaultQueryExecutionOptions.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/DefaultQueryExecutionOptions.cs @@ -90,6 +90,8 @@ public DefaultQueryExecutionOptions(DataSource dataSource, CancellationToken can public bool QuotedIdentifiers => true; + public ColumnOrdering ColumnOrdering => ColumnOrdering.Strict; + public void ConfirmDelete(ConfirmDmlStatementEventArgs e) { } diff --git a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsConnection.cs b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsConnection.cs index e5ccaf7c..75edd1f5 100644 --- a/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsConnection.cs +++ b/MarkMpn.Sql4Cds.Engine/Ado/Sql4CdsConnection.cs @@ -221,6 +221,15 @@ public bool QuotedIdentifiers /// public string ApplicationName { get; set; } + /// + /// Returns or sets how columns should be sorted within tables + /// + public ColumnOrdering ColumnOrdering + { + get => _options.ColumnOrdering; + set => _options.ColumnOrdering = value; + } + internal Dictionary GlobalVariableTypes => _globalVariableTypes; internal Dictionary GlobalVariableValues => _globalVariableValues; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs index 6c85af17..f7907002 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/FetchXmlScan.cs @@ -661,7 +661,7 @@ public override INodeSchema GetSchema(NodeCompilationContext context) var notNullColumns = new HashSet(); var sortOrder = new List(); - AddSchemaAttributes(dataSource, schema, aliases, ref primaryKey, notNullColumns, sortOrder, entity.name, Alias, entity.Items, true, false); + AddSchemaAttributes(context, dataSource, schema, aliases, ref primaryKey, notNullColumns, sortOrder, entity.name, Alias, entity.Items, true, false); _lastSchema = new NodeSchema( primaryKey: primaryKey, @@ -759,7 +759,7 @@ internal static bool IsValidAlias(string alias) return Regex.IsMatch(alias, "^[A-Za-z_][A-Za-z0-9_]*$"); } - private void AddSchemaAttributes(DataSource dataSource, ColumnList schema, Dictionary> aliases, ref string primaryKey, HashSet notNullColumns, List sortOrder, string entityName, string alias, object[] items, bool innerJoin, bool requireTablePrefix) + private void AddSchemaAttributes(NodeCompilationContext context, DataSource dataSource, ColumnList schema, Dictionary> aliases, ref string primaryKey, HashSet notNullColumns, List sortOrder, string entityName, string alias, object[] items, bool innerJoin, bool requireTablePrefix) { if (items == null && !ReturnFullSchema) return; @@ -768,7 +768,7 @@ private void AddSchemaAttributes(DataSource dataSource, ColumnList schema, Dicti if (ReturnFullSchema && !FetchXml.aggregate) { - foreach (var attrMetadata in meta.Attributes.OrderBy(a => a.LogicalName)) + foreach (var attrMetadata in SortAttributes(meta, context)) { if (attrMetadata.IsValidForRead == false) continue; @@ -842,7 +842,7 @@ private void AddSchemaAttributes(DataSource dataSource, ColumnList schema, Dicti if (items.OfType().Any()) { - foreach (var attrMetadata in meta.Attributes) + foreach (var attrMetadata in SortAttributes(meta, context)) { if (attrMetadata.IsValidForRead == false) continue; @@ -907,7 +907,7 @@ private void AddSchemaAttributes(DataSource dataSource, ColumnList schema, Dicti } } - AddSchemaAttributes(dataSource, schema, aliases, ref primaryKey, notNullColumns, sortOrder, linkEntity.name, linkEntity.alias, linkEntity.Items, innerJoin && linkEntity.linktype == "inner", requireTablePrefix || linkEntity.RequireTablePrefix); + AddSchemaAttributes(context, dataSource, schema, aliases, ref primaryKey, notNullColumns, sortOrder, linkEntity.name, linkEntity.alias, linkEntity.Items, innerJoin && linkEntity.linktype == "inner", requireTablePrefix || linkEntity.RequireTablePrefix); } if (innerJoin) @@ -918,6 +918,21 @@ private void AddSchemaAttributes(DataSource dataSource, ColumnList schema, Dicti } } + private IEnumerable SortAttributes(EntityMetadata metadata, NodeCompilationContext context) + { + switch (context.Options.ColumnOrdering) + { + case ColumnOrdering.Alphabetical: + return metadata.Attributes.OrderBy(a => a.LogicalName); + + case ColumnOrdering.Strict: + return metadata.Attributes.OrderBy(a => a.ColumnNumber.Value); + + default: + throw new ArgumentOutOfRangeException("Invalid column ordering " + context.Options.ColumnOrdering); + } + } + private void AddNotNullFilters(ColumnList schema, Dictionary> aliases, HashSet notNullColumns, string alias, filter filter) { if (filter.Items == null) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GlobalOptionSetQueryNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GlobalOptionSetQueryNode.cs index fbb79d97..f7d45365 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GlobalOptionSetQueryNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/GlobalOptionSetQueryNode.cs @@ -20,6 +20,7 @@ class OptionSetProperty public IDictionary> Accessors { get; set; } public DataTypeReference SqlType { get; set; } public Type NetType { get; set; } + public IComparable[] DataMemberOrder { get; set; } } private static readonly Type[] _optionsetTypes; @@ -65,7 +66,8 @@ static GlobalOptionSetQueryNode() Name = g.Key.ToLowerInvariant(), SqlType = type, NetType = netType, - Accessors = g.ToDictionary(p => p.Type, p => MetadataQueryNode.GetPropertyAccessor(p.Property, netType)) + Accessors = g.ToDictionary(p => p.Type, p => MetadataQueryNode.GetPropertyAccessor(p.Property, netType)), + DataMemberOrder = MetadataQueryNode.GetDataMemberOrder(g.First().Property) }; }) .Where(p => p != null) @@ -120,7 +122,14 @@ public override INodeSchema GetSchema(NodeCompilationContext context) var schema = new ColumnList(); var aliases = new Dictionary>(StringComparer.OrdinalIgnoreCase); - foreach (var prop in _optionsetProps.Values) + var props = (IEnumerable)_optionsetProps.Values; + + if (context.Options.ColumnOrdering == ColumnOrdering.Alphabetical) + props = props.OrderBy(p => p.Name); + else + props = props.OrderBy(p => p.DataMemberOrder[0]).ThenBy(p => p.DataMemberOrder[1]).ThenBy(p => p.DataMemberOrder[2]); + + foreach (var prop in props) { schema[$"{Alias}.{prop.Name}"] = prop.SqlType; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MetadataQueryNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MetadataQueryNode.cs index 581287c2..a6671f4a 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MetadataQueryNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/MetadataQueryNode.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; +using System.Runtime.Serialization; using System.Text; using System.Threading.Tasks; using Microsoft.SqlServer.TransactSql.ScriptDom; @@ -28,6 +29,7 @@ class MetadataProperty public Func Accessor { get; set; } public DataTypeReference SqlType { get; set; } public Type Type { get; set; } + public IComparable[] DataMemberOrder { get; set; } } class AttributeProperty @@ -38,6 +40,7 @@ class AttributeProperty public DataTypeReference SqlType { get; set; } public Type Type { get; set; } public bool IsNullable { get; set; } + public IComparable[] DataMemberOrder { get; set; } } private IDictionary _entityCols; @@ -72,7 +75,7 @@ static MetadataQueryNode() _entityProps = typeof(EntityMetadata) .GetProperties(BindingFlags.Public | BindingFlags.Instance) .Where(p => !excludedEntityProps.Contains(p.Name)) - .ToDictionary(p => p.Name, p => new MetadataProperty { SqlName = p.Name.ToLowerInvariant(), PropertyName = p.Name, Type = p.PropertyType, SqlType = GetPropertyType(p.PropertyType), Accessor = GetPropertyAccessor(p, GetPropertyType(p.PropertyType).ToNetType(out _)) }, StringComparer.OrdinalIgnoreCase); + .ToDictionary(p => p.Name, p => new MetadataProperty { SqlName = p.Name.ToLowerInvariant(), PropertyName = p.Name, Type = p.PropertyType, SqlType = GetPropertyType(p.PropertyType), Accessor = GetPropertyAccessor(p, GetPropertyType(p.PropertyType).ToNetType(out _)), DataMemberOrder = GetDataMemberOrder(p) }, StringComparer.OrdinalIgnoreCase); var excludedOneToManyRelationshipProps = new[] { @@ -85,7 +88,7 @@ static MetadataQueryNode() _oneToManyRelationshipProps = typeof(OneToManyRelationshipMetadata) .GetProperties(BindingFlags.Public | BindingFlags.Instance) .Where(p => !excludedOneToManyRelationshipProps.Contains(p.Name)) - .ToDictionary(p => p.Name, p => new MetadataProperty { SqlName = p.Name.ToLowerInvariant(), PropertyName = p.Name, Type = p.PropertyType, SqlType = GetPropertyType(p.PropertyType), Accessor = GetPropertyAccessor(p, GetPropertyType(p.PropertyType).ToNetType(out _)) }, StringComparer.OrdinalIgnoreCase); + .ToDictionary(p => p.Name, p => new MetadataProperty { SqlName = p.Name.ToLowerInvariant(), PropertyName = p.Name, Type = p.PropertyType, SqlType = GetPropertyType(p.PropertyType), Accessor = GetPropertyAccessor(p, GetPropertyType(p.PropertyType).ToNetType(out _)), DataMemberOrder = GetDataMemberOrder(p) }, StringComparer.OrdinalIgnoreCase); var excludedManyToManyRelationshipProps = new[] { @@ -97,7 +100,7 @@ static MetadataQueryNode() _manyToManyRelationshipProps = typeof(ManyToManyRelationshipMetadata) .GetProperties(BindingFlags.Public | BindingFlags.Instance) .Where(p => !excludedManyToManyRelationshipProps.Contains(p.Name)) - .ToDictionary(p => p.Name, p => new MetadataProperty { SqlName = p.Name.ToLowerInvariant(), PropertyName = p.Name, Type = p.PropertyType, SqlType = GetPropertyType(p.PropertyType), Accessor = GetPropertyAccessor(p, GetPropertyType(p.PropertyType).ToNetType(out _)) }, StringComparer.OrdinalIgnoreCase); + .ToDictionary(p => p.Name, p => new MetadataProperty { SqlName = p.Name.ToLowerInvariant(), PropertyName = p.Name, Type = p.PropertyType, SqlType = GetPropertyType(p.PropertyType), Accessor = GetPropertyAccessor(p, GetPropertyType(p.PropertyType).ToNetType(out _)), DataMemberOrder = GetDataMemberOrder(p) }, StringComparer.OrdinalIgnoreCase); // Get a list of all attribute types _attributeTypes = typeof(AttributeMetadata).Assembly @@ -145,7 +148,8 @@ static MetadataQueryNode() SqlType = type, Type = netType, Accessors = g.ToDictionary(p => p.Type, p => GetPropertyAccessor(p.Property, netType)), - IsNullable = isNullable + IsNullable = isNullable, + DataMemberOrder = GetDataMemberOrder(g.First().Property) }; }) .Where(p => p != null) @@ -456,7 +460,12 @@ public override INodeSchema GetSchema(NodeCompilationContext context) if (Query.Properties != null) entityProps = entityProps.Where(p => Query.Properties.AllProperties || Query.Properties.PropertyNames.Contains(p.PropertyName, StringComparer.OrdinalIgnoreCase)); - foreach (var prop in entityProps.OrderBy(p => p.SqlName)) + if (context.Options.ColumnOrdering == ColumnOrdering.Alphabetical) + entityProps = entityProps.OrderBy(p => p.SqlName); + else + entityProps = entityProps.OrderBy(p => p.DataMemberOrder[0]).ThenBy(p => p.DataMemberOrder[1]).ThenBy(p => p.DataMemberOrder[2]); + + foreach (var prop in entityProps) { var fullName = $"{EntityAlias}.{prop.SqlName}"; schema[fullName] = prop.SqlType; @@ -483,7 +492,12 @@ public override INodeSchema GetSchema(NodeCompilationContext context) if (Query.AttributeQuery?.Properties != null) attributeProps = attributeProps.Where(p => Query.AttributeQuery.Properties.AllProperties || Query.AttributeQuery.Properties.PropertyNames.Contains(p.PropertyName, StringComparer.OrdinalIgnoreCase)); - foreach (var prop in attributeProps.OrderBy(p => p.SqlName)) + if (context.Options.ColumnOrdering == ColumnOrdering.Alphabetical) + attributeProps = attributeProps.OrderBy(p => p.SqlName); + else + attributeProps = attributeProps.OrderBy(p => p.DataMemberOrder[0]).ThenBy(p => p.DataMemberOrder[1]).ThenBy(p => p.DataMemberOrder[2]); + + foreach (var prop in attributeProps) { var fullName = $"{AttributeAlias}.{prop.SqlName}"; schema[fullName] = prop.SqlType; @@ -511,7 +525,12 @@ public override INodeSchema GetSchema(NodeCompilationContext context) if (Query.RelationshipQuery?.Properties != null) relationshipProps = relationshipProps.Where(p => Query.RelationshipQuery.Properties.AllProperties || Query.RelationshipQuery.Properties.PropertyNames.Contains(p.PropertyName, StringComparer.OrdinalIgnoreCase)); - foreach (var prop in relationshipProps.OrderBy(p => p.SqlName)) + if (context.Options.ColumnOrdering == ColumnOrdering.Alphabetical) + relationshipProps = relationshipProps.OrderBy(p => p.SqlName); + else + relationshipProps = relationshipProps.OrderBy(p => p.DataMemberOrder[0]).ThenBy(p => p.DataMemberOrder[1]).ThenBy(p => p.DataMemberOrder[2]); + + foreach (var prop in relationshipProps) { var fullName = $"{OneToManyRelationshipAlias}.{prop.SqlName}"; schema[fullName] = prop.SqlType; @@ -539,7 +558,12 @@ public override INodeSchema GetSchema(NodeCompilationContext context) if (Query.RelationshipQuery?.Properties != null) relationshipProps = relationshipProps.Where(p => Query.RelationshipQuery.Properties.AllProperties || Query.RelationshipQuery.Properties.PropertyNames.Contains(p.PropertyName, StringComparer.OrdinalIgnoreCase)); - foreach (var prop in relationshipProps.OrderBy(p => p.SqlName)) + if (context.Options.ColumnOrdering == ColumnOrdering.Alphabetical) + relationshipProps = relationshipProps.OrderBy(p => p.SqlName); + else + relationshipProps = relationshipProps.OrderBy(p => p.DataMemberOrder[0]).ThenBy(p => p.DataMemberOrder[1]).ThenBy(p => p.DataMemberOrder[2]); + + foreach (var prop in relationshipProps) { var fullName = $"{ManyToOneRelationshipAlias}.{prop.SqlName}"; schema[fullName] = prop.SqlType; @@ -567,7 +591,12 @@ public override INodeSchema GetSchema(NodeCompilationContext context) if (Query.RelationshipQuery?.Properties != null) relationshipProps = relationshipProps.Where(p => Query.RelationshipQuery.Properties.AllProperties || Query.RelationshipQuery.Properties.PropertyNames.Contains(p.PropertyName, StringComparer.OrdinalIgnoreCase)); - foreach (var prop in relationshipProps.OrderBy(p => p.SqlName)) + if (context.Options.ColumnOrdering == ColumnOrdering.Alphabetical) + relationshipProps = relationshipProps.OrderBy(p => p.SqlName); + else + relationshipProps = relationshipProps.OrderBy(p => p.DataMemberOrder[0]).ThenBy(p => p.DataMemberOrder[1]).ThenBy(p => p.DataMemberOrder[2]); + + foreach (var prop in relationshipProps) { var fullName = $"{ManyToManyRelationshipAlias}.{prop.SqlName}"; schema[fullName] = prop.SqlType; @@ -715,6 +744,27 @@ internal static Func GetPropertyAccessor(PropertyInfo prop, Type return func; } + internal static IComparable[] GetDataMemberOrder(PropertyInfo prop) + { + // https://learn.microsoft.com/en-us/dotnet/framework/wcf/feature-details/data-member-order + var inheritanceDepth = 0; + var type = prop.DeclaringType; + while (type.BaseType != null) + { + inheritanceDepth++; + type = type.BaseType; + } + + var attr = prop.GetCustomAttribute(); + + return new IComparable[] + { + inheritanceDepth, + attr?.Order ?? Int32.MinValue, + prop.Name + }; + } + private static SqlString ApplyCollation(string value) { if (value == null) diff --git a/MarkMpn.Sql4Cds.Engine/IQueryExecutionOptions.cs b/MarkMpn.Sql4Cds.Engine/IQueryExecutionOptions.cs index 52c83c22..d4cb9231 100644 --- a/MarkMpn.Sql4Cds.Engine/IQueryExecutionOptions.cs +++ b/MarkMpn.Sql4Cds.Engine/IQueryExecutionOptions.cs @@ -112,5 +112,26 @@ interface IQueryExecutionOptions /// Returns or sets a value indicating if SQL will be parsed using quoted identifiers /// bool QuotedIdentifiers { get; } + + /// + /// Indicates how columns should be assumed to be ordered within tables + /// + ColumnOrdering ColumnOrdering { get; } + } + + /// + /// Indicates how columns should be assumed to be ordered within tables + /// + public enum ColumnOrdering + { + /// + /// Columns are ordered according to the original metadata order + /// + Strict, + + /// + /// Columns are ordered alphabetically + /// + Alphabetical } } diff --git a/MarkMpn.Sql4Cds/QueryExecutionOptions.cs b/MarkMpn.Sql4Cds/QueryExecutionOptions.cs index 9ba36d76..19d18223 100644 --- a/MarkMpn.Sql4Cds/QueryExecutionOptions.cs +++ b/MarkMpn.Sql4Cds/QueryExecutionOptions.cs @@ -42,6 +42,7 @@ public void ApplySettings(bool execute) _con.UseLocalTimeZone = Settings.Instance.ShowLocalTimes; _con.BypassCustomPlugins = Settings.Instance.BypassCustomPlugins; _con.QuotedIdentifiers = Settings.Instance.QuotedIdentifiers; + _con.ColumnOrdering = Settings.Instance.ColumnOrdering; _con.PreInsert += ConfirmInsert; _con.PreUpdate += ConfirmUpdate; diff --git a/MarkMpn.Sql4Cds/Settings.cs b/MarkMpn.Sql4Cds/Settings.cs index df3df1ec..06cf053d 100644 --- a/MarkMpn.Sql4Cds/Settings.cs +++ b/MarkMpn.Sql4Cds/Settings.cs @@ -62,6 +62,8 @@ public class Settings public bool ShowFetchXMLInEstimatedExecutionPlans { get; set; } = true; public FetchXml2SqlOptions FetchXml2SqlOptions { get; set; } = new FetchXml2SqlOptions(); + + public ColumnOrdering ColumnOrdering { get; set; } = ColumnOrdering.Alphabetical; } public class TabContent diff --git a/MarkMpn.Sql4Cds/SettingsForm.Designer.cs b/MarkMpn.Sql4Cds/SettingsForm.Designer.cs index 57340bf6..7884b5fd 100644 --- a/MarkMpn.Sql4Cds/SettingsForm.Designer.cs +++ b/MarkMpn.Sql4Cds/SettingsForm.Designer.cs @@ -64,6 +64,7 @@ private void InitializeComponent() this.showTooltipsCheckbox = new System.Windows.Forms.CheckBox(); this.tabControl1 = new System.Windows.Forms.TabControl(); this.tabPage1 = new System.Windows.Forms.TabPage(); + this.schemaColumnOrderingCheckbox = new System.Windows.Forms.CheckBox(); this.showFetchXMLInEstimatedExecutionPlansCheckBox = new System.Windows.Forms.CheckBox(); this.tabPage2 = new System.Windows.Forms.TabPage(); this.label15 = new System.Windows.Forms.Label(); @@ -386,7 +387,7 @@ private void InitializeComponent() this.localTimesComboBox.Margin = new System.Windows.Forms.Padding(2); this.localTimesComboBox.Name = "localTimesComboBox"; this.localTimesComboBox.Size = new System.Drawing.Size(203, 21); - this.localTimesComboBox.TabIndex = 13; + this.localTimesComboBox.TabIndex = 12; // // label11 // @@ -395,7 +396,7 @@ private void InitializeComponent() this.label11.Margin = new System.Windows.Forms.Padding(2, 0, 2, 0); this.label11.Name = "label11"; this.label11.Size = new System.Drawing.Size(128, 13); - this.label11.TabIndex = 12; + this.label11.TabIndex = 11; this.label11.Text = "Treat date/time values as"; // // maxDopUpDown @@ -481,6 +482,7 @@ private void InitializeComponent() // // tabPage1 // + this.tabPage1.Controls.Add(this.schemaColumnOrderingCheckbox); this.tabPage1.Controls.Add(this.showFetchXMLInEstimatedExecutionPlansCheckBox); this.tabPage1.Controls.Add(this.pictureBox6); this.tabPage1.Controls.Add(this.pictureBox4); @@ -507,6 +509,16 @@ private void InitializeComponent() this.tabPage1.Text = "Query Execution"; this.tabPage1.UseVisualStyleBackColor = true; // + // schemaColumnOrderingCheckbox + // + this.schemaColumnOrderingCheckbox.AutoSize = true; + this.schemaColumnOrderingCheckbox.Location = new System.Drawing.Point(9, 193); + this.schemaColumnOrderingCheckbox.Name = "schemaColumnOrderingCheckbox"; + this.schemaColumnOrderingCheckbox.Size = new System.Drawing.Size(163, 17); + this.schemaColumnOrderingCheckbox.TabIndex = 13; + this.schemaColumnOrderingCheckbox.Text = "Use schema column ordering"; + this.schemaColumnOrderingCheckbox.UseVisualStyleBackColor = true; + // // showFetchXMLInEstimatedExecutionPlansCheckBox // this.showFetchXMLInEstimatedExecutionPlansCheckBox.AutoSize = true; @@ -887,5 +899,6 @@ private void InitializeComponent() private ScintillaNET.Scintilla nativeSqlScintilla; private System.Windows.Forms.CheckBox showFetchXMLInEstimatedExecutionPlansCheckBox; private System.Windows.Forms.LinkLabel fetchXml2SqlConversionAdvancedLinkLabel; + private System.Windows.Forms.CheckBox schemaColumnOrderingCheckbox; } } \ No newline at end of file diff --git a/MarkMpn.Sql4Cds/SettingsForm.cs b/MarkMpn.Sql4Cds/SettingsForm.cs index 6f11a3ed..b6f4cbba 100644 --- a/MarkMpn.Sql4Cds/SettingsForm.cs +++ b/MarkMpn.Sql4Cds/SettingsForm.cs @@ -42,6 +42,7 @@ public SettingsForm(Settings settings) localDateFormatCheckbox.Checked = settings.LocalFormatDates; simpleSqlRadioButton.Checked = !settings.UseNativeSqlConversion; nativeSqlRadioButton.Checked = settings.UseNativeSqlConversion; + schemaColumnOrderingCheckbox.Checked = settings.ColumnOrdering == ColumnOrdering.Strict; SetSqlStyle(simpleSqlScintilla); SetSqlStyle(nativeSqlScintilla); @@ -119,6 +120,7 @@ protected override void OnClosing(CancelEventArgs e) _settings.LocalFormatDates = localDateFormatCheckbox.Checked; _settings.UseNativeSqlConversion = nativeSqlRadioButton.Checked; _settings.FetchXml2SqlOptions = _fetchXml2SqlOptions; + _settings.ColumnOrdering = schemaColumnOrderingCheckbox.Checked ? ColumnOrdering.Strict : ColumnOrdering.Alphabetical; } } diff --git a/MarkMpn.Sql4Cds/SettingsForm.resx b/MarkMpn.Sql4Cds/SettingsForm.resx index 64900265..e109cc81 100644 --- a/MarkMpn.Sql4Cds/SettingsForm.resx +++ b/MarkMpn.Sql4Cds/SettingsForm.resx @@ -121,7 +121,7 @@ iVBORw0KGgoAAAANSUhEUgAAAFAAAABQCAIAAAABc2X6AAAABGdBTUEAALGPC/xhBQAAAAlwSFlzAAAO - wwAADsMBx2+oZAAAF2tJREFUeF7tm3l0leW1xvnj1laLDCKEUQwIARWrogwVrSBeiorKnHk6mUMCBAJh + wQAADsEBuJFr7QAAF2tJREFUeF7tm3l0leW1xvnj1laLDCKEUQwIARWrogwVrSBeiorKnHk6mUMCBAJh SIwiQ8KQhEwQAhhAkRlEnEASBUEqoG2Ve22VoXVoa+vqogp6r7X3+c7vZOflBOqw7l+3l/Ws3T08e797 f+/7DQdsix+26vAvhRY/ah1y2ZXtBSmXt+kolxScLq85jIMiKV0VBMuVjt+YAqsQAi5BEshjZQWyBHnk NxMPijlJvLyt10mj7CT5w1YhLX7Q8mpji+Tx2nS8om0nnM0BQYrxBX8tL+qV9s8sSFcIv0WDPJSiAkUo From 3cfdd0bbf68bbfc079f98fe7107cd024234d3af9 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Mon, 10 Apr 2023 17:15:07 +0100 Subject: [PATCH 20/34] Fixed disconnect button Fixes #284 --- MarkMpn.Sql4Cds/ObjectExplorer.Designer.cs | 3 +++ MarkMpn.Sql4Cds/ObjectExplorer.cs | 7 +++++++ MarkMpn.Sql4Cds/ObjectExplorer.resx | 2 +- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/MarkMpn.Sql4Cds/ObjectExplorer.Designer.cs b/MarkMpn.Sql4Cds/ObjectExplorer.Designer.cs index 1723095a..08bbecab 100644 --- a/MarkMpn.Sql4Cds/ObjectExplorer.Designer.cs +++ b/MarkMpn.Sql4Cds/ObjectExplorer.Designer.cs @@ -61,6 +61,7 @@ private void InitializeComponent() this.treeView.Size = new System.Drawing.Size(284, 236); this.treeView.TabIndex = 1; this.treeView.BeforeExpand += new System.Windows.Forms.TreeViewCancelEventHandler(this.treeView_BeforeExpand); + this.treeView.AfterSelect += new System.Windows.Forms.TreeViewEventHandler(this.treeView_AfterSelect); this.treeView.NodeMouseClick += new System.Windows.Forms.TreeNodeMouseClickEventHandler(this.treeView_NodeMouseClick); this.treeView.NodeMouseDoubleClick += new System.Windows.Forms.TreeNodeMouseClickEventHandler(this.treeView_NodeMouseDoubleClick); // @@ -189,11 +190,13 @@ private void InitializeComponent() // tsbDisconnect // this.tsbDisconnect.DisplayStyle = System.Windows.Forms.ToolStripItemDisplayStyle.Image; + this.tsbDisconnect.Enabled = false; this.tsbDisconnect.Image = ((System.Drawing.Image)(resources.GetObject("tsbDisconnect.Image"))); this.tsbDisconnect.ImageTransparentColor = System.Drawing.Color.Magenta; this.tsbDisconnect.Name = "tsbDisconnect"; this.tsbDisconnect.Size = new System.Drawing.Size(23, 22); this.tsbDisconnect.Text = "Disconnect"; + this.tsbDisconnect.Click += new System.EventHandler(this.disconnectToolStripMenuItem_Click); // // ObjectExplorer // diff --git a/MarkMpn.Sql4Cds/ObjectExplorer.cs b/MarkMpn.Sql4Cds/ObjectExplorer.cs index 49355e31..21033a65 100644 --- a/MarkMpn.Sql4Cds/ObjectExplorer.cs +++ b/MarkMpn.Sql4Cds/ObjectExplorer.cs @@ -563,6 +563,8 @@ private void disconnectToolStripMenuItem_Click(object sender, EventArgs e) node = node.Parent; node.Remove(); + + tsbDisconnect.Enabled = treeView.SelectedNode != null; } private void serverContextMenuStrip_Opening(object sender, System.ComponentModel.CancelEventArgs e) @@ -617,5 +619,10 @@ private void tsbConnect_Click(object sender, EventArgs e) { _connect(); } + + private void treeView_AfterSelect(object sender, TreeViewEventArgs e) + { + tsbDisconnect.Enabled = treeView.SelectedNode != null; + } } } diff --git a/MarkMpn.Sql4Cds/ObjectExplorer.resx b/MarkMpn.Sql4Cds/ObjectExplorer.resx index 3210be8c..c2d838d2 100644 --- a/MarkMpn.Sql4Cds/ObjectExplorer.resx +++ b/MarkMpn.Sql4Cds/ObjectExplorer.resx @@ -125,7 +125,7 @@ AAEAAAD/////AQAAAAAAAAAMAgAAAFdTeXN0ZW0uV2luZG93cy5Gb3JtcywgVmVyc2lvbj00LjAuMC4w LCBDdWx0dXJlPW5ldXRyYWwsIFB1YmxpY0tleVRva2VuPWI3N2E1YzU2MTkzNGUwODkFAQAAACZTeXN0 ZW0uV2luZG93cy5Gb3Jtcy5JbWFnZUxpc3RTdHJlYW1lcgEAAAAERGF0YQcCAgAAAAkDAAAADwMAAADy - GwAAAk1TRnQBSQFMAgEBGwEAAdgBAAHYAQABEAEAARABAAT/AQkBAAj/AUIBTQE2AQQGAAE2AQQCAAEo + GwAAAk1TRnQBSQFMAgEBGwEAAfABAAHwAQABEAEAARABAAT/AQkBAAj/AUIBTQE2AQQGAAE2AQQCAAEo AwABQAMAAXADAAEBAQABCAYAARwYAAGAAgABgAMAAoABAAGAAwABgAEAAYABAAKAAgADwAEAAcAB3AHA AQAB8AHKAaYBAAEzBQABMwEAATMBAAEzAQACMwIAAxYBAAMcAQADIgEAAykBAANVAQADTQEAA0IBAAM5 AQABgAF8Af8BAAJQAf8BAAGTAQAB1gEAAf8B7AHMAQABxgHWAe8BAAHWAucBAAGQAakBrQIAAf8BMwMA From 157ad5a50a5ff06714ac143911987d9efd3c0501 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Mon, 10 Apr 2023 17:31:11 +0100 Subject: [PATCH 21/34] Allow setting batch size for DML queries via query hint Fixes #282 --- .../ExecutionPlan/AssignVariablesNode.cs | 3 + .../ExecutionPlan/BaseDmlNode.cs | 35 +++++++++- .../ExecutionPlan/DeleteNode.cs | 3 + .../ExecutionPlan/ExecuteAsNode.cs | 3 + .../ExecutionPlan/InsertNode.cs | 3 + .../ExecutionPlan/UpdateNode.cs | 3 + MarkMpn.Sql4Cds.Engine/MetaMetadataCache.cs | 64 ++++++++++++++++++- .../OptimizerHintValidatingVisitor.cs | 3 + 8 files changed, 112 insertions(+), 5 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AssignVariablesNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AssignVariablesNode.cs index 3d6a6baf..1b1a9ffb 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AssignVariablesNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/AssignVariablesNode.cs @@ -31,6 +31,9 @@ class AssignVariablesNode : BaseDmlNode [Browsable(false)] public override int MaxDOP { get; set; } + [Browsable(false)] + public override int BatchSize { get; set; } + [Browsable(false)] public override bool BypassCustomPluginExecution { get; set; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs index 291bddc3..086ce83d 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs @@ -97,6 +97,12 @@ public void Dispose() [Description("The maximum number of operations that will be performed in parallel")] public abstract int MaxDOP { get; set; } + /// + /// The number of requests that will be submitted in a single batch + /// + [Description("The number of requests that will be submitted in a single batch")] + public abstract int BatchSize { get; set; } + /// /// Indicates if custom plugins should be skipped /// @@ -139,6 +145,7 @@ public virtual IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext } MaxDOP = GetMaxDOP(context, hints); + BatchSize = GetBatchSize(context, hints); BypassCustomPluginExecution = GetBypassPluginExecution(context, hints); return new[] { this }; @@ -157,7 +164,7 @@ private int GetMaxDOP(NodeCompilationContext context, IList query if (maxDopHint != null) { if (!(maxDopHint.Value is IntegerLiteral maxDop) || !Int32.TryParse(maxDop.Value, out var value) || value < 1) - throw new NotSupportedQueryFragmentException("MAXDOP requires a positive integer value"); + throw new NotSupportedQueryFragmentException("MAXDOP requires a positive integer value", maxDopHint); return value; } @@ -165,6 +172,28 @@ private int GetMaxDOP(NodeCompilationContext context, IList query return context.Options.MaxDegreeOfParallelism; } + private int GetBatchSize(NodeCompilationContext context, IList queryHints) + { + if (queryHints == null) + return context.Options.BatchSize; + + var batchSizeHint = queryHints + .OfType() + .SelectMany(hint => hint.Hints) + .Where(hint => hint.Value.StartsWith("BATCH_SIZE_", StringComparison.OrdinalIgnoreCase)) + .FirstOrDefault(); + + if (batchSizeHint != null) + { + if (!Int32.TryParse(batchSizeHint.Value.Substring(11), out var value) || value < 1) + throw new NotSupportedQueryFragmentException("BATCH_SIZE requires a positive integer value", batchSizeHint); + + return value; + } + + return context.Options.BatchSize; + } + private bool GetBypassPluginExecution(NodeCompilationContext context, IList queryHints) { if (queryHints == null) @@ -555,7 +584,7 @@ protected string ExecuteDmlOperation(IOrganizationService org, IQueryExecutionOp if (BypassCustomPluginExecution) request.Parameters["BypassCustomPluginExecution"] = true; - if (options.BatchSize == 1) + if (BatchSize == 1) { var newCount = Interlocked.Increment(ref inProgressCount); var progress = (double)newCount / entities.Count; @@ -586,7 +615,7 @@ protected string ExecuteDmlOperation(IOrganizationService org, IQueryExecutionOp threadLocalState.EMR.Requests.Add(request); - if (threadLocalState.EMR.Requests.Count == options.BatchSize) + if (threadLocalState.EMR.Requests.Count == BatchSize) { var newCount = Interlocked.Add(ref inProgressCount, threadLocalState.EMR.Requests.Count); var progress = (double)newCount / entities.Count; diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DeleteNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DeleteNode.cs index a7552813..24543278 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DeleteNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DeleteNode.cs @@ -50,6 +50,9 @@ class DeleteNode : BaseDmlNode [Category("Delete")] public override int MaxDOP { get; set; } + [Category("Delete")] + public override int BatchSize { get; set; } + [Category("Delete")] public override bool BypassCustomPluginExecution { get; set; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteAsNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteAsNode.cs index 30a5a4ab..a5b77ac3 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteAsNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/ExecuteAsNode.cs @@ -34,6 +34,9 @@ class ExecuteAsNode : BaseDmlNode, IImpersonateRevertExecutionPlanNode [Browsable(false)] public override int MaxDOP { get; set; } + [Browsable(false)] + public override int BatchSize { get; set; } + [Browsable(false)] public override bool BypassCustomPluginExecution { get; set; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/InsertNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/InsertNode.cs index 9252e6e4..ed3c2022 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/InsertNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/InsertNode.cs @@ -42,6 +42,9 @@ class InsertNode : BaseDmlNode [Category("Insert")] public override int MaxDOP { get; set; } + [Category("Insert")] + public override int BatchSize { get; set; } + [Category("Insert")] public override bool BypassCustomPluginExecution { get; set; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/UpdateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/UpdateNode.cs index fb77ceca..aef31cb6 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/UpdateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/UpdateNode.cs @@ -49,6 +49,9 @@ class UpdateNode : BaseDmlNode [Category("Update")] public override int MaxDOP { get; set; } + [Category("Update")] + public override int BatchSize { get; set; } + [Category("Update")] public override bool BypassCustomPluginExecution { get; set; } diff --git a/MarkMpn.Sql4Cds.Engine/MetaMetadataCache.cs b/MarkMpn.Sql4Cds.Engine/MetaMetadataCache.cs index 8ffa4367..ad8edc8f 100644 --- a/MarkMpn.Sql4Cds.Engine/MetaMetadataCache.cs +++ b/MarkMpn.Sql4Cds.Engine/MetaMetadataCache.cs @@ -4,10 +4,12 @@ using System.Linq; using System.Reflection; using System.Text; +using System.Threading; using System.Threading.Tasks; using MarkMpn.Sql4Cds.Engine.ExecutionPlan; using Microsoft.Xrm.Sdk; using Microsoft.Xrm.Sdk.Metadata; +using Microsoft.Xrm.Sdk.Query; namespace MarkMpn.Sql4Cds.Engine { @@ -16,6 +18,64 @@ namespace MarkMpn.Sql4Cds.Engine /// public class MetaMetadataCache : IAttributeMetadataCache { + class StubOptions : IQueryExecutionOptions + { + public CancellationToken CancellationToken => throw new NotImplementedException(); + + public bool BlockUpdateWithoutWhere => throw new NotImplementedException(); + + public bool BlockDeleteWithoutWhere => throw new NotImplementedException(); + + public bool UseBulkDelete => throw new NotImplementedException(); + + public int BatchSize => throw new NotImplementedException(); + + public bool UseTDSEndpoint => throw new NotImplementedException(); + + public int MaxDegreeOfParallelism => throw new NotImplementedException(); + + public bool ColumnComparisonAvailable => throw new NotImplementedException(); + + public bool UseLocalTimeZone => throw new NotImplementedException(); + + public List JoinOperatorsAvailable => throw new NotImplementedException(); + + public bool BypassCustomPlugins => throw new NotImplementedException(); + + public string PrimaryDataSource => throw new NotImplementedException(); + + public Guid UserId => throw new NotImplementedException(); + + public bool QuotedIdentifiers => throw new NotImplementedException(); + + public ColumnOrdering ColumnOrdering => ColumnOrdering.Alphabetical; + + public void ConfirmDelete(ConfirmDmlStatementEventArgs e) + { + throw new NotImplementedException(); + } + + public void ConfirmInsert(ConfirmDmlStatementEventArgs e) + { + throw new NotImplementedException(); + } + + public void ConfirmUpdate(ConfirmDmlStatementEventArgs e) + { + throw new NotImplementedException(); + } + + public bool ContinueRetrieve(int count) + { + throw new NotImplementedException(); + } + + public void Progress(double? progress, string message) + { + throw new NotImplementedException(); + } + } + private readonly IAttributeMetadataCache _inner; private static readonly IDictionary _customMetadata; @@ -32,7 +92,7 @@ static MetaMetadataCache() metadataNode.ManyToOneRelationshipAlias = "relationship_n_1"; metadataNode.ManyToManyRelationshipAlias = "relationship_n_n"; - var metadataSchema = metadataNode.GetSchema(new NodeCompilationContext(null, null, null)); + var metadataSchema = metadataNode.GetSchema(new NodeCompilationContext(null, new StubOptions(), null)); _customMetadata["metadata." + metadataNode.EntityAlias] = SchemaToMetadata(metadataSchema, metadataNode.EntityAlias); _customMetadata["metadata." + metadataNode.AttributeAlias] = SchemaToMetadata(metadataSchema, metadataNode.AttributeAlias); @@ -43,7 +103,7 @@ static MetaMetadataCache() var optionsetNode = new GlobalOptionSetQueryNode(); optionsetNode.Alias = "globaloptionset"; - var optionsetSchema = optionsetNode.GetSchema(new NodeCompilationContext(null, null, null)); + var optionsetSchema = optionsetNode.GetSchema(new NodeCompilationContext(null, new StubOptions(), null)); _customMetadata["metadata." + optionsetNode.Alias] = SchemaToMetadata(optionsetSchema, optionsetNode.Alias); } diff --git a/MarkMpn.Sql4Cds.Engine/Visitors/OptimizerHintValidatingVisitor.cs b/MarkMpn.Sql4Cds.Engine/Visitors/OptimizerHintValidatingVisitor.cs index 7857bd2e..5d96921c 100644 --- a/MarkMpn.Sql4Cds.Engine/Visitors/OptimizerHintValidatingVisitor.cs +++ b/MarkMpn.Sql4Cds.Engine/Visitors/OptimizerHintValidatingVisitor.cs @@ -68,6 +68,9 @@ class OptimizerHintValidatingVisitor : TSqlFragmentVisitor { // Custom hint to set the default page size of FetchXML queries "FETCHXML_PAGE_SIZE_", + + // Custom hint to set the batch size for DML queries + "BATCH_SIZE_", }; private readonly bool _removeSql4CdsHints; From 79f493c7720555cd6a8f4422007c3497dba85003 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Tue, 11 Apr 2023 22:00:12 +0100 Subject: [PATCH 22/34] Added confirmation prompts to ADS extension --- AzureDataStudioExtension/CHANGELOG.md | 7 +- AzureDataStudioExtension/package.json | 25 +++++++ AzureDataStudioExtension/package.nls.json | 5 ++ AzureDataStudioExtension/src/main.ts | 18 +++-- .../Configuration/Sql4CdsSettings.cs | 10 +++ .../Contracts/ConfirmationRequest.cs | 17 +++++ .../Contracts/ConfirmationResponse.cs | 17 +++++ .../QueryExecution/QueryExecutionHandler.cs | 75 +++++++++++++++++++ 8 files changed, 167 insertions(+), 7 deletions(-) create mode 100644 MarkMpn.Sql4Cds.LanguageServer/QueryExecution/Contracts/ConfirmationRequest.cs create mode 100644 MarkMpn.Sql4Cds.LanguageServer/QueryExecution/Contracts/ConfirmationResponse.cs diff --git a/AzureDataStudioExtension/CHANGELOG.md b/AzureDataStudioExtension/CHANGELOG.md index f5c0253d..a589f8b9 100644 --- a/AzureDataStudioExtension/CHANGELOG.md +++ b/AzureDataStudioExtension/CHANGELOG.md @@ -2,4 +2,9 @@ ## [v7.1.0](https://github.com/MarkMpn/Sql4Cds/releases/tag/v7.1.0) - 2023-01-31 -This is the first release of SQL 4 CDS for Azure Data Studio. \ No newline at end of file +This is the first release of SQL 4 CDS for Azure Data Studio. + +## [v7.2.0](https://github.com/MarkMpn/Sql4Cds/releases/tag/v7.2.0) - 2023-04-30 + +Fixes starting the SQL 4 CDS language server on non-Windows platforms. +Adds confirmation prompts and safety limits mirroring the XrmToolBox tool. diff --git a/AzureDataStudioExtension/package.json b/AzureDataStudioExtension/package.json index d60b2a36..56f210db 100644 --- a/AzureDataStudioExtension/package.json +++ b/AzureDataStudioExtension/package.json @@ -117,6 +117,31 @@ "type": "boolean", "default": false, "description": "%sql4cds.quotedIdentifiers.description%" + }, + "SQL4CDS.insertWarnThreshold": { + "type": "integer", + "default": 1, + "description": "%sql4cds.insertWarnThreshold.description%" + }, + "SQL4CDS.updateWarnThreshold": { + "type": "integer", + "default": 0, + "description": "%sql4cds.updateWarnThreshold.description%" + }, + "SQL4CDS.deleteWarnThreshold": { + "type": "integer", + "default": 0, + "description": "%sql4cds.deleteWarnThreshold.description%" + }, + "SQL4CDS.selectLimit": { + "type": "integer", + "default": 0, + "description": "%sql4cds.selectLimit.description%" + }, + "SQL4CDS.maxRetrievesPerQuery": { + "type": "integer", + "default": 100, + "description": "%sql4cds.maxRetrievesPerQuery.description%" } } }, diff --git a/AzureDataStudioExtension/package.nls.json b/AzureDataStudioExtension/package.nls.json index 01aaedef..03a3114c 100644 --- a/AzureDataStudioExtension/package.nls.json +++ b/AzureDataStudioExtension/package.nls.json @@ -15,6 +15,11 @@ "sql4cds.useLocalTimeZone.description": "Use local time zone for date/time values", "sql4cds.bypassCustomPlugins.description": "Bypass custom plugins for INSERT/UPDATE/DELETE requests", "sql4cds.quotedIdentifiers.description": "Use quoted identifiers", + "sql4cds.insertWarnThreshold.description": "Warn when inserting more than _ records", + "sql4cds.updateWarnThreshold.description": "Warn when updating more than _ records", + "sql4cds.deleteWarnThreshold.description": "Warn when deleting more than _ records", + "sql4cds.selectLimit.description": "Limit results to _ records (0 for unlimited)", + "sql4cds.maxRetrievesPerQuery.description": "Maximum retrievals per query (0 for unlimited)", "sql4cds.provider.displayName": "SQL 4 CDS", "sql4cds.connectionOptions.groupName.source": "Source", "sql4cds.connectionOptions.groupName.security": "Security", diff --git a/AzureDataStudioExtension/src/main.ts b/AzureDataStudioExtension/src/main.ts index 1b4bcbe8..3ec73b7d 100644 --- a/AzureDataStudioExtension/src/main.ts +++ b/AzureDataStudioExtension/src/main.ts @@ -57,7 +57,7 @@ export async function activate(context: vscode.ExtensionContext) { let diagnosticCollection = vscode.languages.createDiagnosticCollection(Constants.providerId); context.subscriptions.push(diagnosticCollection); - let e = path.join(Utils.getResolvedServiceInstallationPath(), "MarkMpn.Sql4Cds.LanguageServer.exe"); + let e = path.join(Utils.getResolvedServiceInstallationPath(), "MarkMpn.Sql4Cds.LanguageServer.dll"); let serverOptions = generateServerOptions(e); languageClient = new SqlOpsDataClient(Constants.serviceName, serverOptions, clientOptions); languageClient.onReady().then(() => { @@ -69,6 +69,13 @@ export async function activate(context: vscode.ExtensionContext) { statusView.text = msg; statusView.show(); }); + languageClient.onNotification("sql4cds/confirmation", (message: {ownerUri: string, msg: string}) => { + vscode.window + .showInformationMessage(message.msg, "Yes", "No") + .then(answer => { + languageClient.sendNotification("sql4cds/confirm", { ownerUri: message.ownerUri, result: answer === "Yes" }) + }); + }); languageClient.onNotification("query/batchComplete", () => { statusView.hide(); }); @@ -114,8 +121,7 @@ export async function activate(context: vscode.ExtensionContext) { } function generateServerOptions(executablePath: string): ServerOptions { - let serverArgs = []; - let serverCommand: string = executablePath; + let serverArgs = [executablePath]; let config = vscode.workspace.getConfiguration(Constants.providerId); if (config) { @@ -124,8 +130,8 @@ function generateServerOptions(executablePath: string): ServerOptions { let useLocalSource = config["useDebugSource"]; if (useLocalSource) { let localSourcePath = config["debugSourcePath"]; - let filePath = path.join(localSourcePath, "MarkMpn.Sql4Cds.LanguageServer.exe"); - serverCommand = filePath; + let filePath = path.join(localSourcePath, "MarkMpn.Sql4Cds.LanguageServer.dll"); + serverArgs[0] = filePath; let enableStartupDebugging = config["enableStartupDebugging"]; if (enableStartupDebugging) @@ -143,7 +149,7 @@ function generateServerOptions(executablePath: string): ServerOptions { } // run the service host - return { command: serverCommand, args: serverArgs, transport: TransportKind.stdio }; + return { command: "dotnet", args: serverArgs, transport: TransportKind.stdio }; } function generateHandleServerProviderEvent() { diff --git a/MarkMpn.Sql4Cds.LanguageServer/Configuration/Sql4CdsSettings.cs b/MarkMpn.Sql4Cds.LanguageServer/Configuration/Sql4CdsSettings.cs index 06a5ea44..a7734043 100644 --- a/MarkMpn.Sql4Cds.LanguageServer/Configuration/Sql4CdsSettings.cs +++ b/MarkMpn.Sql4Cds.LanguageServer/Configuration/Sql4CdsSettings.cs @@ -20,6 +20,16 @@ public class Sql4CdsSettings public bool QuotedIdentifiers { get; set; } + public int DeleteWarnThreshold { get; set; } + + public int UpdateWarnThreshold { get; set; } + + public int InsertWarnThreshold { get; set; } + + public int SelectLimit { get; set; } + + public int MaxRetrievesPerQuery { get; set; } + public static Sql4CdsSettings Instance { get; set; } } } diff --git a/MarkMpn.Sql4Cds.LanguageServer/QueryExecution/Contracts/ConfirmationRequest.cs b/MarkMpn.Sql4Cds.LanguageServer/QueryExecution/Contracts/ConfirmationRequest.cs new file mode 100644 index 00000000..29a48193 --- /dev/null +++ b/MarkMpn.Sql4Cds.LanguageServer/QueryExecution/Contracts/ConfirmationRequest.cs @@ -0,0 +1,17 @@ +using Microsoft.VisualStudio.LanguageServer.Protocol; + +namespace MarkMpn.Sql4Cds.LanguageServer.QueryExecution.Contracts +{ + public class ConfirmationParams + { + public string OwnerUri { get; set; } + + public string Msg { get; set; } + } + + public class ConfirmationRequest + { + public const string MessageName = "sql4cds/confirmation"; + public static readonly LspNotification Type = new LspNotification(MessageName); + } +} diff --git a/MarkMpn.Sql4Cds.LanguageServer/QueryExecution/Contracts/ConfirmationResponse.cs b/MarkMpn.Sql4Cds.LanguageServer/QueryExecution/Contracts/ConfirmationResponse.cs new file mode 100644 index 00000000..bd2793da --- /dev/null +++ b/MarkMpn.Sql4Cds.LanguageServer/QueryExecution/Contracts/ConfirmationResponse.cs @@ -0,0 +1,17 @@ +using Microsoft.VisualStudio.LanguageServer.Protocol; + +namespace MarkMpn.Sql4Cds.LanguageServer.QueryExecution.Contracts +{ + public class ConfirmationResponseParams + { + public string OwnerUri { get; set; } + + public bool Result { get; set; } + } + + public class ConfirmationResponse + { + public const string MessageName = "sql4cds/confirm"; + public static readonly LspNotification Type = new LspNotification(MessageName); + } +} diff --git a/MarkMpn.Sql4Cds.LanguageServer/QueryExecution/QueryExecutionHandler.cs b/MarkMpn.Sql4Cds.LanguageServer/QueryExecution/QueryExecutionHandler.cs index 2997db64..2bde5617 100644 --- a/MarkMpn.Sql4Cds.LanguageServer/QueryExecution/QueryExecutionHandler.cs +++ b/MarkMpn.Sql4Cds.LanguageServer/QueryExecution/QueryExecutionHandler.cs @@ -18,6 +18,7 @@ using MarkMpn.Sql4Cds.LanguageServer.Workspace; using Microsoft.SqlTools.ServiceLayer.ExecutionPlan.Contracts; using Microsoft.VisualStudio.LanguageServer.Protocol; +using Microsoft.Xrm.Sdk.Metadata; using Newtonsoft.Json; using Newtonsoft.Json.Linq; using StreamJsonRpc; @@ -31,6 +32,8 @@ class QueryExecutionHandler : IJsonRpcMethodHandler private readonly TextDocumentManager _documentManager; private readonly ConcurrentDictionary> _resultSets; private readonly ConcurrentDictionary _commands; + private readonly ConcurrentDictionary _confirmationEvents; + private readonly ConcurrentDictionary _confirmationResults; public QueryExecutionHandler(JsonRpc lsp, ConnectionManager connectionManager, TextDocumentManager documentManager) { @@ -39,6 +42,8 @@ public QueryExecutionHandler(JsonRpc lsp, ConnectionManager connectionManager, T _documentManager = documentManager; _resultSets = new ConcurrentDictionary>(); _commands = new ConcurrentDictionary(); + _confirmationEvents = new ConcurrentDictionary(); + _confirmationResults = new ConcurrentDictionary(); } public void Initialize(JsonRpc lsp) @@ -48,6 +53,13 @@ public void Initialize(JsonRpc lsp) lsp.AddHandler(QueryCancelRequest.Type, HandleQueryCancel); lsp.AddHandler(QueryExecutionPlanRequest.Type, HandleQueryExecutionPlan); lsp.AddHandler(QueryDisposeRequest.Type, HandleQueryDispose); + lsp.AddHandler(ConfirmationResponse.Type, HandleConfirmation); + } + + private void HandleConfirmation(ConfirmationResponseParams arg) + { + _confirmationResults[arg.OwnerUri] = arg.Result; + _confirmationEvents[arg.OwnerUri].Set(); } public ExecuteRequestResult HandleExecuteDocumentSelection(ExecuteDocumentSelectionParams request) @@ -127,9 +139,58 @@ public ExecuteRequestResult HandleExecuteDocumentSelection(ExecuteDocumentSelect }); } }); + var confirm = (ConfirmDmlStatementEventArgs e, int threshold, string operation) => + { + e.Cancel = false; + + if (e.Count > threshold || e.BypassCustomPluginExecution) + { + var msg = $"{operation} will affect {e.Count:N0} {GetDisplayName(e.Count, e.Metadata)}."; + if (e.BypassCustomPluginExecution) + msg += " This operation will bypass any custom plugins."; + msg += " Do you want to proceed?"; + + var evt = new ManualResetEventSlim(); + _confirmationEvents[request.OwnerUri] = evt; + _ = _lsp.NotifyAsync(ConfirmationRequest.Type, new ConfirmationParams { OwnerUri = request.OwnerUri, Msg = msg }); + evt.Wait(); + e.Cancel = !_confirmationResults[request.OwnerUri]; + _confirmationEvents.Remove(request.OwnerUri, out _); + _confirmationResults.Remove(request.OwnerUri, out _); + }; + }; + var confirmInsert = (EventHandler)((object sender, ConfirmDmlStatementEventArgs e) => + { + confirm(e, Sql4CdsSettings.Instance.InsertWarnThreshold, "Insert"); + }); + var confirmUpdate = (EventHandler)((object sender, ConfirmDmlStatementEventArgs e) => + { + confirm(e, Sql4CdsSettings.Instance.UpdateWarnThreshold, "Update"); + }); + var confirmDelete = (EventHandler)((object sender, ConfirmDmlStatementEventArgs e) => + { + confirm(e, Sql4CdsSettings.Instance.DeleteWarnThreshold, "Delete"); + }); + var pages = 0; + var confirmRetrieve = (EventHandler)((object sender, ConfirmRetrieveEventArgs e) => + { + e.Cancel = e.Count >= Sql4CdsSettings.Instance.SelectLimit; + + if (!e.Cancel) + { + pages++; + + if (Sql4CdsSettings.Instance.MaxRetrievesPerQuery != 0 && pages > Sql4CdsSettings.Instance.MaxRetrievesPerQuery) + throw new QueryExecutionException($"Hit maximum retrieval limit. This limit is in place to protect against excessive API requests. Try restricting the data to retrieve with WHERE clauses or eliminating subqueries.\r\nYour limit of {Sql4CdsSettings.Instance.MaxRetrievesPerQuery:N0} retrievals per query can be modified in Settings."); + } + }); session.Connection.InfoMessage += infoMessageHandler; session.Connection.Progress += progressHandler; + session.Connection.PreInsert += confirmInsert; + session.Connection.PreUpdate += confirmUpdate; + session.Connection.PreDelete += confirmDelete; + session.Connection.PreRetrieve += confirmRetrieve; var resultSets = new List(); _resultSets[request.OwnerUri] = resultSets; @@ -376,6 +437,10 @@ public ExecuteRequestResult HandleExecuteDocumentSelection(ExecuteDocumentSelect session.Connection.InfoMessage -= infoMessageHandler; session.Connection.Progress -= progressHandler; + session.Connection.PreInsert -= confirmInsert; + session.Connection.PreUpdate -= confirmUpdate; + session.Connection.PreDelete -= confirmDelete; + session.Connection.PreRetrieve -= confirmRetrieve; } var endTime = DateTime.UtcNow; @@ -406,6 +471,16 @@ public ExecuteRequestResult HandleExecuteDocumentSelection(ExecuteDocumentSelect return new ExecuteRequestResult(); } + private string GetDisplayName(int count, EntityMetadata meta) + { + if (count == 1) + return meta.DisplayName.UserLocalizedLabel?.Label ?? meta.LogicalName; + + return meta.DisplayCollectionName.UserLocalizedLabel?.Label ?? + meta.LogicalCollectionName ?? + meta.LogicalName; + } + private ExecutionPlanGraph ConvertExecutionPlan(IRootExecutionPlanNode plan, bool executed) { var id = 1; From 95fbbe8a8fea754af3a46b3d24fec4e04be01580 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Tue, 11 Apr 2023 22:32:18 +0100 Subject: [PATCH 23/34] Use recommended degrees of parallelism --- .../MarkMpn.Sql4Cds.Engine.NetFx.csproj | 2 +- .../ExecutionPlan/BaseDmlNode.cs | 35 ++- MarkMpn.Sql4Cds/Settings.cs | 2 +- MarkMpn.Sql4Cds/SettingsForm.Designer.cs | 247 +++++++++--------- MarkMpn.Sql4Cds/SettingsForm.resx | 2 +- 5 files changed, 152 insertions(+), 136 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine.NetFx/MarkMpn.Sql4Cds.Engine.NetFx.csproj b/MarkMpn.Sql4Cds.Engine.NetFx/MarkMpn.Sql4Cds.Engine.NetFx.csproj index 6be6b510..85c9df73 100644 --- a/MarkMpn.Sql4Cds.Engine.NetFx/MarkMpn.Sql4Cds.Engine.NetFx.csproj +++ b/MarkMpn.Sql4Cds.Engine.NetFx/MarkMpn.Sql4Cds.Engine.NetFx.csproj @@ -68,7 +68,7 @@ 9.0.2.33 - 9.1.0.79 + 9.1.1.1 1.0.0 diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs index 086ce83d..4ef8e1a8 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs @@ -153,23 +153,44 @@ public virtual IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext private int GetMaxDOP(NodeCompilationContext context, IList queryHints) { - if (queryHints == null) - return context.Options.MaxDegreeOfParallelism; + if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) + throw new NotSupportedQueryFragmentException("Unknown datasource"); + + var org = dataSource.Connection; + var recommendedMaxDop = 1; + +#if NETCOREAPP + var svc = org as ServiceClient; - var maxDopHint = queryHints + if (svc != null) + recommendedMaxDop = svc.RecommendedDegreesOfParallelism; +#else + var svc = org as CrmServiceClient; + + if (svc != null) + recommendedMaxDop = svc.RecommendedDegreesOfParallelism; +#endif + + var maxDopHint = (queryHints ?? Array.Empty()) .OfType() .Where(hint => hint.HintKind == OptimizerHintKind.MaxDop) .FirstOrDefault(); if (maxDopHint != null) { - if (!(maxDopHint.Value is IntegerLiteral maxDop) || !Int32.TryParse(maxDop.Value, out var value) || value < 1) - throw new NotSupportedQueryFragmentException("MAXDOP requires a positive integer value", maxDopHint); + if (!(maxDopHint.Value is IntegerLiteral maxDop) || !Int32.TryParse(maxDop.Value, out var value) || value < 0) + throw new NotSupportedQueryFragmentException("MAXDOP requires a positive integer value, or 0 to use recommended value", maxDopHint); - return value; + if (value > 0) + return value; + + return recommendedMaxDop; } - return context.Options.MaxDegreeOfParallelism; + if (context.Options.MaxDegreeOfParallelism > 0) + return context.Options.MaxDegreeOfParallelism; + + return recommendedMaxDop; } private int GetBatchSize(NodeCompilationContext context, IList queryHints) diff --git a/MarkMpn.Sql4Cds/Settings.cs b/MarkMpn.Sql4Cds/Settings.cs index 06cf053d..8956c44f 100644 --- a/MarkMpn.Sql4Cds/Settings.cs +++ b/MarkMpn.Sql4Cds/Settings.cs @@ -41,7 +41,7 @@ public class Settings public bool ShowIntellisenseTooltips { get; set; } = true; - public int MaxDegreeOfPaallelism { get; set; } = 10; + public int MaxDegreeOfPaallelism { get; set; } public bool IncludeFetchXml { get; set; } diff --git a/MarkMpn.Sql4Cds/SettingsForm.Designer.cs b/MarkMpn.Sql4Cds/SettingsForm.Designer.cs index 7884b5fd..348a66e6 100644 --- a/MarkMpn.Sql4Cds/SettingsForm.Designer.cs +++ b/MarkMpn.Sql4Cds/SettingsForm.Designer.cs @@ -30,6 +30,7 @@ private void InitializeComponent() { System.ComponentModel.ComponentResourceManager resources = new System.ComponentModel.ComponentResourceManager(typeof(SettingsForm)); this.topPanel = new System.Windows.Forms.Panel(); + this.pictureBox = new System.Windows.Forms.PictureBox(); this.label1 = new System.Windows.Forms.Label(); this.panel2 = new System.Windows.Forms.Panel(); this.cancelButton = new System.Windows.Forms.Button(); @@ -66,28 +67,28 @@ private void InitializeComponent() this.tabPage1 = new System.Windows.Forms.TabPage(); this.schemaColumnOrderingCheckbox = new System.Windows.Forms.CheckBox(); this.showFetchXMLInEstimatedExecutionPlansCheckBox = new System.Windows.Forms.CheckBox(); + this.pictureBox6 = new System.Windows.Forms.PictureBox(); + this.pictureBox4 = new System.Windows.Forms.PictureBox(); + this.pictureBox2 = new System.Windows.Forms.PictureBox(); + this.pictureBox1 = new System.Windows.Forms.PictureBox(); + this.bulkDeleteHelp = new System.Windows.Forms.PictureBox(); this.tabPage2 = new System.Windows.Forms.TabPage(); this.label15 = new System.Windows.Forms.Label(); this.insertWarnThresholdUpDown = new System.Windows.Forms.NumericUpDown(); this.label16 = new System.Windows.Forms.Label(); + this.pictureBox3 = new System.Windows.Forms.PictureBox(); this.tabPage3 = new System.Windows.Forms.TabPage(); this.localDateFormatCheckbox = new System.Windows.Forms.CheckBox(); this.rememberSessionCheckbox = new System.Windows.Forms.CheckBox(); this.tabPage4 = new System.Windows.Forms.TabPage(); + this.fetchXml2SqlConversionAdvancedLinkLabel = new System.Windows.Forms.LinkLabel(); this.nativeSqlScintilla = new ScintillaNET.Scintilla(); this.simpleSqlScintilla = new ScintillaNET.Scintilla(); this.nativeSqlRadioButton = new System.Windows.Forms.RadioButton(); this.simpleSqlRadioButton = new System.Windows.Forms.RadioButton(); this.label17 = new System.Windows.Forms.Label(); - this.pictureBox6 = new System.Windows.Forms.PictureBox(); - this.pictureBox4 = new System.Windows.Forms.PictureBox(); - this.pictureBox2 = new System.Windows.Forms.PictureBox(); - this.pictureBox1 = new System.Windows.Forms.PictureBox(); - this.bulkDeleteHelp = new System.Windows.Forms.PictureBox(); - this.pictureBox3 = new System.Windows.Forms.PictureBox(); - this.pictureBox = new System.Windows.Forms.PictureBox(); - this.fetchXml2SqlConversionAdvancedLinkLabel = new System.Windows.Forms.LinkLabel(); this.topPanel.SuspendLayout(); + ((System.ComponentModel.ISupportInitialize)(this.pictureBox)).BeginInit(); this.panel2.SuspendLayout(); ((System.ComponentModel.ISupportInitialize)(this.selectLimitUpDown)).BeginInit(); ((System.ComponentModel.ISupportInitialize)(this.updateWarnThresholdUpDown)).BeginInit(); @@ -97,17 +98,16 @@ private void InitializeComponent() ((System.ComponentModel.ISupportInitialize)(this.maxDopUpDown)).BeginInit(); this.tabControl1.SuspendLayout(); this.tabPage1.SuspendLayout(); - this.tabPage2.SuspendLayout(); - ((System.ComponentModel.ISupportInitialize)(this.insertWarnThresholdUpDown)).BeginInit(); - this.tabPage3.SuspendLayout(); - this.tabPage4.SuspendLayout(); ((System.ComponentModel.ISupportInitialize)(this.pictureBox6)).BeginInit(); ((System.ComponentModel.ISupportInitialize)(this.pictureBox4)).BeginInit(); ((System.ComponentModel.ISupportInitialize)(this.pictureBox2)).BeginInit(); ((System.ComponentModel.ISupportInitialize)(this.pictureBox1)).BeginInit(); ((System.ComponentModel.ISupportInitialize)(this.bulkDeleteHelp)).BeginInit(); + this.tabPage2.SuspendLayout(); + ((System.ComponentModel.ISupportInitialize)(this.insertWarnThresholdUpDown)).BeginInit(); ((System.ComponentModel.ISupportInitialize)(this.pictureBox3)).BeginInit(); - ((System.ComponentModel.ISupportInitialize)(this.pictureBox)).BeginInit(); + this.tabPage3.SuspendLayout(); + this.tabPage4.SuspendLayout(); this.SuspendLayout(); // // topPanel @@ -122,6 +122,18 @@ private void InitializeComponent() this.topPanel.Size = new System.Drawing.Size(441, 52); this.topPanel.TabIndex = 0; // + // pictureBox + // + this.pictureBox.BackColor = System.Drawing.Color.FromArgb(((int)(((byte)(7)))), ((int)(((byte)(14)))), ((int)(((byte)(22))))); + this.pictureBox.Image = ((System.Drawing.Image)(resources.GetObject("pictureBox.Image"))); + this.pictureBox.Location = new System.Drawing.Point(6, 6); + this.pictureBox.Margin = new System.Windows.Forms.Padding(2); + this.pictureBox.Name = "pictureBox"; + this.pictureBox.Size = new System.Drawing.Size(40, 40); + this.pictureBox.SizeMode = System.Windows.Forms.PictureBoxSizeMode.StretchImage; + this.pictureBox.TabIndex = 0; + this.pictureBox.TabStop = false; + // // label1 // this.label1.AutoSize = true; @@ -372,9 +384,9 @@ private void InitializeComponent() this.label12.AutoSize = true; this.label12.Location = new System.Drawing.Point(173, 52); this.label12.Name = "label12"; - this.label12.Size = new System.Drawing.Size(77, 13); + this.label12.Size = new System.Drawing.Size(131, 13); this.label12.TabIndex = 6; - this.label12.Text = "worker threads"; + this.label12.Text = "worker threads (0 for auto)"; // // localTimesComboBox // @@ -402,11 +414,6 @@ private void InitializeComponent() // maxDopUpDown // this.maxDopUpDown.Location = new System.Drawing.Point(65, 50); - this.maxDopUpDown.Minimum = new decimal(new int[] { - 1, - 0, - 0, - 0}); this.maxDopUpDown.Name = "maxDopUpDown"; this.maxDopUpDown.Size = new System.Drawing.Size(102, 20); this.maxDopUpDown.TabIndex = 5; @@ -532,6 +539,69 @@ private void InitializeComponent() this.showFetchXMLInEstimatedExecutionPlansCheckBox.Text = "Show FetchXML in estimated execution plans"; this.showFetchXMLInEstimatedExecutionPlansCheckBox.UseVisualStyleBackColor = true; // + // pictureBox6 + // + this.pictureBox6.Cursor = System.Windows.Forms.Cursors.Hand; + this.pictureBox6.Image = global::MarkMpn.Sql4Cds.Properties.Resources.StatusHelp_16x; + this.pictureBox6.Location = new System.Drawing.Point(310, 52); + this.pictureBox6.Name = "pictureBox6"; + this.pictureBox6.Size = new System.Drawing.Size(16, 16); + this.pictureBox6.TabIndex = 31; + this.pictureBox6.TabStop = false; + this.pictureBox6.Tag = "https://docs.microsoft.com/powerapps/developer/data-platform/api-limits#how-to-ma" + + "ximize-throughput"; + this.pictureBox6.Click += new System.EventHandler(this.helpIcon_Click); + // + // pictureBox4 + // + this.pictureBox4.Cursor = System.Windows.Forms.Cursors.Hand; + this.pictureBox4.Image = global::MarkMpn.Sql4Cds.Properties.Resources.StatusHelp_16x; + this.pictureBox4.Location = new System.Drawing.Point(124, 6); + this.pictureBox4.Name = "pictureBox4"; + this.pictureBox4.Size = new System.Drawing.Size(16, 16); + this.pictureBox4.TabIndex = 29; + this.pictureBox4.TabStop = false; + this.pictureBox4.Tag = "https://docs.microsoft.com/sql/t-sql/statements/set-quoted-identifier-transact-sq" + + "l"; + this.pictureBox4.Click += new System.EventHandler(this.helpIcon_Click); + // + // pictureBox2 + // + this.pictureBox2.Cursor = System.Windows.Forms.Cursors.Hand; + this.pictureBox2.Image = global::MarkMpn.Sql4Cds.Properties.Resources.StatusHelp_16x; + this.pictureBox2.Location = new System.Drawing.Point(250, 122); + this.pictureBox2.Name = "pictureBox2"; + this.pictureBox2.Size = new System.Drawing.Size(16, 16); + this.pictureBox2.TabIndex = 28; + this.pictureBox2.TabStop = false; + this.pictureBox2.Tag = "https://docs.microsoft.com/powerapps/developer/data-platform/dataverse-sql-query"; + this.pictureBox2.Click += new System.EventHandler(this.helpIcon_Click); + // + // pictureBox1 + // + this.pictureBox1.Cursor = System.Windows.Forms.Cursors.Hand; + this.pictureBox1.Image = global::MarkMpn.Sql4Cds.Properties.Resources.StatusHelp_16x; + this.pictureBox1.Location = new System.Drawing.Point(148, 99); + this.pictureBox1.Name = "pictureBox1"; + this.pictureBox1.Size = new System.Drawing.Size(16, 16); + this.pictureBox1.TabIndex = 27; + this.pictureBox1.TabStop = false; + this.pictureBox1.Tag = "https://docs.microsoft.com/powerapps/developer/data-platform/bypass-custom-busine" + + "ss-logic"; + this.pictureBox1.Click += new System.EventHandler(this.helpIcon_Click); + // + // bulkDeleteHelp + // + this.bulkDeleteHelp.Cursor = System.Windows.Forms.Cursors.Hand; + this.bulkDeleteHelp.Image = global::MarkMpn.Sql4Cds.Properties.Resources.StatusHelp_16x; + this.bulkDeleteHelp.Location = new System.Drawing.Point(167, 77); + this.bulkDeleteHelp.Name = "bulkDeleteHelp"; + this.bulkDeleteHelp.Size = new System.Drawing.Size(16, 16); + this.bulkDeleteHelp.TabIndex = 26; + this.bulkDeleteHelp.TabStop = false; + this.bulkDeleteHelp.Tag = "https://docs.microsoft.com/powerapps/developer/data-platform/delete-data-bulk"; + this.bulkDeleteHelp.Click += new System.EventHandler(this.helpIcon_Click); + // // tabPage2 // this.tabPage2.Controls.Add(this.label15); @@ -590,6 +660,18 @@ private void InitializeComponent() this.label16.TabIndex = 6; this.label16.Text = "Warn when inserting more than"; // + // pictureBox3 + // + this.pictureBox3.Cursor = System.Windows.Forms.Cursors.Hand; + this.pictureBox3.Image = global::MarkMpn.Sql4Cds.Properties.Resources.StatusHelp_16x; + this.pictureBox3.Location = new System.Drawing.Point(378, 34); + this.pictureBox3.Name = "pictureBox3"; + this.pictureBox3.Size = new System.Drawing.Size(16, 16); + this.pictureBox3.TabIndex = 27; + this.pictureBox3.TabStop = false; + this.pictureBox3.Tag = "https://docs.microsoft.com/powerapps/developer/data-platform/api-limits"; + this.pictureBox3.Click += new System.EventHandler(this.helpIcon_Click); + // // tabPage3 // this.tabPage3.Controls.Add(this.localDateFormatCheckbox); @@ -640,6 +722,18 @@ private void InitializeComponent() this.tabPage4.Text = "Conversion"; this.tabPage4.UseVisualStyleBackColor = true; // + // fetchXml2SqlConversionAdvancedLinkLabel + // + this.fetchXml2SqlConversionAdvancedLinkLabel.Anchor = ((System.Windows.Forms.AnchorStyles)((System.Windows.Forms.AnchorStyles.Top | System.Windows.Forms.AnchorStyles.Right))); + this.fetchXml2SqlConversionAdvancedLinkLabel.AutoSize = true; + this.fetchXml2SqlConversionAdvancedLinkLabel.Location = new System.Drawing.Point(348, 21); + this.fetchXml2SqlConversionAdvancedLinkLabel.Name = "fetchXml2SqlConversionAdvancedLinkLabel"; + this.fetchXml2SqlConversionAdvancedLinkLabel.Size = new System.Drawing.Size(56, 13); + this.fetchXml2SqlConversionAdvancedLinkLabel.TabIndex = 5; + this.fetchXml2SqlConversionAdvancedLinkLabel.TabStop = true; + this.fetchXml2SqlConversionAdvancedLinkLabel.Text = "Advanced"; + this.fetchXml2SqlConversionAdvancedLinkLabel.LinkClicked += new System.Windows.Forms.LinkLabelLinkClickedEventHandler(this.fetchXml2SqlConversionAdvancedLinkLabel_LinkClicked); + // // nativeSqlScintilla // this.nativeSqlScintilla.Anchor = ((System.Windows.Forms.AnchorStyles)(((System.Windows.Forms.AnchorStyles.Top | System.Windows.Forms.AnchorStyles.Left) @@ -692,105 +786,6 @@ private void InitializeComponent() this.label17.TabIndex = 0; this.label17.Text = "When converting FetchXML to SQL (e.g. from FetchXML Builder):"; // - // pictureBox6 - // - this.pictureBox6.Cursor = System.Windows.Forms.Cursors.Hand; - this.pictureBox6.Image = global::MarkMpn.Sql4Cds.Properties.Resources.StatusHelp_16x; - this.pictureBox6.Location = new System.Drawing.Point(256, 52); - this.pictureBox6.Name = "pictureBox6"; - this.pictureBox6.Size = new System.Drawing.Size(16, 16); - this.pictureBox6.TabIndex = 31; - this.pictureBox6.TabStop = false; - this.pictureBox6.Tag = "https://docs.microsoft.com/powerapps/developer/data-platform/api-limits#how-to-ma" + - "ximize-throughput"; - this.pictureBox6.Click += new System.EventHandler(this.helpIcon_Click); - // - // pictureBox4 - // - this.pictureBox4.Cursor = System.Windows.Forms.Cursors.Hand; - this.pictureBox4.Image = global::MarkMpn.Sql4Cds.Properties.Resources.StatusHelp_16x; - this.pictureBox4.Location = new System.Drawing.Point(124, 6); - this.pictureBox4.Name = "pictureBox4"; - this.pictureBox4.Size = new System.Drawing.Size(16, 16); - this.pictureBox4.TabIndex = 29; - this.pictureBox4.TabStop = false; - this.pictureBox4.Tag = "https://docs.microsoft.com/sql/t-sql/statements/set-quoted-identifier-transact-sq" + - "l"; - this.pictureBox4.Click += new System.EventHandler(this.helpIcon_Click); - // - // pictureBox2 - // - this.pictureBox2.Cursor = System.Windows.Forms.Cursors.Hand; - this.pictureBox2.Image = global::MarkMpn.Sql4Cds.Properties.Resources.StatusHelp_16x; - this.pictureBox2.Location = new System.Drawing.Point(250, 122); - this.pictureBox2.Name = "pictureBox2"; - this.pictureBox2.Size = new System.Drawing.Size(16, 16); - this.pictureBox2.TabIndex = 28; - this.pictureBox2.TabStop = false; - this.pictureBox2.Tag = "https://docs.microsoft.com/powerapps/developer/data-platform/dataverse-sql-query"; - this.pictureBox2.Click += new System.EventHandler(this.helpIcon_Click); - // - // pictureBox1 - // - this.pictureBox1.Cursor = System.Windows.Forms.Cursors.Hand; - this.pictureBox1.Image = global::MarkMpn.Sql4Cds.Properties.Resources.StatusHelp_16x; - this.pictureBox1.Location = new System.Drawing.Point(148, 99); - this.pictureBox1.Name = "pictureBox1"; - this.pictureBox1.Size = new System.Drawing.Size(16, 16); - this.pictureBox1.TabIndex = 27; - this.pictureBox1.TabStop = false; - this.pictureBox1.Tag = "https://docs.microsoft.com/powerapps/developer/data-platform/bypass-custom-busine" + - "ss-logic"; - this.pictureBox1.Click += new System.EventHandler(this.helpIcon_Click); - // - // bulkDeleteHelp - // - this.bulkDeleteHelp.Cursor = System.Windows.Forms.Cursors.Hand; - this.bulkDeleteHelp.Image = global::MarkMpn.Sql4Cds.Properties.Resources.StatusHelp_16x; - this.bulkDeleteHelp.Location = new System.Drawing.Point(167, 77); - this.bulkDeleteHelp.Name = "bulkDeleteHelp"; - this.bulkDeleteHelp.Size = new System.Drawing.Size(16, 16); - this.bulkDeleteHelp.TabIndex = 26; - this.bulkDeleteHelp.TabStop = false; - this.bulkDeleteHelp.Tag = "https://docs.microsoft.com/powerapps/developer/data-platform/delete-data-bulk"; - this.bulkDeleteHelp.Click += new System.EventHandler(this.helpIcon_Click); - // - // pictureBox3 - // - this.pictureBox3.Cursor = System.Windows.Forms.Cursors.Hand; - this.pictureBox3.Image = global::MarkMpn.Sql4Cds.Properties.Resources.StatusHelp_16x; - this.pictureBox3.Location = new System.Drawing.Point(378, 34); - this.pictureBox3.Name = "pictureBox3"; - this.pictureBox3.Size = new System.Drawing.Size(16, 16); - this.pictureBox3.TabIndex = 27; - this.pictureBox3.TabStop = false; - this.pictureBox3.Tag = "https://docs.microsoft.com/powerapps/developer/data-platform/api-limits"; - this.pictureBox3.Click += new System.EventHandler(this.helpIcon_Click); - // - // pictureBox - // - this.pictureBox.BackColor = System.Drawing.Color.FromArgb(((int)(((byte)(7)))), ((int)(((byte)(14)))), ((int)(((byte)(22))))); - this.pictureBox.Image = ((System.Drawing.Image)(resources.GetObject("pictureBox.Image"))); - this.pictureBox.Location = new System.Drawing.Point(6, 6); - this.pictureBox.Margin = new System.Windows.Forms.Padding(2); - this.pictureBox.Name = "pictureBox"; - this.pictureBox.Size = new System.Drawing.Size(40, 40); - this.pictureBox.SizeMode = System.Windows.Forms.PictureBoxSizeMode.StretchImage; - this.pictureBox.TabIndex = 0; - this.pictureBox.TabStop = false; - // - // fetchXml2SqlConversionAdvancedLinkLabel - // - this.fetchXml2SqlConversionAdvancedLinkLabel.Anchor = ((System.Windows.Forms.AnchorStyles)((System.Windows.Forms.AnchorStyles.Top | System.Windows.Forms.AnchorStyles.Right))); - this.fetchXml2SqlConversionAdvancedLinkLabel.AutoSize = true; - this.fetchXml2SqlConversionAdvancedLinkLabel.Location = new System.Drawing.Point(348, 21); - this.fetchXml2SqlConversionAdvancedLinkLabel.Name = "fetchXml2SqlConversionAdvancedLinkLabel"; - this.fetchXml2SqlConversionAdvancedLinkLabel.Size = new System.Drawing.Size(56, 13); - this.fetchXml2SqlConversionAdvancedLinkLabel.TabIndex = 5; - this.fetchXml2SqlConversionAdvancedLinkLabel.TabStop = true; - this.fetchXml2SqlConversionAdvancedLinkLabel.Text = "Advanced"; - this.fetchXml2SqlConversionAdvancedLinkLabel.LinkClicked += new System.Windows.Forms.LinkLabelLinkClickedEventHandler(this.fetchXml2SqlConversionAdvancedLinkLabel_LinkClicked); - // // SettingsForm // this.AcceptButton = this.okButton; @@ -812,6 +807,7 @@ private void InitializeComponent() this.Text = "SQL 4 CDS Settings"; this.topPanel.ResumeLayout(false); this.topPanel.PerformLayout(); + ((System.ComponentModel.ISupportInitialize)(this.pictureBox)).EndInit(); this.panel2.ResumeLayout(false); ((System.ComponentModel.ISupportInitialize)(this.selectLimitUpDown)).EndInit(); ((System.ComponentModel.ISupportInitialize)(this.updateWarnThresholdUpDown)).EndInit(); @@ -822,20 +818,19 @@ private void InitializeComponent() this.tabControl1.ResumeLayout(false); this.tabPage1.ResumeLayout(false); this.tabPage1.PerformLayout(); + ((System.ComponentModel.ISupportInitialize)(this.pictureBox6)).EndInit(); + ((System.ComponentModel.ISupportInitialize)(this.pictureBox4)).EndInit(); + ((System.ComponentModel.ISupportInitialize)(this.pictureBox2)).EndInit(); + ((System.ComponentModel.ISupportInitialize)(this.pictureBox1)).EndInit(); + ((System.ComponentModel.ISupportInitialize)(this.bulkDeleteHelp)).EndInit(); this.tabPage2.ResumeLayout(false); this.tabPage2.PerformLayout(); ((System.ComponentModel.ISupportInitialize)(this.insertWarnThresholdUpDown)).EndInit(); + ((System.ComponentModel.ISupportInitialize)(this.pictureBox3)).EndInit(); this.tabPage3.ResumeLayout(false); this.tabPage3.PerformLayout(); this.tabPage4.ResumeLayout(false); this.tabPage4.PerformLayout(); - ((System.ComponentModel.ISupportInitialize)(this.pictureBox6)).EndInit(); - ((System.ComponentModel.ISupportInitialize)(this.pictureBox4)).EndInit(); - ((System.ComponentModel.ISupportInitialize)(this.pictureBox2)).EndInit(); - ((System.ComponentModel.ISupportInitialize)(this.pictureBox1)).EndInit(); - ((System.ComponentModel.ISupportInitialize)(this.bulkDeleteHelp)).EndInit(); - ((System.ComponentModel.ISupportInitialize)(this.pictureBox3)).EndInit(); - ((System.ComponentModel.ISupportInitialize)(this.pictureBox)).EndInit(); this.ResumeLayout(false); } diff --git a/MarkMpn.Sql4Cds/SettingsForm.resx b/MarkMpn.Sql4Cds/SettingsForm.resx index e109cc81..006f7a17 100644 --- a/MarkMpn.Sql4Cds/SettingsForm.resx +++ b/MarkMpn.Sql4Cds/SettingsForm.resx @@ -121,7 +121,7 @@ iVBORw0KGgoAAAANSUhEUgAAAFAAAABQCAIAAAABc2X6AAAABGdBTUEAALGPC/xhBQAAAAlwSFlzAAAO - wQAADsEBuJFr7QAAF2tJREFUeF7tm3l0leW1xvnj1laLDCKEUQwIARWrogwVrSBeiorKnHk6mUMCBAJh + vwAADr8BOAVTJAAAF2tJREFUeF7tm3l0leW1xvnj1laLDCKEUQwIARWrogwVrSBeiorKnHk6mUMCBAJh SIwiQ8KQhEwQAhhAkRlEnEASBUEqoG2Ve22VoXVoa+vqogp6r7X3+c7vZOflBOqw7l+3l/Ws3T08e797 f+/7DQdsix+26vAvhRY/ah1y2ZXtBSmXt+kolxScLq85jIMiKV0VBMuVjt+YAqsQAi5BEshjZQWyBHnk NxMPijlJvLyt10mj7CT5w1YhLX7Q8mpji+Tx2nS8om0nnM0BQYrxBX8tL+qV9s8sSFcIv0WDPJSiAkUo From c9fed26ee94436278a3ea645afca70fcd5b5a973 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Sat, 15 Apr 2023 16:48:51 +0100 Subject: [PATCH 24/34] Fixed failing tests --- .../FakeXrmEasyTestsBase.cs | 16 ++++++++ .../MarkMpn.Sql4Cds.Engine.Tests.csproj | 39 ++++++++++++++++--- MarkMpn.Sql4Cds.Engine.Tests/packages.config | 15 +++++-- .../ExecutionPlan/BaseDmlNode.cs | 3 ++ 4 files changed, 65 insertions(+), 8 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine.Tests/FakeXrmEasyTestsBase.cs b/MarkMpn.Sql4Cds.Engine.Tests/FakeXrmEasyTestsBase.cs index cc9b16cf..30c68288 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/FakeXrmEasyTestsBase.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/FakeXrmEasyTestsBase.cs @@ -78,6 +78,9 @@ public FakeXrmEasyTestsBase() SetMaxLength(_context); SetMaxLength(_context2); + + SetColumnNumber(_context); + SetColumnNumber(_context2); } private void SetPrimaryNameAttributes(XrmFakedContext context) @@ -191,5 +194,18 @@ private void SetMaxLength(XrmFakedContext context) context.SetEntityMetadata(entity); } } + + private void SetColumnNumber(XrmFakedContext context) + { + foreach (var entity in context.CreateMetadataQuery()) + { + var index = 0; + + foreach (var attr in entity.Attributes) + typeof(AttributeMetadata).GetProperty(nameof(AttributeMetadata.ColumnNumber)).SetValue(attr, index++); + + context.SetEntityMetadata(entity); + } + } } } diff --git a/MarkMpn.Sql4Cds.Engine.Tests/MarkMpn.Sql4Cds.Engine.Tests.csproj b/MarkMpn.Sql4Cds.Engine.Tests/MarkMpn.Sql4Cds.Engine.Tests.csproj index 3d7a48d4..4e7492d9 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/MarkMpn.Sql4Cds.Engine.Tests.csproj +++ b/MarkMpn.Sql4Cds.Engine.Tests/MarkMpn.Sql4Cds.Engine.Tests.csproj @@ -52,8 +52,11 @@ ..\packages\FakeXrmEasy.9.1.57.0\lib\net452\FakeXrmEasy.dll + + ..\packages\Microsoft.Bcl.AsyncInterfaces.6.0.0\lib\net461\Microsoft.Bcl.AsyncInterfaces.dll + - ..\packages\Microsoft.CrmSdk.CoreAssemblies.9.0.2.33\lib\net462\Microsoft.Crm.Sdk.Proxy.dll + ..\packages\Microsoft.CrmSdk.CoreAssemblies.9.0.2.45\lib\net462\Microsoft.Crm.Sdk.Proxy.dll @@ -63,7 +66,7 @@ ..\packages\Microsoft.IdentityModel.Clients.ActiveDirectory.5.2.8\lib\net45\Microsoft.IdentityModel.Clients.ActiveDirectory.dll - ..\packages\Microsoft.CrmSdk.XrmTooling.CoreAssembly.9.1.0.79\lib\net462\Microsoft.Rest.ClientRuntime.dll + ..\packages\Microsoft.CrmSdk.XrmTooling.CoreAssembly.9.1.1.1\lib\net462\Microsoft.Rest.ClientRuntime.dll @@ -73,16 +76,16 @@ ..\packages\MSTest.TestFramework.2.1.1\lib\net45\Microsoft.VisualStudio.TestPlatform.TestFramework.Extensions.dll - ..\packages\Microsoft.CrmSdk.CoreAssemblies.9.0.2.33\lib\net462\Microsoft.Xrm.Sdk.dll + ..\packages\Microsoft.CrmSdk.CoreAssemblies.9.0.2.45\lib\net462\Microsoft.Xrm.Sdk.dll ..\packages\Microsoft.CrmSdk.Deployment.9.0.2.28\lib\net462\Microsoft.Xrm.Sdk.Deployment.dll - ..\packages\Microsoft.CrmSdk.Workflow.9.0.2.28\lib\net462\Microsoft.Xrm.Sdk.Workflow.dll + ..\packages\Microsoft.CrmSdk.Workflow.9.0.2.42\lib\net462\Microsoft.Xrm.Sdk.Workflow.dll - ..\packages\Microsoft.CrmSdk.XrmTooling.CoreAssembly.9.1.0.79\lib\net462\Microsoft.Xrm.Tooling.Connector.dll + ..\packages\Microsoft.CrmSdk.XrmTooling.CoreAssembly.9.1.1.1\lib\net462\Microsoft.Xrm.Tooling.Connector.dll ..\packages\Newtonsoft.Json.13.0.1\lib\net45\Newtonsoft.Json.dll @@ -91,6 +94,9 @@ + + ..\packages\System.Buffers.4.5.1\lib\net461\System.Buffers.dll + @@ -104,6 +110,9 @@ True True + + ..\packages\System.Memory.4.5.4\lib\net461\System.Memory.dll + ..\packages\System.Net.Http.4.3.4\lib\net46\System.Net.Http.dll @@ -112,11 +121,17 @@ + + ..\packages\System.Numerics.Vectors.4.5.0\lib\net46\System.Numerics.Vectors.dll + ..\packages\System.Runtime.4.3.1\lib\net462\System.Runtime.dll True True + + ..\packages\System.Runtime.CompilerServices.Unsafe.6.0.0\lib\net461\System.Runtime.CompilerServices.Unsafe.dll + @@ -141,6 +156,18 @@ + + ..\packages\System.Text.Encodings.Web.6.0.0\lib\net461\System.Text.Encodings.Web.dll + + + ..\packages\System.Text.Json.6.0.2\lib\net461\System.Text.Json.dll + + + ..\packages\System.Threading.Tasks.Extensions.4.5.4\lib\net461\System.Threading.Tasks.Extensions.dll + + + ..\packages\System.ValueTuple.4.5.0\lib\net47\System.ValueTuple.dll + @@ -198,6 +225,8 @@ + + \ No newline at end of file diff --git a/MarkMpn.Sql4Cds.Engine.Tests/packages.config b/MarkMpn.Sql4Cds.Engine.Tests/packages.config index 2f2a77bd..cef51bb1 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/packages.config +++ b/MarkMpn.Sql4Cds.Engine.Tests/packages.config @@ -2,21 +2,30 @@ - + + - - + + + + + + + + + + \ No newline at end of file diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs index 4ef8e1a8..b62fac28 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseDmlNode.cs @@ -153,6 +153,9 @@ public virtual IRootExecutionPlanNodeInternal[] FoldQuery(NodeCompilationContext private int GetMaxDOP(NodeCompilationContext context, IList queryHints) { + if (DataSource == null) + return 1; + if (!context.DataSources.TryGetValue(DataSource, out var dataSource)) throw new NotSupportedQueryFragmentException("Unknown datasource"); From 1ea2dc0d27f6a5a28f42b1cb4f7467b8f5ea4805 Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Sat, 15 Apr 2023 16:54:59 +0100 Subject: [PATCH 25/34] Removed downgraded package --- .../MarkMpn.Sql4Cds.Engine.NetFx.csproj | 3 --- 1 file changed, 3 deletions(-) diff --git a/MarkMpn.Sql4Cds.Engine.NetFx/MarkMpn.Sql4Cds.Engine.NetFx.csproj b/MarkMpn.Sql4Cds.Engine.NetFx/MarkMpn.Sql4Cds.Engine.NetFx.csproj index 85c9df73..80f25502 100644 --- a/MarkMpn.Sql4Cds.Engine.NetFx/MarkMpn.Sql4Cds.Engine.NetFx.csproj +++ b/MarkMpn.Sql4Cds.Engine.NetFx/MarkMpn.Sql4Cds.Engine.NetFx.csproj @@ -64,9 +64,6 @@ 2.20.0 - - 9.0.2.33 - 9.1.1.1 From 0f5834e0e38c738c94b5591dcb77ce4097ec97eb Mon Sep 17 00:00:00 2001 From: Mark Carrington Date: Sun, 16 Apr 2023 15:16:45 +0100 Subject: [PATCH 26/34] Converted packages.config to PackageReference --- .../AutocompleteMenu-ScintillaNET.csproj | 9 +- AutocompleteMenu/packages.config | 4 - ...rkMpn.Sql4Cds.Engine.FetchXml.Tests.csproj | 75 +----- .../packages.config | 19 -- .../MarkMpn.Sql4Cds.Engine.NetFx.csproj | 4 +- .../MarkMpn.Sql4Cds.Engine.Tests.csproj | 127 +-------- MarkMpn.Sql4Cds.Engine.Tests/app.config | 28 +- MarkMpn.Sql4Cds.Engine.Tests/packages.config | 31 --- .../Connection/PersistentMetadataCache.cs | 2 +- .../MarkMpn.Sql4Cds.LanguageServer.csproj | 4 +- .../MarkMpn.Sql4Cds.SSMS.18.csproj | 2 +- .../MarkMpn.Sql4Cds.Tests.csproj | 98 ++----- MarkMpn.Sql4Cds.Tests/packages.config | 21 -- MarkMpn.Sql4Cds/ILMergeConfig.json | 10 + MarkMpn.Sql4Cds/MarkMpn.Sql4Cds.csproj | 240 ++---------------- MarkMpn.Sql4Cds/packages.config | 51 ---- build.yml | 6 - pr-build.yml | 6 - 18 files changed, 101 insertions(+), 636 deletions(-) delete mode 100644 AutocompleteMenu/packages.config delete mode 100644 MarkMpn.Sql4Cds.Engine.FetchXml.Tests/packages.config delete mode 100644 MarkMpn.Sql4Cds.Engine.Tests/packages.config delete mode 100644 MarkMpn.Sql4Cds.Tests/packages.config delete mode 100644 MarkMpn.Sql4Cds/packages.config diff --git a/AutocompleteMenu/AutocompleteMenu-ScintillaNET.csproj b/AutocompleteMenu/AutocompleteMenu-ScintillaNET.csproj index 300c900c..94f510a7 100644 --- a/AutocompleteMenu/AutocompleteMenu-ScintillaNET.csproj +++ b/AutocompleteMenu/AutocompleteMenu-ScintillaNET.csproj @@ -34,9 +34,6 @@ false - - ..\packages\jacobslusser.ScintillaNET.3.6.3\lib\net40\ScintillaNET.dll - @@ -70,7 +67,11 @@ - + + + + 3.6.3 +