Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use deferred construction if there are required members (fix #71) #73

Merged
merged 1 commit into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -690,13 +690,14 @@ private static void WriteRowFactory(in GenerateState context, CodeWriter sb, ITy
}

var hasInitOnlyMembers = members.Any(member => member.IsInitOnly);
var hasRequiredMembers = members.Any(member => member.IsRequired);
var hasGetOnlyMembers = members.Any(member => member is { IsGettable: true, IsSettable: false, IsInitOnly: false });
var useConstructorDeferred = map.Constructor is not null;
var useFactoryMethodDeferred = map.FactoryMethod is not null;

// Implementation detail:
// constructor takes advantage over factory method.
var useDeferredConstruction = useConstructorDeferred || useFactoryMethodDeferred || hasInitOnlyMembers || hasGetOnlyMembers;
var useDeferredConstruction = useConstructorDeferred || useFactoryMethodDeferred || hasInitOnlyMembers || hasGetOnlyMembers || hasRequiredMembers;

WriteRowFactoryHeader();

Expand Down
8 changes: 8 additions & 0 deletions src/Dapper.AOT.Analyzers/Internal/CodeWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ public static bool IsSettableInstanceMember(ISymbol symbol, out ITypeSymbol type
return false;
}

public static bool IsRequired(ISymbol symbol)
=> symbol switch
{
IPropertySymbol prop => prop.IsRequired,
IFieldSymbol field => field.IsRequired,
_ => false,
};

public static bool IsInitOnlyInstanceMember(ISymbol symbol, out ITypeSymbol type)
{
if (symbol.DeclaredAccessibility == Accessibility.Public && !symbol.IsStatic)
Expand Down
3 changes: 3 additions & 0 deletions src/Dapper.AOT.Analyzers/Internal/Inspection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ public readonly struct ElementMember
public bool IsGettable => (_flags & ElementMemberFlags.IsGettable) != 0;
public bool IsSettable => (_flags & ElementMemberFlags.IsSettable) != 0;
public bool IsInitOnly => (_flags & ElementMemberFlags.IsInitOnly) != 0;
public bool IsRequired => (_flags & ElementMemberFlags.IsRequired) != 0;
public bool IsExpandable => (_flags & ElementMemberFlags.IsExpandable) != 0;

/// <summary>
Expand All @@ -438,6 +439,7 @@ public enum ElementMemberFlags
IsSettable = 1 << 1,
IsInitOnly = 1 << 2,
IsExpandable = 1 << 3,
IsRequired = 1 << 4,
}

public ElementMember(
Expand Down Expand Up @@ -711,6 +713,7 @@ internal static ImmutableArray<ElementMember> GetMembers(bool forParameters, ITy
if (CodeWriter.IsGettableInstanceMember(member, out _)) flags |= ElementMember.ElementMemberFlags.IsGettable;
if (CodeWriter.IsSettableInstanceMember(member, out _)) flags |= ElementMember.ElementMemberFlags.IsSettable;
if (CodeWriter.IsInitOnlyInstanceMember(member, out _)) flags |= ElementMember.ElementMemberFlags.IsInitOnly;
if (CodeWriter.IsRequired(member)) flags |= ElementMember.ElementMemberFlags.IsRequired;

if (forParameters)
{
Expand Down
37 changes: 37 additions & 0 deletions test/Dapper.AOT.Test/Interceptors/RequiredProperties.input.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using Dapper;
using Microsoft.Data.SqlClient;
using System.Threading.Tasks;

[DapperAot]

public class SomeType
{
public int Id { get; set; }
public required string Title { get; set; }
public bool IsComplete { get; set; }

public Task<SomeType> Insert(SqlConnection connection)
=> connection.QuerySingleAsync<SomeType>("""
INSERT INTO MyTable(Title, IsComplete)
Values(@Title, @IsComplete)
RETURNING *
""", this);
}


namespace System.Runtime.CompilerServices
{
[System.AttributeUsage(System.AttributeTargets.All, AllowMultiple = true, Inherited = false)]
sealed file class CompilerFeatureRequiredAttribute : Attribute
{
public CompilerFeatureRequiredAttribute(string _) { }
}

[System.AttributeUsage(System.AttributeTargets.Class | System.AttributeTargets.Field | System.AttributeTargets.Property | System.AttributeTargets.Struct, AllowMultiple = false, Inherited = false)]
sealed file class RequiredMemberAttribute : Attribute {}
}
namespace System.Diagnostics.CodeAnalysis
{
[System.AttributeUsage(System.AttributeTargets.Constructor, AllowMultiple = false, Inherited = false)]
sealed file class SetsRequiredMembersAttribute : Attribute { }
}
164 changes: 164 additions & 0 deletions test/Dapper.AOT.Test/Interceptors/RequiredProperties.output.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#nullable enable
namespace Dapper.AOT // interceptors must be in a known namespace
{
file static class DapperGeneratedInterceptors
{
[global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\RequiredProperties.input.cs", 14, 23)]
internal static global::System.Threading.Tasks.Task<global::SomeType> QuerySingleAsync0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType)
{
// Query, Async, TypedResult, HasParameters, SingleRow, Text, AtLeastOne, AtMostOne, BindResultsByName, KnownParameters
// takes parameter: global::SomeType
// parameter map: IsComplete Title
// returns data: global::SomeType
global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql));
global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text);
global::System.Diagnostics.Debug.Assert(param is not null);

return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).QuerySingleAsync((global::SomeType)param!, RowFactory0.Instance);

}

private class CommonCommandFactory<T> : global::Dapper.CommandFactory<T>
{
public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args)
{
var cmd = base.GetCommand(connection, sql, commandType, args);
// apply special per-provider command initialization logic for OracleCommand
if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0)
{
cmd0.BindByName = true;
cmd0.InitialLONGFetchSize = -1;

}
return cmd;
}

}

