Skip to content

Commit

Permalink
Use deferred construction if there are required members (fix #71)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgravell committed Nov 15, 2023
1 parent 82fce34 commit 38cfa09
Show file tree
Hide file tree
Showing 8 changed files with 386 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -690,13 +690,14 @@ private static void WriteRowFactory(in GenerateState context, CodeWriter sb, ITy
}

var hasInitOnlyMembers = members.Any(member => member.IsInitOnly);
var hasRequiredMembers = members.Any(member => member.IsRequired);
var hasGetOnlyMembers = members.Any(member => member is { IsGettable: true, IsSettable: false, IsInitOnly: false });
var useConstructorDeferred = map.Constructor is not null;
var useFactoryMethodDeferred = map.FactoryMethod is not null;

// Implementation detail:
// constructor takes advantage over factory method.
var useDeferredConstruction = useConstructorDeferred || useFactoryMethodDeferred || hasInitOnlyMembers || hasGetOnlyMembers;
var useDeferredConstruction = useConstructorDeferred || useFactoryMethodDeferred || hasInitOnlyMembers || hasGetOnlyMembers || hasRequiredMembers;

WriteRowFactoryHeader();

Expand Down
8 changes: 8 additions & 0 deletions src/Dapper.AOT.Analyzers/Internal/CodeWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ public static bool IsSettableInstanceMember(ISymbol symbol, out ITypeSymbol type
return false;
}

public static bool IsRequired(ISymbol symbol)
=> symbol switch
{
IPropertySymbol prop => prop.IsRequired,
IFieldSymbol field => field.IsRequired,
_ => false,
};

public static bool IsInitOnlyInstanceMember(ISymbol symbol, out ITypeSymbol type)
{
if (symbol.DeclaredAccessibility == Accessibility.Public && !symbol.IsStatic)
Expand Down
3 changes: 3 additions & 0 deletions src/Dapper.AOT.Analyzers/Internal/Inspection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ public readonly struct ElementMember
public bool IsGettable => (_flags & ElementMemberFlags.IsGettable) != 0;
public bool IsSettable => (_flags & ElementMemberFlags.IsSettable) != 0;
public bool IsInitOnly => (_flags & ElementMemberFlags.IsInitOnly) != 0;
public bool IsRequired => (_flags & ElementMemberFlags.IsRequired) != 0;
public bool IsExpandable => (_flags & ElementMemberFlags.IsExpandable) != 0;

/// <summary>
Expand All @@ -438,6 +439,7 @@ public enum ElementMemberFlags
IsSettable = 1 << 1,
IsInitOnly = 1 << 2,
IsExpandable = 1 << 3,
IsRequired = 1 << 4,
}

public ElementMember(
Expand Down Expand Up @@ -711,6 +713,7 @@ internal static ImmutableArray<ElementMember> GetMembers(bool forParameters, ITy
if (CodeWriter.IsGettableInstanceMember(member, out _)) flags |= ElementMember.ElementMemberFlags.IsGettable;
if (CodeWriter.IsSettableInstanceMember(member, out _)) flags |= ElementMember.ElementMemberFlags.IsSettable;
if (CodeWriter.IsInitOnlyInstanceMember(member, out _)) flags |= ElementMember.ElementMemberFlags.IsInitOnly;
if (CodeWriter.IsRequired(member)) flags |= ElementMember.ElementMemberFlags.IsRequired;

if (forParameters)
{
Expand Down
37 changes: 37 additions & 0 deletions test/Dapper.AOT.Test/Interceptors/RequiredProperties.input.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using Dapper;
using Microsoft.Data.SqlClient;
using System.Threading.Tasks;

[DapperAot]

public class SomeType
{
public int Id { get; set; }
public required string Title { get; set; }
public bool IsComplete { get; set; }

public Task<SomeType> Insert(SqlConnection connection)
=> connection.QuerySingleAsync<SomeType>("""
INSERT INTO MyTable(Title, IsComplete)
Values(@Title, @IsComplete)
RETURNING *
""", this);
}


namespace System.Runtime.CompilerServices
{
[System.AttributeUsage(System.AttributeTargets.All, AllowMultiple = true, Inherited = false)]
sealed file class CompilerFeatureRequiredAttribute : Attribute
{
public CompilerFeatureRequiredAttribute(string _) { }
}

[System.AttributeUsage(System.AttributeTargets.Class | System.AttributeTargets.Field | System.AttributeTargets.Property | System.AttributeTargets.Struct, AllowMultiple = false, Inherited = false)]
sealed file class RequiredMemberAttribute : Attribute {}
}
namespace System.Diagnostics.CodeAnalysis
{
[System.AttributeUsage(System.AttributeTargets.Constructor, AllowMultiple = false, Inherited = false)]
sealed file class SetsRequiredMembersAttribute : Attribute { }
}
164 changes: 164 additions & 0 deletions test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#nullable enable
namespace Dapper.AOT // interceptors must be in a known namespace
{
file static class DapperGeneratedInterceptors
{
[global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\RequiredProperties.input.cs", 14, 23)]
internal static global::System.Threading.Tasks.Task<global::SomeType> QuerySingleAsync0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType)
{
// Query, Async, TypedResult, HasParameters, SingleRow, Text, AtLeastOne, AtMostOne, BindResultsByName, KnownParameters
// takes parameter: global::SomeType
// parameter map: IsComplete Title
// returns data: global::SomeType
global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql));
global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text);
global::System.Diagnostics.Debug.Assert(param is not null);

return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).QuerySingleAsync((global::SomeType)param!, RowFactory0.Instance);

}

private class CommonCommandFactory<T> : global::Dapper.CommandFactory<T>
{
public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args)
{
var cmd = base.GetCommand(connection, sql, commandType, args);
// apply special per-provider command initialization logic for OracleCommand
if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0)
{
cmd0.BindByName = true;
cmd0.InitialLONGFetchSize = -1;

}
return cmd;
}

}

