From 38cfa09b5784c6d75284d9dc1b1c20fb2849659b Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Wed, 15 Nov 2023 13:23:02 +0000 Subject: [PATCH] Use deferred construction if there are required members (fix #71) --- .../DapperInterceptorGenerator.cs | 3 +- .../Internal/CodeWriter.cs | 8 + .../Internal/Inspection.cs | 3 + .../Interceptors/RequiredProperties.input.cs | 37 ++++ .../Interceptors/RequiredProperties.output.cs | 164 ++++++++++++++++++ .../RequiredProperties.output.netfx.cs | 164 ++++++++++++++++++ .../RequiredProperties.output.netfx.txt | 4 + .../RequiredProperties.output.txt | 4 + 8 files changed, 386 insertions(+), 1 deletion(-) create mode 100644 test/Dapper.AOT.Test/Interceptors/RequiredProperties.input.cs create mode 100644 test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.cs create mode 100644 test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.netfx.cs create mode 100644 test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.netfx.txt create mode 100644 test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.txt diff --git a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs index b0ca9fbc..fd611ffe 100644 --- a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs +++ b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs @@ -690,13 +690,14 @@ private static void WriteRowFactory(in GenerateState context, CodeWriter sb, ITy } var hasInitOnlyMembers = members.Any(member => member.IsInitOnly); + var hasRequiredMembers = members.Any(member => member.IsRequired); var hasGetOnlyMembers = members.Any(member => member is { IsGettable: true, IsSettable: false, IsInitOnly: false }); var useConstructorDeferred = map.Constructor is not null; var useFactoryMethodDeferred = map.FactoryMethod is not null; // Implementation detail: // constructor takes advantage over factory method. - var useDeferredConstruction = useConstructorDeferred || useFactoryMethodDeferred || hasInitOnlyMembers || hasGetOnlyMembers; + var useDeferredConstruction = useConstructorDeferred || useFactoryMethodDeferred || hasInitOnlyMembers || hasGetOnlyMembers || hasRequiredMembers; WriteRowFactoryHeader(); diff --git a/src/Dapper.AOT.Analyzers/Internal/CodeWriter.cs b/src/Dapper.AOT.Analyzers/Internal/CodeWriter.cs index 5e4b6feb..edb74d7e 100644 --- a/src/Dapper.AOT.Analyzers/Internal/CodeWriter.cs +++ b/src/Dapper.AOT.Analyzers/Internal/CodeWriter.cs @@ -201,6 +201,14 @@ public static bool IsSettableInstanceMember(ISymbol symbol, out ITypeSymbol type return false; } + public static bool IsRequired(ISymbol symbol) + => symbol switch + { + IPropertySymbol prop => prop.IsRequired, + IFieldSymbol field => field.IsRequired, + _ => false, + }; + public static bool IsInitOnlyInstanceMember(ISymbol symbol, out ITypeSymbol type) { if (symbol.DeclaredAccessibility == Accessibility.Public && !symbol.IsStatic) diff --git a/src/Dapper.AOT.Analyzers/Internal/Inspection.cs b/src/Dapper.AOT.Analyzers/Internal/Inspection.cs index d0f9a690..5578bde7 100644 --- a/src/Dapper.AOT.Analyzers/Internal/Inspection.cs +++ b/src/Dapper.AOT.Analyzers/Internal/Inspection.cs @@ -412,6 +412,7 @@ public readonly struct ElementMember public bool IsGettable => (_flags & ElementMemberFlags.IsGettable) != 0; public bool IsSettable => (_flags & ElementMemberFlags.IsSettable) != 0; public bool IsInitOnly => (_flags & ElementMemberFlags.IsInitOnly) != 0; + public bool IsRequired => (_flags & ElementMemberFlags.IsRequired) != 0; public bool IsExpandable => (_flags & ElementMemberFlags.IsExpandable) != 0; /// @@ -438,6 +439,7 @@ public enum ElementMemberFlags IsSettable = 1 << 1, IsInitOnly = 1 << 2, IsExpandable = 1 << 3, + IsRequired = 1 << 4, } public ElementMember( @@ -711,6 +713,7 @@ internal static ImmutableArray GetMembers(bool forParameters, ITy if (CodeWriter.IsGettableInstanceMember(member, out _)) flags |= ElementMember.ElementMemberFlags.IsGettable; if (CodeWriter.IsSettableInstanceMember(member, out _)) flags |= ElementMember.ElementMemberFlags.IsSettable; if (CodeWriter.IsInitOnlyInstanceMember(member, out _)) flags |= ElementMember.ElementMemberFlags.IsInitOnly; + if (CodeWriter.IsRequired(member)) flags |= ElementMember.ElementMemberFlags.IsRequired; if (forParameters) { diff --git a/test/Dapper.AOT.Test/Interceptors/RequiredProperties.input.cs b/test/Dapper.AOT.Test/Interceptors/RequiredProperties.input.cs new file mode 100644 index 00000000..41ddf020 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/RequiredProperties.input.cs @@ -0,0 +1,37 @@ +using Dapper; +using Microsoft.Data.SqlClient; +using System.Threading.Tasks; + +[DapperAot] + +public class SomeType +{ + public int Id { get; set; } + public required string Title { get; set; } + public bool IsComplete { get; set; } + + public Task Insert(SqlConnection connection) + => connection.QuerySingleAsync(""" + INSERT INTO MyTable(Title, IsComplete) + Values(@Title, @IsComplete) + RETURNING * + """, this); +} + + +namespace System.Runtime.CompilerServices +{ + [System.AttributeUsage(System.AttributeTargets.All, AllowMultiple = true, Inherited = false)] + sealed file class CompilerFeatureRequiredAttribute : Attribute + { + public CompilerFeatureRequiredAttribute(string _) { } + } + + [System.AttributeUsage(System.AttributeTargets.Class | System.AttributeTargets.Field | System.AttributeTargets.Property | System.AttributeTargets.Struct, AllowMultiple = false, Inherited = false)] + sealed file class RequiredMemberAttribute : Attribute {} +} +namespace System.Diagnostics.CodeAnalysis +{ + [System.AttributeUsage(System.AttributeTargets.Constructor, AllowMultiple = false, Inherited = false)] + sealed file class SetsRequiredMembersAttribute : Attribute { } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.cs b/test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.cs new file mode 100644 index 00000000..37e3ec31 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.cs @@ -0,0 +1,164 @@ +#nullable enable +namespace Dapper.AOT // interceptors must be in a known namespace +{ + file static class DapperGeneratedInterceptors + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\RequiredProperties.input.cs", 14, 23)] + internal static global::System.Threading.Tasks.Task QuerySingleAsync0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Query, Async, TypedResult, HasParameters, SingleRow, Text, AtLeastOne, AtMostOne, BindResultsByName, KnownParameters + // takes parameter: global::SomeType + // parameter map: IsComplete Title + // returns data: global::SomeType + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).QuerySingleAsync((global::SomeType)param!, RowFactory0.Instance); + + } + + private class CommonCommandFactory : global::Dapper.CommandFactory + { + public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args) + { + var cmd = base.GetCommand(connection, sql, commandType, args); + // apply special per-provider command initialization logic for OracleCommand + if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0) + { + cmd0.BindByName = true; + cmd0.InitialLONGFetchSize = -1; + + } + return cmd; + } + + } + + private static readonly CommonCommandFactory DefaultCommandFactory = new(); + + private sealed class RowFactory0 : global::Dapper.RowFactory + { + internal static readonly RowFactory0 Instance = new(); + private RowFactory0() {} + public override object? Tokenize(global::System.Data.Common.DbDataReader reader, global::System.Span tokens, int columnOffset) + { + for (int i = 0; i < tokens.Length; i++) + { + int token = -1; + var name = reader.GetName(columnOffset); + var type = reader.GetFieldType(columnOffset); + switch (NormalizedHash(name)) + { + case 926444256U when NormalizedEquals(name, "id"): + token = type == typeof(int) ? 0 : 3; // two tokens for right-typed and type-flexible + break; + case 2556802313U when NormalizedEquals(name, "title"): + token = type == typeof(string) ? 1 : 4; + break; + case 2563538258U when NormalizedEquals(name, "iscomplete"): + token = type == typeof(bool) ? 2 : 5; + break; + + } + tokens[i] = token; + columnOffset++; + + } + return null; + } + public override global::SomeType Read(global::System.Data.Common.DbDataReader reader, global::System.ReadOnlySpan tokens, int columnOffset, object? state) + { + int value0 = default; + string? value1 = default; + bool value2 = default; + foreach (var token in tokens) + { + switch (token) + { + case 0: + value0 = reader.GetInt32(columnOffset); + break; + case 3: + value0 = GetValue(reader, columnOffset); + break; + case 1: + value1 = reader.IsDBNull(columnOffset) ? (string?)null : reader.GetString(columnOffset); + break; + case 4: + value1 = reader.IsDBNull(columnOffset) ? (string?)null : GetValue(reader, columnOffset); + break; + case 2: + value2 = reader.GetBoolean(columnOffset); + break; + case 5: + value2 = GetValue(reader, columnOffset); + break; + + } + columnOffset++; + + } + return new global::SomeType + { + Id = value0, + Title = value1, + IsComplete = value2, + }; + } + } + + private sealed class CommandFactory0 : CommonCommandFactory + { + internal static readonly CommandFactory0 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, global::SomeType args) + { + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "Title"; + p.DbType = global::System.Data.DbType.String; + p.Size = -1; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(args.Title); + ps.Add(p); + + p = cmd.CreateParameter(); + p.ParameterName = "IsComplete"; + p.DbType = global::System.Data.DbType.Boolean; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(args.IsComplete); + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, global::SomeType args) + { + var ps = cmd.Parameters; + ps[0].Value = AsValue(args.Title); + ps[1].Value = AsValue(args.IsComplete); + + } + public override bool CanPrepare => true; + + } + + + } +} +namespace System.Runtime.CompilerServices +{ + // this type is needed by the compiler to implement interceptors - it doesn't need to + // come from the runtime itself, though + + [global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] + sealed file class InterceptsLocationAttribute : global::System.Attribute + { + public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber) + { + _ = path; + _ = lineNumber; + _ = columnNumber; + } + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.netfx.cs b/test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.netfx.cs new file mode 100644 index 00000000..37e3ec31 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.netfx.cs @@ -0,0 +1,164 @@ +#nullable enable +namespace Dapper.AOT // interceptors must be in a known namespace +{ + file static class DapperGeneratedInterceptors + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\RequiredProperties.input.cs", 14, 23)] + internal static global::System.Threading.Tasks.Task QuerySingleAsync0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Query, Async, TypedResult, HasParameters, SingleRow, Text, AtLeastOne, AtMostOne, BindResultsByName, KnownParameters + // takes parameter: global::SomeType + // parameter map: IsComplete Title + // returns data: global::SomeType + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).QuerySingleAsync((global::SomeType)param!, RowFactory0.Instance); + + } + + private class CommonCommandFactory : global::Dapper.CommandFactory + { + public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args) + { + var cmd = base.GetCommand(connection, sql, commandType, args); + // apply special per-provider command initialization logic for OracleCommand + if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0) + { + cmd0.BindByName = true; + cmd0.InitialLONGFetchSize = -1; + + } + return cmd; + } + + } + + private static readonly CommonCommandFactory DefaultCommandFactory = new(); + + private sealed class RowFactory0 : global::Dapper.RowFactory + { + internal static readonly RowFactory0 Instance = new(); + private RowFactory0() {} + public override object? Tokenize(global::System.Data.Common.DbDataReader reader, global::System.Span tokens, int columnOffset) + { + for (int i = 0; i < tokens.Length; i++) + { + int token = -1; + var name = reader.GetName(columnOffset); + var type = reader.GetFieldType(columnOffset); + switch (NormalizedHash(name)) + { + case 926444256U when NormalizedEquals(name, "id"): + token = type == typeof(int) ? 0 : 3; // two tokens for right-typed and type-flexible + break; + case 2556802313U when NormalizedEquals(name, "title"): + token = type == typeof(string) ? 1 : 4; + break; + case 2563538258U when NormalizedEquals(name, "iscomplete"): + token = type == typeof(bool) ? 2 : 5; + break; + + } + tokens[i] = token; + columnOffset++; + + } + return null; + } + public override global::SomeType Read(global::System.Data.Common.DbDataReader reader, global::System.ReadOnlySpan tokens, int columnOffset, object? state) + { + int value0 = default; + string? value1 = default; + bool value2 = default; + foreach (var token in tokens) + { + switch (token) + { + case 0: + value0 = reader.GetInt32(columnOffset); + break; + case 3: + value0 = GetValue(reader, columnOffset); + break; + case 1: + value1 = reader.IsDBNull(columnOffset) ? (string?)null : reader.GetString(columnOffset); + break; + case 4: + value1 = reader.IsDBNull(columnOffset) ? (string?)null : GetValue(reader, columnOffset); + break; + case 2: + value2 = reader.GetBoolean(columnOffset); + break; + case 5: + value2 = GetValue(reader, columnOffset); + break; + + } + columnOffset++; + + } + return new global::SomeType + { + Id = value0, + Title = value1, + IsComplete = value2, + }; + } + } + + private sealed class CommandFactory0 : CommonCommandFactory + { + internal static readonly CommandFactory0 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, global::SomeType args) + { + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "Title"; + p.DbType = global::System.Data.DbType.String; + p.Size = -1; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(args.Title); + ps.Add(p); + + p = cmd.CreateParameter(); + p.ParameterName = "IsComplete"; + p.DbType = global::System.Data.DbType.Boolean; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(args.IsComplete); + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, global::SomeType args) + { + var ps = cmd.Parameters; + ps[0].Value = AsValue(args.Title); + ps[1].Value = AsValue(args.IsComplete); + + } + public override bool CanPrepare => true; + + } + + + } +} +namespace System.Runtime.CompilerServices +{ + // this type is needed by the compiler to implement interceptors - it doesn't need to + // come from the runtime itself, though + + [global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] + sealed file class InterceptsLocationAttribute : global::System.Attribute + { + public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber) + { + _ = path; + _ = lineNumber; + _ = columnNumber; + } + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.netfx.txt b/test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.netfx.txt new file mode 100644 index 00000000..721f00dd --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.netfx.txt @@ -0,0 +1,4 @@ +Generator produced 1 diagnostics: + +Hidden DAP000 L1 C1 +Dapper.AOT handled 1 of 1 possible call-sites using 1 interceptors, 1 commands and 1 readers diff --git a/test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.txt b/test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.txt new file mode 100644 index 00000000..721f00dd --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.txt @@ -0,0 +1,4 @@ +Generator produced 1 diagnostics: + +Hidden DAP000 L1 C1 +Dapper.AOT handled 1 of 1 possible call-sites using 1 interceptors, 1 commands and 1 readers