Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Implement Socket.Send/ReceiveAsync cancellation support #34212

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/System.Net.Sockets/src/System.Net.Sockets.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@
<Compile Include="$(CommonPath)\Interop\Windows\Winsock\WSABuffer.cs">
<Link>Interop\Windows\Winsock\WSABuffer.cs</Link>
</Compile>
<Compile Include="$(CommonPath)\CoreLib\Interop\Windows\Kernel32\Interop.CancelIoEx.cs">
<Link>Common\Interop\Windows\Interop.CancelIoEx.cs</Link>
</Compile>
<Compile Include="$(CommonPath)\Interop\Windows\kernel32\Interop.SetFileCompletionNotificationModes.cs">
<Link>Interop\Windows\kernel32\Interop.SetFileCompletionNotificationModes.cs</Link>
</Compile>
Expand Down
183 changes: 147 additions & 36 deletions src/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,6 @@ internal Task<int> ReceiveAsync(ArraySegment<byte> buffer, SocketFlags socketFla
return ReceiveAsync((Memory<byte>)buffer, socketFlags, fromNetworkStream, default).AsTask();
}

// TODO https://github.com/dotnet/corefx/issues/24430:
// Fully plumb cancellation down into socket operations.

internal ValueTask<int> ReceiveAsync(Memory<byte> buffer, SocketFlags socketFlags, bool fromNetworkStream, CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
Expand All @@ -201,14 +198,16 @@ internal ValueTask<int> ReceiveAsync(Memory<byte> buffer, SocketFlags socketFlag
saea.SetBuffer(buffer);
saea.SocketFlags = socketFlags;
saea.WrapExceptionsInIOExceptions = fromNetworkStream;
return saea.ReceiveAsync(this);
}
else
{
// We couldn't get a cached instance, due to a concurrent receive operation on the socket.
// Fall back to wrapping APM.
return new ValueTask<int>(ReceiveAsyncApm(buffer, socketFlags));
return saea.ReceiveAsync(this, cancellationToken);
}

// We couldn't get a cached instance, due to a concurrent receive operation on the socket.
// Fall back to wrapping APM.
Task<int> apmTask = ReceiveAsyncApm(buffer, socketFlags);
return new ValueTask<int>(
cancellationToken.CanBeCanceled && !apmTask.IsCompleted ?
WaitWithCancellationAsync(apmTask, cancellationToken) :
apmTask);
}

/// <summary>Implements Task-returning ReceiveAsync on top of Begin/EndReceive.</summary>
Expand Down Expand Up @@ -350,14 +349,16 @@ internal ValueTask<int> SendAsync(ReadOnlyMemory<byte> buffer, SocketFlags socke
saea.SetBuffer(MemoryMarshal.AsMemory(buffer));
saea.SocketFlags = socketFlags;
saea.WrapExceptionsInIOExceptions = false;
return saea.SendAsync(this);
}
else
{
// We couldn't get a cached instance, due to a concurrent send operation on the socket.
// Fall back to wrapping APM.
return new ValueTask<int>(SendAsyncApm(buffer, socketFlags));
return saea.SendAsync(this, cancellationToken);
}

// We couldn't get a cached instance, due to a concurrent send operation on the socket.
// Fall back to wrapping APM.
Task<int> apmTask = SendAsyncApm(buffer, socketFlags);
return new ValueTask<int>(
cancellationToken.CanBeCanceled && !apmTask.IsCompleted ?
WaitWithCancellationAsync(apmTask, cancellationToken) :
apmTask);
}

internal ValueTask SendAsyncForNetworkStream(ReadOnlyMemory<byte> buffer, SocketFlags socketFlags, CancellationToken cancellationToken)
Expand All @@ -374,14 +375,16 @@ internal ValueTask SendAsyncForNetworkStream(ReadOnlyMemory<byte> buffer, Socket
saea.SetBuffer(MemoryMarshal.AsMemory(buffer));
saea.SocketFlags = socketFlags;
saea.WrapExceptionsInIOExceptions = true;
return saea.SendAsyncForNetworkStream(this);
}
else
{
// We couldn't get a cached instance, due to a concurrent send operation on the socket.
// Fall back to wrapping APM.
return new ValueTask(SendAsyncApm(buffer, socketFlags));
return saea.SendAsyncForNetworkStream(this, cancellationToken);
}

// We couldn't get a cached instance, due to a concurrent send operation on the socket.
// Fall back to wrapping APM.
Task<int> apmTask = SendAsyncApm(buffer, socketFlags);
return new ValueTask(
cancellationToken.CanBeCanceled && !apmTask.IsCompleted ?
WaitWithCancellationAsync(apmTask, cancellationToken) :
apmTask);
}

/// <summary>Implements Task-returning SendAsync on top of Begin/EndSend.</summary>
Expand Down Expand Up @@ -470,6 +473,27 @@ internal Task<int> SendToAsync(ArraySegment<byte> buffer, SocketFlags socketFlag
return tcs.Task;
}

