diff --git a/DllImportGenerator/DllImportGenerator.IntegrationTests/ArrayTests.cs b/DllImportGenerator/DllImportGenerator.IntegrationTests/ArrayTests.cs index d535c13d7691..0a5b67328152 100644 --- a/DllImportGenerator/DllImportGenerator.IntegrationTests/ArrayTests.cs +++ b/DllImportGenerator/DllImportGenerator.IntegrationTests/ArrayTests.cs @@ -67,10 +67,16 @@ public class ArrayTests [Fact] public void IntArrayMarshalledToNativeAsExpected() { - var array = new [] { 1, 5, 79, 165, 32, 3 }; + var array = new[] { 1, 5, 79, 165, 32, 3 }; Assert.Equal(array.Sum(), NativeExportsNE.Arrays.Sum(array, array.Length)); } + [Fact] + public void NullIntArrayMarshalledToNativeAsExpected() + { + Assert.Equal(-1, NativeExportsNE.Arrays.Sum(null, 0)); + } + [Fact] public void ZeroLengthArrayMarshalledAsNonNull() { diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/BlittableArrayMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/BlittableArrayMarshaller.cs index 8e7291f9158b..cd42a96d8324 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/BlittableArrayMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/BlittableArrayMarshaller.cs @@ -56,21 +56,58 @@ public override IEnumerable Generate(TypePositionInfo info, Stu var (managedIdentifer, nativeIdentifier) = context.GetIdentifiers(info); if (!info.IsByRef && !info.IsManagedReturnPosition && context.PinningSupported) { + string byRefIdentifier = $"__byref_{managedIdentifer}"; + if (context.CurrentStage == StubCodeContext.Stage.Marshal) + { + // [COMPAT] We use explicit byref calculations here instead of just using a fixed statement + // since a fixed statement converts a zero-length array to a null pointer. + // Many native APIs, such as GDI+, ICU, etc. validate that an array parameter is non-null + // even when the passed in array length is zero. To avoid breaking customers that want to move + // to source-generated interop in subtle ways, we explicitly pass a reference to the 0-th element + // of an array as long as it is non-null, matching the behavior of the built-in interop system + // for single-dimensional zero-based arrays. + + // ref = == null ? ref *(); + var nullRef = + PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, + CastExpression( + PointerType(GetElementTypeSyntax(info)), + LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)))); + + var getArrayDataReference = + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), + IdentifierName("GetArrayDataReference")), + ArgumentList(SingletonSeparatedList( + Argument(IdentifierName(managedIdentifer))))); + + yield return LocalDeclarationStatement( + VariableDeclaration( + RefType(GetElementTypeSyntax(info))) + .WithVariables(SingletonSeparatedList( + VariableDeclarator(Identifier(byRefIdentifier)) + .WithInitializer(EqualsValueClause( + RefExpression(ParenthesizedExpression( + ConditionalExpression( + BinaryExpression( + SyntaxKind.EqualsExpression, + IdentifierName(managedIdentifer), + LiteralExpression( + SyntaxKind.NullLiteralExpression)), + RefExpression(nullRef), + RefExpression(getArrayDataReference))))))))); + } if (context.CurrentStage == StubCodeContext.Stage.Pin) { - // fixed ( = &MemoryMarshal.GetArrayDataReference()) + // fixed ( = &) yield return FixedStatement( VariableDeclaration(AsNativeType(info), SingletonSeparatedList( VariableDeclarator(nativeIdentifier) .WithInitializer(EqualsValueClause( PrefixUnaryExpression(SyntaxKind.AddressOfExpression, - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - ParseTypeName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), - IdentifierName("GetArrayDataReference")), - ArgumentList( - SingletonSeparatedList(Argument(IdentifierName(managedIdentifer))) - ))))))), + IdentifierName(byRefIdentifier)))))), EmptyStatement()); } yield break; diff --git a/DllImportGenerator/DllImportGenerator/TypeNames.cs b/DllImportGenerator/DllImportGenerator/TypeNames.cs index 457a4db0386a..509969a05230 100644 --- a/DllImportGenerator/DllImportGenerator/TypeNames.cs +++ b/DllImportGenerator/DllImportGenerator/TypeNames.cs @@ -44,7 +44,7 @@ public static string MarshalEx(AnalyzerConfigOptions options) public const string System_Runtime_InteropServices_OutAttribute = "System.Runtime.InteropServices.OutAttribute"; public const string System_Runtime_InteropServices_InAttribute = "System.Runtime.InteropServices.InAttribute"; - + public const string System_Runtime_CompilerServices_SkipLocalsInitAttribute = "System.Runtime.CompilerServices.SkipLocalsInitAttribute"; } }