From 51494867569e9e564bdab25ae1e063fea01a9350 Mon Sep 17 00:00:00 2001 From: GrahamTheCoder Date: Sun, 12 Dec 2021 21:56:47 +0000 Subject: [PATCH] Convert ref extension method to static invocation - fixes #785 --- CHANGELOG.md | 1 + CodeConverter/CSharp/ExpressionNodeVisitor.cs | 30 +++++++++---- CodeConverter/Util/ISymbolExtensions.cs | 6 +++ Tests/CSharp/MemberTests/MemberTests.cs | 43 +++++++++++++++++++ 4 files changed, 71 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ddcda730..dea5bba8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### VB -> C# +* Convert extension methods on ByRef reference types to static invocations [#785](https://github.com/icsharpcode/CodeConverter/issues/785) ### C# -> VB diff --git a/CodeConverter/CSharp/ExpressionNodeVisitor.cs b/CodeConverter/CSharp/ExpressionNodeVisitor.cs index 471e6e18f..3b23fd701 100644 --- a/CodeConverter/CSharp/ExpressionNodeVisitor.cs +++ b/CodeConverter/CSharp/ExpressionNodeVisitor.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Data; using System.Linq; +using System.Security.Cryptography; using System.Threading.Tasks; using ICSharpCode.CodeConverter.Shared; using ICSharpCode.CodeConverter.Util; @@ -13,6 +14,7 @@ using Microsoft.CodeAnalysis.Simplification; using Microsoft.VisualBasic.CompilerServices; using IOperation = Microsoft.CodeAnalysis.IOperation; +using ISymbolExtensions = ICSharpCode.CodeConverter.Util.ISymbolExtensions; using SyntaxFactory = Microsoft.CodeAnalysis.CSharp.SyntaxFactory; using SyntaxKind = Microsoft.CodeAnalysis.CSharp.SyntaxKind; using VBasic = Microsoft.CodeAnalysis.VisualBasic; @@ -897,16 +899,24 @@ private async Task ConvertInvocationAsync(VBSyntax.InvocationE return convertedExpression; //Parameterless property access } - if (!IsElementAtOrDefaultInvocation(invocationSymbol, expressionSymbol)) { - return SyntaxFactory.InvocationExpression(convertedExpression, - await ConvertArgumentListOrEmptyAsync(node, node.ArgumentList)); - } + var convertedArgumentList= await ConvertArgumentListOrEmptyAsync(node, node.ArgumentList); - var newExpression = GetElementAtOrDefaultExpression(expressionType, convertedExpression); + if (IsElementAtOrDefaultInvocation(invocationSymbol, expressionSymbol)) { + convertedExpression = GetElementAtOrDefaultExpression(expressionType, convertedExpression); + } - return SyntaxFactory.InvocationExpression(newExpression, - await ConvertArgumentListOrEmptyAsync(node, node.ArgumentList)); + if (invocationSymbol.IsReducedExtension() && invocationSymbol is IMethodSymbol {ReducedFrom: {Parameters: var parameters}} && + !parameters.FirstOrDefault().ValidCSharpExtensionMethodParameter() && + node.Expression is VBSyntax.MemberAccessExpressionSyntax maes) { + var thisArgExpression = await maes.Expression.AcceptAsync(TriviaConvertingExpressionVisitor); + var thisArg = Microsoft.CodeAnalysis.CSharp.SyntaxFactory.Argument(thisArgExpression).WithRefKindKeyword(GetRefToken(RefKind.Ref)); + convertedArgumentList = SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(convertedArgumentList.Arguments.Prepend(thisArg))); + var containingType = (ExpressionSyntax) CommonConversions.CsSyntaxGenerator.TypeExpression(invocationSymbol.ContainingType); + convertedExpression = SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, containingType, + SyntaxFactory.IdentifierName(CommonConversions.CsEscapedIdentifier(invocationSymbol.Name))); + } + return SyntaxFactory.InvocationExpression(convertedExpression, convertedArgumentList); } private async Task<(ExpressionSyntax, bool isElementAccess)> ConvertInvocationSubExpressionAsync(VBSyntax.InvocationExpressionSyntax node, @@ -1148,7 +1158,8 @@ public override async Task VisitParameter(VBSyntax.ParameterSy var attributes = (await node.AttributeLists.SelectManyAsync(CommonConversions.ConvertAttributeAsync)).ToList(); var modifiers = CommonConversions.ConvertModifiers(node, node.Modifiers, TokenContext.Local); - var csParamSymbol = CommonConversions.GetDeclaredCsOriginalSymbolOrNull(node) as IParameterSymbol; + var vbSymbol = _semanticModel.GetDeclaredSymbol(node) as IParameterSymbol; + var csParamSymbol = CommonConversions.GetCsOriginalSymbolOrNull(vbSymbol) as IParameterSymbol; if (csParamSymbol?.RefKind == RefKind.Out || node.AttributeLists.Any(CommonConversions.HasOutAttribute)) { modifiers = SyntaxFactory.TokenList(modifiers .Where(m => !m.IsKind(SyntaxKind.RefKeyword)) @@ -1193,7 +1204,8 @@ public override async Task VisitParameter(VBSyntax.ParameterSy } if (node.Parent.Parent is VBSyntax.MethodStatementSyntax mss - && mss.AttributeLists.Any(CommonConversions.HasExtensionAttribute) && node.Parent.ChildNodes().First() == node) { + && mss.AttributeLists.Any(CommonConversions.HasExtensionAttribute) && node.Parent.ChildNodes().First() == node && + vbSymbol.ValidCSharpExtensionMethodParameter()) { modifiers = modifiers.Insert(0, SyntaxFactory.Token(SyntaxKind.ThisKeyword)); } return SyntaxFactory.Parameter( diff --git a/CodeConverter/Util/ISymbolExtensions.cs b/CodeConverter/Util/ISymbolExtensions.cs index 3997f4954..1d18a187e 100644 --- a/CodeConverter/Util/ISymbolExtensions.cs +++ b/CodeConverter/Util/ISymbolExtensions.cs @@ -89,6 +89,12 @@ public static bool IsReducedTypeParameterMethod(this ISymbol symbol) { return symbol is IMethodSymbol ms && ms.ReducedFrom?.TypeParameters.Count() > ms.TypeParameters.Count(); } + + /// + /// Since non value types can't be ref types for extension methods in C#, convert to a static invocation + /// https://github.com/icsharpcode/CodeConverter/issues/785 + /// + public static bool ValidCSharpExtensionMethodParameter(this IParameterSymbol vbSymbol) => vbSymbol != null && (vbSymbol.RefKind != RefKind.Ref || vbSymbol.Type.IsValueType); } } diff --git a/Tests/CSharp/MemberTests/MemberTests.cs b/Tests/CSharp/MemberTests/MemberTests.cs index bc5ec77a8..a8389b9f1 100644 --- a/Tests/CSharp/MemberTests/MemberTests.cs +++ b/Tests/CSharp/MemberTests/MemberTests.cs @@ -574,6 +574,49 @@ public static void TestMethod(this string str) }"); } + [Fact] + public async Task TestRefExtensionMethodAsync() + { + await TestConversionVisualBasicToCSharpAsync( +@"Imports System +Imports System.Runtime.CompilerServices ' Removed since the extension attribute is removed + +Public Module MyExtensions + + Public Sub Add(Of T)(ByRef arr As T(), item As T) + Array.Resize(arr, arr.Length + 1) + arr(arr.Length - 1) = item + End Sub +End Module + +Public Module UsagePoint + Public Sub Main() + Dim arr = New Integer() {1, 2, 3} + arr.Add(4) + System.Console.WriteLine(arr(3)) + End Sub +End Module", @"using System; + +public static partial class MyExtensions +{ + public static void Add(ref T[] arr, T item) + { + Array.Resize(ref arr, arr.Length + 1); + arr[arr.Length - 1] = item; + } +} + +public static partial class UsagePoint +{ + public static void Main() + { + var arr = new int[] { 1, 2, 3 }; + MyExtensions.Add(ref arr, 4); + Console.WriteLine(arr[3]); + } +}"); + } + [Fact] public async Task TestExtensionWithinExtendedTypeAsync() {