/// <summary>Waits for the provided task to complete, and in the interim cancels pending operations on the socket if cancellation is requested.</summary>
private async Task<int> WaitWithCancellationAsync(Task<int> task, CancellationToken cancellationToken)
{
Debug.Assert(cancellationToken.CanBeCanceled);
using (cancellationToken.UnsafeRegister(s => SocketPal.CancelPendingOperations((Socket)s), this))
{
try
{
return await task.ConfigureAwait(false);
}
catch (SocketException se)
{
if (se.SocketErrorCode == SocketError.OperationAborted)
{
cancellationToken.ThrowIfCancellationRequested();
}
throw;
}
}
}

/// <summary>Validates the supplied array segment, throwing if its array or indices are null or out-of-bounds, respectively.</summary>
private static void ValidateBuffer(ArraySegment<byte> buffer)
{
Expand Down Expand Up @@ -819,7 +843,9 @@ internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IVal
/// when the operation does complete.
/// </summary>
private Action<object> _continuation = s_availableSentinel;
/// <summary>Captured ExecutionContext to use when invoking the continuation.</summary>
private ExecutionContext _executionContext;
/// <summary>Context or scheduler to which continuations should be queued.</summary>
private object _scheduler;
/// <summary>Current token value given to a ValueTask and then verified against the value it passes back to us.</summary>
/// <remarks>
Expand All @@ -828,6 +854,19 @@ internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IVal
/// it's already being reused by someone else.
/// </remarks>
private short _token;
/// <summary>CancellationToken associated with the current operation.</summary>
private CancellationToken _cancellationToken;
/// <summary>Cancellation registration for the current operation.</summary>
private CancellationTokenRegistration _cancellationRegistration;
/// <summary>Flag used to communicate between the threads initiating and completing operations.</summary>
private byte _cancellationState;

/// <summary>Indicates that cancelation isn't applicable for the operation.</summary>
private const byte CancellationState_None = 0;
/// <summary>Indicates that a cancellation callback is in the process of being registered.</summary>
private const byte CancellationState_Registering = 1;
/// <summary>Indicates that a cancellation callback has been registered.</summary>
private const byte CancellationState_Registered = 2;

/// <summary>Initializes the event args.</summary>
public AwaitableSocketAsyncEventArgs() :
Expand All @@ -842,12 +881,26 @@ public bool Reserve() =>

private void Release()
{
_cancellationToken = default;
_token++;
Volatile.Write(ref _continuation, s_availableSentinel);
}

protected override void OnCompleted(SocketAsyncEventArgs _)
{
// _cancellationState will be None if cancellation can't happen,
// Registering if it's in the process of being registered, and
// Registered if it's been registered. If cancellation may happen,
// we spin until it's been registered, so that we can safely unregister.
// That way, there's no concern about leaking a registration.
if (_cancellationState != CancellationState_None)
{
var sw = new SpinWait();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to call this only when needed? (unless it is really cheap). I would expect that cases when we have contention on state would be rare.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's cheap. SpinWait is a struct, so

var sw = new SpinWait();

is identical to:

SpinWait sw = default;

while (Volatile.Read(ref _cancellationState) == CancellationState_Registering) sw.SpinOnce();
_cancellationRegistration.Dispose();
_cancellationState = CancellationState_None;
}

// When the operation completes, see if OnCompleted was already called to hook up a continuation.
// If it was, invoke the continuation.
Action<object> c = _continuation;
Expand Down Expand Up @@ -882,13 +935,29 @@ protected override void OnCompleted(SocketAsyncEventArgs _)

/// <summary>Initiates a receive operation on the associated socket.</summary>
/// <returns>This instance.</returns>
public ValueTask<int> ReceiveAsync(Socket socket)
public ValueTask<int> ReceiveAsync(Socket socket, CancellationToken cancellationToken)
{
Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use");

if (socket.ReceiveAsync(this))
if (cancellationToken.CanBeCanceled)
{
return new ValueTask<int>(this, _token);
_cancellationState = CancellationState_Registering;
if (socket.ReceiveAsync(this))
{
// Store the token and the registration, then let the callback know that it can safely dispose of the registration.
_cancellationToken = cancellationToken;
_cancellationRegistration = cancellationToken.UnsafeRegister(s => SocketPal.CancelPendingOperations((Socket)s), socket);
Volatile.Write(ref _cancellationState, CancellationState_Registered);
return new ValueTask<int>(this, _token);
}
}
else
{
_cancellationState = CancellationState_None;
if (socket.ReceiveAsync(this))
{
return new ValueTask<int>(this, _token);
}
}

int bytesTransferred = BytesTransferred;
Expand All @@ -903,13 +972,29 @@ public ValueTask<int> ReceiveAsync(Socket socket)

/// <summary>Initiates a send operation on the associated socket.</summary>
/// <returns>This instance.</returns>
public ValueTask<int> SendAsync(Socket socket)
public ValueTask<int> SendAsync(Socket socket, CancellationToken cancellationToken)
{
Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use");

if (socket.SendAsync(this))
if (cancellationToken.CanBeCanceled)
{
_cancellationState = CancellationState_Registering;
if (socket.SendAsync(this))
{
// Store the token and the registration, then let the callback know that it can safely dispose of the registration.
_cancellationToken = cancellationToken;
_cancellationRegistration = cancellationToken.UnsafeRegister(s => SocketPal.CancelPendingOperations((Socket)s), socket);
Volatile.Write(ref _cancellationState, CancellationState_Registered);
return new ValueTask<int>(this, _token);
}
}
else
{
return new ValueTask<int>(this, _token);
_cancellationState = CancellationState_None;
if (socket.SendAsync(this))
{
return new ValueTask<int>(this, _token);
}
}

int bytesTransferred = BytesTransferred;
Expand All @@ -922,13 +1007,29 @@ public ValueTask<int> SendAsync(Socket socket)
new ValueTask<int>(Task.FromException<int>(CreateException(error)));
}

public ValueTask SendAsyncForNetworkStream(Socket socket)
public ValueTask SendAsyncForNetworkStream(Socket socket, CancellationToken cancellationToken)
{
Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use");

if (socket.SendAsync(this))
if (cancellationToken.CanBeCanceled)
{
return new ValueTask(this, _token);
_cancellationState = CancellationState_Registering;
if (socket.SendAsync(this))
{
// Store the token and the registration, then let the callback know that it can safely dispose of the registration.
_cancellationToken = cancellationToken;
_cancellationRegistration = cancellationToken.UnsafeRegister(s => SocketPal.CancelPendingOperations((Socket)s), socket);
Volatile.Write(ref _cancellationState, CancellationState_Registered);
return new ValueTask(this, _token);
}
}
else
{
_cancellationState = CancellationState_None;
if (socket.SendAsync(this))
{
return new ValueTask(this, _token);
}
}

SocketError error = SocketError;
Expand Down Expand Up @@ -1051,12 +1152,13 @@ public int GetResult(short token)

SocketError error = SocketError;
int bytes = BytesTransferred;
CancellationToken cancellationToken = _cancellationToken;

Release();

if (error != SocketError.Success)
{
ThrowException(error);
ThrowException(error, cancellationToken);
}
return bytes;
}
Expand All @@ -1069,20 +1171,29 @@ void IValueTaskSource.GetResult(short token)
}

SocketError error = SocketError;
CancellationToken cancellationToken = _cancellationToken;

Release();

if (error != SocketError.Success)
{
ThrowException(error);
ThrowException(error, cancellationToken);
}
}

private void ThrowIncorrectTokenException() => throw new InvalidOperationException(SR.InvalidOperation_IncorrectToken);

private void ThrowMultipleContinuationsException() => throw new InvalidOperationException(SR.InvalidOperation_MultipleContinuations);

private void ThrowException(SocketError error) => throw CreateException(error);
private void ThrowException(SocketError error, CancellationToken cancellationToken)
{
if (error == SocketError.OperationAborted)
{
cancellationToken.ThrowIfCancellationRequested();
}

throw CreateException(error);
}

private Exception CreateException(SocketError error)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1067,19 +1067,21 @@ public void CancelAndContinueProcessing(TOperation op)
nextOp?.Dispatch();
}

