diff --git a/src/Moq/ExpressionExtensions.cs b/src/Moq/ExpressionExtensions.cs index 256f74cd4..d17eb5169 100644 --- a/src/Moq/ExpressionExtensions.cs +++ b/src/Moq/ExpressionExtensions.cs @@ -10,6 +10,7 @@ using System.Reflection; using System.Text; +using Moq.Async; using Moq.Properties; using Moq.Protected; @@ -294,6 +295,16 @@ void Split(Expression e, out Expression r /* remainder */, out InvocationShape p { var memberAccessExpression = (MemberExpression)e; Debug.Assert(memberAccessExpression.Member is PropertyInfo); + + if (IsResult(memberAccessExpression.Member, out var awaitableFactory)) + { + Split(memberAccessExpression.Expression, out r, out p); + p.AddResultExpression( + awaitable => Expression.MakeMemberAccess(awaitable, memberAccessExpression.Member), + awaitableFactory); + return; + } + r = memberAccessExpression.Expression; var parameter = Expression.Parameter(r.Type, r is ParameterExpression ope ? ope.Name : ParameterName); var property = memberAccessExpression.GetReboundProperty(); @@ -327,6 +338,15 @@ void Split(Expression e, out Expression r /* remainder */, out InvocationShape p throw new InvalidOperationException(); // this should be unreachable } } + + bool IsResult(MemberInfo member, out IAwaitableFactory awaitableFactory) + { + var instanceType = member.DeclaringType; + awaitableFactory = AwaitableFactory.TryGet(instanceType); + var returnType = member switch { PropertyInfo p => p.PropertyType, + _ => null }; + return awaitableFactory != null && object.Equals(returnType, awaitableFactory.ResultType); + } } internal static PropertyInfo GetReboundProperty(this MemberExpression expression) diff --git a/src/Moq/Invocation.cs b/src/Moq/Invocation.cs index 64b7cbc73..b859d11f0 100644 --- a/src/Moq/Invocation.cs +++ b/src/Moq/Invocation.cs @@ -7,6 +7,8 @@ using System.Reflection; using System.Text; +using Moq.Async; + namespace Moq { internal abstract class Invocation : IInvocation @@ -89,6 +91,18 @@ public Exception Exception } } + public void ConvertResultToAwaitable(IAwaitableFactory awaitableFactory) + { + if (this.result is ExceptionResult r) + { + this.result = awaitableFactory.CreateFaulted(r.Exception); + } + else if (this.result != null && !this.method.ReturnType.IsAssignableFrom(this.result.GetType())) + { + this.result = awaitableFactory.CreateCompleted(this.result); + } + } + public bool IsVerified => this.verified; /// <summary> diff --git a/src/Moq/InvocationShape.cs b/src/Moq/InvocationShape.cs index 11fdec7ee..2c9a78768 100644 --- a/src/Moq/InvocationShape.cs +++ b/src/Moq/InvocationShape.cs @@ -8,6 +8,7 @@ using System.Linq.Expressions; using System.Reflection; +using Moq.Async; using Moq.Expressions.Visitors; using E = System.Linq.Expressions.Expression; @@ -60,11 +61,12 @@ public static InvocationShape CreateFrom(Invocation invocation) private static readonly Expression[] noArguments = new Expression[0]; private static readonly IMatcher[] noArgumentMatchers = new IMatcher[0]; - public readonly LambdaExpression Expression; + public LambdaExpression Expression; public readonly MethodInfo Method; public readonly IReadOnlyList<Expression> Arguments; private readonly IMatcher[] argumentMatchers; + private IAwaitableFactory awaitableFactory; private MethodInfo methodImplementation; private Expression[] partiallyEvaluatedArguments; #if DEBUG @@ -98,6 +100,17 @@ public InvocationShape(LambdaExpression expression, MethodInfo method, IReadOnly this.exactGenericTypeArguments = exactGenericTypeArguments; } + public void AddResultExpression(Func<E, E> add, IAwaitableFactory awaitableFactory) + { + this.Expression = E.Lambda(add(this.Expression.Body), this.Expression.Parameters); + this.awaitableFactory = awaitableFactory; + } + + public bool HasResultExpression(out IAwaitableFactory awaitableFactory) + { + return (awaitableFactory = this.awaitableFactory) != null; + } + public void Deconstruct(out LambdaExpression expression, out MethodInfo method, out IReadOnlyList<Expression> arguments) { expression = this.Expression; diff --git a/src/Moq/Setup.cs b/src/Moq/Setup.cs index 66c20f595..6a72d7949 100644 --- a/src/Moq/Setup.cs +++ b/src/Moq/Setup.cs @@ -65,7 +65,25 @@ public void Execute(Invocation invocation) this.Condition?.SetupEvaluatedSuccessfully(); this.expectation.SetupEvaluatedSuccessfully(invocation); - this.ExecuteCore(invocation); + if (this.expectation.HasResultExpression(out var awaitableFactory)) + { + try + { + this.ExecuteCore(invocation); + } + catch (Exception exception) + { + invocation.Exception = exception; + } + finally + { + invocation.ConvertResultToAwaitable(awaitableFactory); + } + } + else + { + this.ExecuteCore(invocation); + } } protected abstract void ExecuteCore(Invocation invocation);