private static readonly CommonCommandFactory<object?> DefaultCommandFactory = new();

private sealed class RowFactory0 : global::Dapper.RowFactory<global::SomeType>
{
internal static readonly RowFactory0 Instance = new();
private RowFactory0() {}
public override object? Tokenize(global::System.Data.Common.DbDataReader reader, global::System.Span<int> tokens, int columnOffset)
{
for (int i = 0; i < tokens.Length; i++)
{
int token = -1;
var name = reader.GetName(columnOffset);
var type = reader.GetFieldType(columnOffset);
switch (NormalizedHash(name))
{
case 926444256U when NormalizedEquals(name, "id"):
token = type == typeof(int) ? 0 : 3; // two tokens for right-typed and type-flexible
break;
case 2556802313U when NormalizedEquals(name, "title"):
token = type == typeof(string) ? 1 : 4;
break;
case 2563538258U when NormalizedEquals(name, "iscomplete"):
token = type == typeof(bool) ? 2 : 5;
break;

}
tokens[i] = token;
columnOffset++;

}
return null;
}
public override global::SomeType Read(global::System.Data.Common.DbDataReader reader, global::System.ReadOnlySpan<int> tokens, int columnOffset, object? state)
{
int value0 = default;
string? value1 = default;
bool value2 = default;
foreach (var token in tokens)
{
switch (token)
{
case 0:
value0 = reader.GetInt32(columnOffset);
break;
case 3:
value0 = GetValue<int>(reader, columnOffset);
break;
case 1:
value1 = reader.IsDBNull(columnOffset) ? (string?)null : reader.GetString(columnOffset);
break;
case 4:
value1 = reader.IsDBNull(columnOffset) ? (string?)null : GetValue<string>(reader, columnOffset);
break;
case 2:
value2 = reader.GetBoolean(columnOffset);
break;
case 5:
value2 = GetValue<bool>(reader, columnOffset);
break;

}
columnOffset++;

}
return new global::SomeType
{
Id = value0,
Title = value1,
IsComplete = value2,
};
}
}

private sealed class CommandFactory0 : CommonCommandFactory<global::SomeType>
{
internal static readonly CommandFactory0 Instance = new();
public override void AddParameters(in global::Dapper.UnifiedCommand cmd, global::SomeType args)
{
var ps = cmd.Parameters;
global::System.Data.Common.DbParameter p;
p = cmd.CreateParameter();
p.ParameterName = "Title";
p.DbType = global::System.Data.DbType.String;
p.Size = -1;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(args.Title);
ps.Add(p);

p = cmd.CreateParameter();
p.ParameterName = "IsComplete";
p.DbType = global::System.Data.DbType.Boolean;
p.Direction = global::System.Data.ParameterDirection.Input;
p.Value = AsValue(args.IsComplete);
ps.Add(p);

}
public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, global::SomeType args)
{
var ps = cmd.Parameters;
ps[0].Value = AsValue(args.Title);
ps[1].Value = AsValue(args.IsComplete);

}
public override bool CanPrepare => true;

}


}
}
namespace System.Runtime.CompilerServices
{
// this type is needed by the compiler to implement interceptors - it doesn't need to
// come from the runtime itself, though

[global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate
[global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)]
sealed file class InterceptsLocationAttribute : global::System.Attribute
{
public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber)
{
_ = path;
_ = lineNumber;
_ = columnNumber;
}
}
}
Loading
Loading