Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to set up the .Result of (value) tasks #1126

Merged
merged 6 commits into from
Jan 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,38 @@ The format is loosely based on [Keep a Changelog](http://keepachangelog.com/en/1

## Unreleased

#### Added

* Ability to directly set up the `.Result` of tasks and value tasks, which makes setup expressions more uniform by rendering dedicated async verbs like `.ReturnsAsync`, `.ThrowsAsync`, etc. unnecessary:

```diff
-mock.Setup(x => x.GetFooAsync()).ReturnsAsync(foo)
+mock.Setup(x => x.GetFooAsync().Result).Returns(foo)
```

This is useful in places where there currently aren't any such async verbs at all:

```diff
-Mock.Of<X>(x => x.GetFooAsync() == Task.FromResult(foo))
+Mock.Of<X>(x => x.GetFooAsync().Result == foo)
```

This also allows recursive setups / method chaining across async calls inside a single setup expression:

```diff
-mock.Setup(x => x.GetFooAsync()).ReturnsAsync(Mock.Of<IFoo>(f => f.Bar == bar))
+mock.Setup(x => x.GetFooAsync().Result.Bar).Returns(bar)
```

or, with only `Mock.Of`:

```diff
-Mock.Of<X>(x => x.GetFooAsync() == Task.FromResult(Mock.Of<IFoo>(f => f.Bar == bar)))
+Mock.Of<X>(x => x.GetFooAsync().Result.Bar == bar)
```

This should work in all principal setup methods (`Mock.Of`, `mock.Setup…`, `mock.Verify…`). Support in `mock.Protected()` and for custom awaitable types may be added in the future. (@stakx, #1125)

#### Changed

* Attempts to mark conditionals setup as verifiable are once again allowed; it turns out that forbidding it (as was done in #997 for version 4.14.0) is in fact a regression. (@stakx, #1121)
Expand Down
23 changes: 21 additions & 2 deletions src/Moq/ActionObserver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Linq.Expressions;
using System.Reflection;

using Moq.Async;
using Moq.Expressions.Visitors;
using Moq.Internals;
using Moq.Properties;
Expand Down Expand Up @@ -60,6 +61,19 @@ public override Expression<Action<T>> ReconstructExpression<T>(Action<T> action,
var invocation = recorder.Invocation;
if (invocation != null)
{
var resultType = invocation.Method.DeclaringType;
if (resultType.IsAssignableFrom(body.Type) == false)
{
if (AwaitableFactory.TryGet(body.Type) is { } awaitableHandler
&& awaitableHandler.ResultType.IsAssignableFrom(resultType))
{
// We are here because the current invocation cannot be chained onto the previous one,
// however it *can* be chained if we assume that there was a `.Result` query on the
// former invocation that we don't see because non-virtual members aren't recorded.
// In this case, we make things work by adding back the missing `.Result`:
body = awaitableHandler.CreateResultExpression(body);
}
}
body = Expression.Call(body, invocation.Method, GetArgumentExpressions(invocation, recorder.Matches.ToArray()));
}
else
Expand Down Expand Up @@ -227,7 +241,7 @@ private sealed class Recorder : IInterceptor
private int creationTimestamp;
private Invocation invocation;
private int invocationTimestamp;
private IProxy returnValue;
private object returnValue;

public Recorder(MatcherObserver matcherObserver)
{
Expand All @@ -248,7 +262,7 @@ public IEnumerable<Match> Matches
}
}

public Recorder Next => this.returnValue?.Interceptor as Recorder;
public Recorder Next => (Awaitable.TryGetResultRecursive(this.returnValue) as IProxy)?.Interceptor as Recorder;

public void Intercept(Invocation invocation)
{
Expand Down Expand Up @@ -277,6 +291,11 @@ public void Intercept(Invocation invocation)
{
this.returnValue = null;
}
else if (AwaitableFactory.TryGet(returnType) is { } awaitableFactory)
{
var result = CreateProxy(awaitableFactory.ResultType, null, this.matcherObserver, out _);
this.returnValue = awaitableFactory.CreateCompleted(result);
}
else if (returnType.IsMockable())
{
this.returnValue = CreateProxy(returnType, null, this.matcherObserver, out _);
Expand Down
40 changes: 40 additions & 0 deletions src/Moq/Async/AwaitExpression.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Diagnostics;
using System.Linq.Expressions;

namespace Moq.Async
{
internal sealed class AwaitExpression : Expression
{
private readonly IAwaitableFactory awaitableFactory;
private readonly Expression operand;

public AwaitExpression(Expression operand, IAwaitableFactory awaitableFactory)
{
Debug.Assert(awaitableFactory != null);
Debug.Assert(operand != null);

this.awaitableFactory = awaitableFactory;
this.operand = operand;
}

public override bool CanReduce => false;

public override ExpressionType NodeType => ExpressionType.Extension;

public Expression Operand => this.operand;

public override Type Type => this.awaitableFactory.ResultType;

public override string ToString()
{
return this.awaitableFactory.ResultType == typeof(void) ? $"await {this.operand}"
: $"(await {this.operand})";
}

protected override Expression VisitChildren(ExpressionVisitor visitor) => this;
}
}
27 changes: 27 additions & 0 deletions src/Moq/Async/AwaitableFactory`1.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;

namespace Moq.Async
{
Expand All @@ -23,6 +26,30 @@ object IAwaitableFactory.CreateCompleted(object result)
return this.CreateCompleted();
}

public abstract TAwaitable CreateFaulted(Exception exception);

object IAwaitableFactory.CreateFaulted(Exception exception)
{
Debug.Assert(exception != null);

return this.CreateFaulted(exception);
}

public abstract TAwaitable CreateFaulted(IEnumerable<Exception> exceptions);

object IAwaitableFactory.CreateFaulted(IEnumerable<Exception> exceptions)
{
Debug.Assert(exceptions != null);
Debug.Assert(exceptions.Any());

return this.CreateFaulted(exceptions);
}

Expression IAwaitableFactory.CreateResultExpression(Expression awaitableExpression)
{
return new AwaitExpression(awaitableExpression, this);
}

bool IAwaitableFactory.TryGetResult(object awaitable, out object result)
{
Debug.Assert(awaitable is TAwaitable);
Expand Down
24 changes: 24 additions & 0 deletions src/Moq/Async/AwaitableFactory`2.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;

namespace Moq.Async
{
Expand All @@ -23,8 +26,29 @@ object IAwaitableFactory.CreateCompleted(object result)
return this.CreateCompleted((TResult)result);
}

public abstract TAwaitable CreateFaulted(Exception exception);

object IAwaitableFactory.CreateFaulted(Exception exception)
{
Debug.Assert(exception != null);

return this.CreateFaulted(exception);
}

public abstract TAwaitable CreateFaulted(IEnumerable<Exception> exceptions);

object IAwaitableFactory.CreateFaulted(IEnumerable<Exception> exceptions)
{
Debug.Assert(exceptions != null);
Debug.Assert(exceptions.Any());

return this.CreateFaulted(exceptions);
}

public abstract bool TryGetResult(TAwaitable awaitable, out TResult result);

public abstract Expression CreateResultExpression(Expression awaitableExpression);

bool IAwaitableFactory.TryGetResult(object awaitable, out object result)
{
Debug.Assert(awaitable is TAwaitable);
Expand Down
8 changes: 8 additions & 0 deletions src/Moq/Async/IAwaitableFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Linq.Expressions;

namespace Moq.Async
{
Expand All @@ -11,6 +13,12 @@ internal interface IAwaitableFactory

object CreateCompleted(object result = null);

object CreateFaulted(Exception exception);

object CreateFaulted(IEnumerable<Exception> exceptions);

Expression CreateResultExpression(Expression awaitableExpression);

bool TryGetResult(object awaitable, out object result);
}
}
18 changes: 18 additions & 0 deletions src/Moq/Async/TaskFactory.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading.Tasks;

namespace Moq.Async
Expand All @@ -17,5 +21,19 @@ public override Task CreateCompleted()
{
return Task.FromResult<object>(default);
}

public override Task CreateFaulted(Exception exception)
{
var tcs = new TaskCompletionSource<object>();
tcs.SetException(exception);
return tcs.Task;
}

public override Task CreateFaulted(IEnumerable<Exception> exceptions)
{
var tcs = new TaskCompletionSource<object>();
tcs.SetException(exceptions);
return tcs.Task;
}
}
}
24 changes: 24 additions & 0 deletions src/Moq/Async/TaskFactory`1.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Threading.Tasks;

namespace Moq.Async
Expand All @@ -12,6 +15,27 @@ public override Task<TResult> CreateCompleted(TResult result)
return Task.FromResult(result);
}

public override Task<TResult> CreateFaulted(Exception exception)
{
var tcs = new TaskCompletionSource<TResult>();
tcs.SetException(exception);
return tcs.Task;
}

public override Task<TResult> CreateFaulted(IEnumerable<Exception> exceptions)
{
var tcs = new TaskCompletionSource<TResult>();
tcs.SetException(exceptions);
return tcs.Task;
}

public override Expression CreateResultExpression(Expression awaitableExpression)
{
return Expression.MakeMemberAccess(
awaitableExpression,
typeof(Task<TResult>).GetProperty(nameof(Task<TResult>.Result)));
}

public override bool TryGetResult(Task<TResult> task, out TResult result)
{
if (task.Status == TaskStatus.RanToCompletion)
Expand Down
16 changes: 16 additions & 0 deletions src/Moq/Async/ValueTaskFactory.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Threading.Tasks;

namespace Moq.Async
Expand All @@ -17,5 +19,19 @@ public override ValueTask CreateCompleted()
{
return default;
}

public override ValueTask CreateFaulted(Exception exception)
{
var tcs = new TaskCompletionSource<object>();
tcs.SetException(exception);
return new ValueTask(tcs.Task);
}

public override ValueTask CreateFaulted(IEnumerable<Exception> exceptions)
{
var tcs = new TaskCompletionSource<object>();
tcs.SetException(exceptions);
return new ValueTask(tcs.Task);
}
}
}
24 changes: 24 additions & 0 deletions src/Moq/Async/ValueTaskFactory`1.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Copyright (c) 2007, Clarius Consulting, Manas Technology Solutions, InSTEDD, and Contributors.
// All rights reserved. Licensed under the BSD 3-Clause License; see License.txt.

using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Threading.Tasks;

namespace Moq.Async
Expand All @@ -12,6 +15,27 @@ public override ValueTask<TResult> CreateCompleted(TResult result)
return new ValueTask<TResult>(result);
}

public override ValueTask<TResult> CreateFaulted(Exception exception)
{
var tcs = new TaskCompletionSource<TResult>();
tcs.SetException(exception);
return new ValueTask<TResult>(tcs.Task);
}

public override ValueTask<TResult> CreateFaulted(IEnumerable<Exception> exceptions)
{
var tcs = new TaskCompletionSource<TResult>();
tcs.SetException(exceptions);
return new ValueTask<TResult>(tcs.Task);
}

public override Expression CreateResultExpression(Expression awaitableExpression)
{
return Expression.MakeMemberAccess(
awaitableExpression,
typeof(ValueTask<TResult>).GetProperty(nameof(ValueTask<TResult>.Result)));
}

public override bool TryGetResult(ValueTask<TResult> valueTask, out TResult result)
{
if (valueTask.IsCompletedSuccessfully)
Expand Down
Loading