private static readonly CommonCommandFactory<object?> DefaultCommandFactory = new();

private sealed class RowFactory0 : global::Dapper.RowFactory<global::SomeType>
{
internal static readonly RowFactory0 Instance = new();
private RowFactory0() {}
public override object? Tokenize(global::System.Data.Common.DbDataReader reader, global::System.Span<int> tokens, int columnOffset)
{
for (int i = 0; i < tokens.Length; i++)
{
int token = -1;
var name = reader.GetName(columnOffset);
var type = reader.GetFieldType(columnOffset);
switch (NormalizedHash(name))
{
case 926444256U when NormalizedEquals(name, "id"):
token = type == typeof(int) ? 0 : 3; // two tokens for right-typed and type-flexible
break;
case 2556802313U when NormalizedEquals(name, "title"):
token = type == typeof(string) ? 1 : 4;
break;
case 2563538258U when NormalizedEquals(name, "iscomplete"):
token = type == typeof(bool) ? 2 : 5;
break;

}
tokens[i] = token;
columnOffset++;

}
return null;
}
public override global::SomeType Read(global::System.Data.Common.DbDataReader reader, global::System.ReadOnlySpan<int> tokens, int columnOffset, object? state)
{
int value0 = default;
string? value1 = default;
bool value2 = default;
foreach (var token in tokens)
{
switch (token)
{
case 0:
value0 = reader.GetInt32(columnOffset);
break;
case 3:
value0 = GetValue<int>(reader, columnOffset);
break;
case 1:
value1 = reader.IsDBNull(columnOffset) ? (string?)null : reader.GetString(columnOffset);
break;
case 4:
value1 = reader.IsDBNull(columnOffset) ? (string?)null : GetValue<string>(reader, columnOffset);
break;
case 2:
value2 = reader.GetBoolean(columnOffset);
break;
case 5:
value2 = GetValue<bool>(reader, columnOffset);
break;

}
columnOffset++;

}
return new global::SomeType
{
Id = value0,
Title = value1,
IsComplete = value2,
};
}
}

private sealed class CommandFactory0 : CommonCommandFactory<global::SomeType>
{
internal static readonly CommandFactory0 Instance = new();
public override void AddParameters(in global::Dapper.UnifiedCommand cmd, global::SomeType args)
{
var ps = cmd.Parameters;
global::System.Data.Common.DbParameter p;
p = cmd.CreateParameter();
p.ParameterName = "Title";
p.DbType = global::System.Data.DbType.String;
p.Size = -1;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(args.Title);
ps.Add(p);

p = cmd.CreateParameter();
p.ParameterName = "IsComplete";
p.DbType = global::System.Data.DbType.Boolean;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(args.IsComplete);
ps.Add(p);

}
public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, global::SomeType args)
{
var ps = cmd.Parameters;
ps[0].Value = AsValue(args.Title);
ps[1].Value = AsValue(args.IsComplete);

}
public override bool CanPrepare => true;

}


}
}
namespace System.Runtime.CompilerServices
{
// this type is needed by the compiler to implement interceptors - it doesn't need to
// come from the runtime itself, though

[global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate
[global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)]
sealed file class InterceptsLocationAttribute : global::System.Attribute
{
public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber)
{
_ = path;
_ = lineNumber;
_ = columnNumber;
}
}
}
Loading

0 comments on commit 38cfa09

Please sign in to comment.