// Called when the socket is closed.
public void StopAndAbort(SocketAsyncContext context)
// Called to cancel all pending items, either in response to a cancellation
// request or in response to closing the socket.
public void Cancel(SocketAsyncContext context, bool close)
{
// We should be called exactly once, by SafeSocketHandle.
Debug.Assert(_state != QueueState.Stopped);

using (Lock())
{
Trace(context, $"Enter");

Debug.Assert(_state != QueueState.Stopped);

_state = QueueState.Stopped;
if (close)
{
Debug.Assert(_state != QueueState.Stopped);
_state = QueueState.Stopped;
}

if (_tail != null)
{
Expand Down Expand Up @@ -1160,9 +1162,7 @@ private void Register()

public void Close()
{
// Drain queues
_sendQueue.StopAndAbort(this);
_receiveQueue.StopAndAbort(this);
Cancel(close: true);

lock (_registerLock)
{
Expand All @@ -1172,6 +1172,13 @@ public void Close()
}
}

public void Cancel(bool close)
{
// Drain queues
_sendQueue.Cancel(this, close);
_receiveQueue.Cancel(this, close);
}

public void SetNonBlocking()
{
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,11 @@ public static unsafe SocketError Bind(SafeSocketHandle handle, ProtocolType sock
return err == Interop.Error.SUCCESS ? SocketError.Success : GetSocketErrorForErrorCode(err);
}

public static void CancelPendingOperations(Socket socket)
{
socket.SafeHandle.AsyncContext.Cancel(close: false);
}

public static SocketError Listen(SafeSocketHandle handle, int backlog)
{
Interop.Error err = Interop.Sys.Listen(handle, backlog);
Expand Down
Loading