Skip to content

Commit

Permalink
Send stream completion when client errors stream (#34147)
Browse files Browse the repository at this point in the history
  • Loading branch information
BrennanConroy authored Jul 9, 2021
1 parent 15949fa commit 2387c21
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ private static class Log
private static readonly Action<ILogger, Exception> _errorHandshakeCanceled =
LoggerMessage.Define(LogLevel.Error, new EventId(83, "ErrorHandshakeCanceled"), "The handshake was canceled by the client.");

private static readonly Action<ILogger, string, Exception?> _erroredStream =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(84, "ErroredStream"), "Client threw an error for stream '{StreamId}'.");

public static void PreparingNonBlockingInvocation(ILogger logger, string target, int count)
{
_preparingNonBlockingInvocation(logger, target, count, null);
Expand Down Expand Up @@ -664,6 +667,11 @@ public static void ErrorHandshakeCanceled(ILogger logger, Exception exception)
{
_errorHandshakeCanceled(logger, exception);
}

public static void ErroredStream(ILogger logger, string streamId, Exception exception)
{
_erroredStream(logger, streamId, exception);
}
}
}
}
Expand Down
35 changes: 20 additions & 15 deletions src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,10 @@ private void LaunchStreams(ConnectionState connectionState, Dictionary<string, o
return;
}

_state.AssertInConnectionLock();
// It's safe to access connectionState.UploadStreamToken as we still have the connection lock
var cts = CancellationTokenSource.CreateLinkedTokenSource(connectionState.UploadStreamToken, cancellationToken);

foreach (var kvp in readers)
{
var reader = kvp.Value;
Expand All @@ -708,19 +712,19 @@ private void LaunchStreams(ConnectionState connectionState, Dictionary<string, o
{
_ = _sendIAsyncStreamItemsMethod
.MakeGenericMethod(reader.GetType().GetInterface("IAsyncEnumerable`1")!.GetGenericArguments())
.Invoke(this, new object[] { connectionState, kvp.Key.ToString(), reader, cancellationToken });
.Invoke(this, new object[] { connectionState, kvp.Key.ToString(), reader, cts });
continue;
}
_ = _sendStreamItemsMethod
.MakeGenericMethod(reader.GetType().GetGenericArguments())
.Invoke(this, new object[] { connectionState, kvp.Key.ToString(), reader, cancellationToken });
.Invoke(this, new object[] { connectionState, kvp.Key.ToString(), reader, cts });
}
}

// this is called via reflection using the `_sendStreamItems` field
private Task SendStreamItems<T>(ConnectionState connectionState, string streamId, ChannelReader<T> reader, CancellationToken token)
private Task SendStreamItems<T>(ConnectionState connectionState, string streamId, ChannelReader<T> reader, CancellationTokenSource tokenSource)
{
async Task ReadChannelStream(CancellationTokenSource tokenSource)
async Task ReadChannelStream()
{
while (await reader.WaitToReadAsync(tokenSource.Token))
{
Expand All @@ -732,13 +736,13 @@ async Task ReadChannelStream(CancellationTokenSource tokenSource)
}
}

return CommonStreaming(connectionState, streamId, token, ReadChannelStream);
return CommonStreaming(connectionState, streamId, ReadChannelStream);
}

// this is called via reflection using the `_sendIAsyncStreamItemsMethod` field
private Task SendIAsyncEnumerableStreamItems<T>(ConnectionState connectionState, string streamId, IAsyncEnumerable<T> stream, CancellationToken token)
private Task SendIAsyncEnumerableStreamItems<T>(ConnectionState connectionState, string streamId, IAsyncEnumerable<T> stream, CancellationTokenSource tokenSource)
{
async Task ReadAsyncEnumerableStream(CancellationTokenSource tokenSource)
async Task ReadAsyncEnumerableStream()
{
var streamValues = AsyncEnumerableAdapters.MakeCancelableTypedAsyncEnumerable(stream, tokenSource);

Expand All @@ -749,25 +753,26 @@ async Task ReadAsyncEnumerableStream(CancellationTokenSource tokenSource)
}
}

return CommonStreaming(connectionState, streamId, token, ReadAsyncEnumerableStream);
return CommonStreaming(connectionState, streamId, ReadAsyncEnumerableStream);
}

private async Task CommonStreaming(ConnectionState connectionState, string streamId, CancellationToken token, Func<CancellationTokenSource, Task> createAndConsumeStream)
private async Task CommonStreaming(ConnectionState connectionState, string streamId, Func<Task> createAndConsumeStream)
{
// It's safe to access connectionState.UploadStreamToken as we still have the connection lock
_state.AssertInConnectionLock();
var cts = CancellationTokenSource.CreateLinkedTokenSource(connectionState.UploadStreamToken, token);

Log.StartingStream(_logger, streamId);
string? responseError = null;
try
{
await createAndConsumeStream(cts);
await createAndConsumeStream();
}
catch (OperationCanceledException)
{
Log.CancelingStream(_logger, streamId);
responseError = $"Stream canceled by client.";
responseError = "Stream canceled by client.";
}
catch (Exception ex)
{
Log.ErroredStream(_logger, streamId, ex);
responseError = $"Stream errored by client: '{ex}'";
}

Log.CompletingStream(_logger, streamId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,32 @@ public async Task UploadStreamCancellationSendsStreamComplete()
}
}

[Fact]
[LogLevel(LogLevel.Trace)]
public async Task UploadStreamErrorSendsStreamComplete()
{
using (StartVerifiableLog())
{
var connection = new TestConnection();
var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory);
await hubConnection.StartAsync().DefaultTimeout();

var cts = new CancellationTokenSource();
var channel = Channel.CreateUnbounded<int>();
var invokeTask = hubConnection.InvokeAsync<object>("UploadMethod", channel.Reader, cts.Token);

var invokeMessage = await connection.ReadSentJsonAsync().DefaultTimeout();
Assert.Equal(HubProtocolConstants.InvocationMessageType, invokeMessage["type"]);

channel.Writer.Complete(new Exception("error from client"));

// the next sent message should be a completion message
var complete = await connection.ReadSentJsonAsync().DefaultTimeout();
Assert.Equal(HubProtocolConstants.CompletionMessageType, complete["type"]);
Assert.StartsWith("Stream errored by client: 'System.Exception: error from client", ((string)complete["error"]));
}
}

[Fact]
[LogLevel(LogLevel.Trace)]
public async Task InvocationCanCompleteBeforeStreamCompletes()
Expand Down

0 comments on commit 2387c21

Please sign in to comment.