diff --git a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs index 762d0df38119f0..790ad3feb2939d 100644 --- a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs +++ b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs @@ -89,6 +89,7 @@ internal QuicStream() { } public override bool CanTimeout { get { throw null; } } public override long Length { get { throw null; } } public override long Position { get { throw null; } set { } } + public bool ReadsCompleted { get { throw null; } } public long StreamId { get { throw null; } } public void AbortRead(long errorCode) { } public void AbortWrite(long errorCode) { } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs index 1b58009a2fd355..2c3d50a58e80e6 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs @@ -58,6 +58,8 @@ internal override int WriteTimeout internal override bool CanRead => !_disposed && ReadStreamBuffer is not null; + internal override bool ReadsCompleted => ReadStreamBuffer?.IsComplete ?? false; + internal override int Read(Span buffer) { CheckDisposed(); diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index 1aa31bd9873ca9..bbaf9c4ed41bb7 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Buffers; -using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Net.Quic.Implementations.MsQuic.Internal; @@ -50,6 +49,7 @@ private sealed class State public QuicBuffer[] ReceiveQuicBuffers = Array.Empty(); public int ReceiveQuicBuffersCount; public int ReceiveQuicBuffersTotalBytes; + public bool ReceiveIsFinal; // set when ReadState.PendingRead: public Memory ReceiveUserBuffer; @@ -193,6 +193,8 @@ internal MsQuicStream(MsQuicConnection.State connectionState, QUIC_STREAM_OPEN_F internal override bool CanWrite => _disposed == 0 && _canWrite; + internal override bool ReadsCompleted => _state.ReadState == ReadState.ReadsCompleted; + internal override bool CanTimeout => true; private int _readTimeout = Timeout.Infinite; @@ -415,13 +417,13 @@ internal override ValueTask ReadAsync(Memory destination, Cancellatio initialReadState = _state.ReadState; abortError = _state.ReadErrorCode; - // Failure scenario: pre-canceled token. Transition: any -> Aborted + // Failure scenario: pre-canceled token. Transition: Any non-final -> Aborted // PendingRead state indicates there is another concurrent read operation in flight // which is forbidden, so it is handled separately if (initialReadState != ReadState.PendingRead && cancellationToken.IsCancellationRequested) { initialReadState = ReadState.Aborted; - _state.ReadState = ReadState.Aborted; + CleanupReadStateAndCheckPending(_state, ReadState.Aborted); preCanceled = true; } @@ -442,16 +444,14 @@ internal override ValueTask ReadAsync(Memory destination, Cancellatio if (cancellationToken.CanBeCanceled) { + // Failure scenario: cancellation. Transition: Any non-final -> Aborted _state.ReceiveCancellationRegistration = cancellationToken.UnsafeRegister(static (obj, token) => { var state = (State)obj!; bool completePendingRead; lock (state) { - completePendingRead = state.ReadState == ReadState.PendingRead; - state.Stream = null; - state.ReceiveUserBuffer = null; - state.ReadState = ReadState.Aborted; + completePendingRead = CleanupReadStateAndCheckPending(state, ReadState.Aborted); } if (completePendingRead) @@ -468,7 +468,8 @@ internal override ValueTask ReadAsync(Memory destination, Cancellatio return _state.ReceiveResettableCompletionSource.GetValueTask(); } - // Success scenario: data already available, completing synchronously. Transition IndividualReadComplete->None + // Success scenario: data already available, completing synchronously. + // Transition IndividualReadComplete->None, or IndividualReadComplete->ReadsCompleted, if it was the last message and we fully consumed it if (initialReadState == ReadState.IndividualReadComplete) { _state.ReadState = ReadState.None; @@ -481,6 +482,11 @@ internal override ValueTask ReadAsync(Memory destination, Cancellatio // Need to re-enable receives because MsQuic will pause them when we don't consume the entire buffer. EnableReceive(); } + else if (_state.ReceiveIsFinal) + { + // This was a final message and we've consumed everything. We can complete the state without waiting for PEER_SEND_SHUTDOWN + _state.ReadState = ReadState.ReadsCompleted; + } return new ValueTask(taken); } @@ -512,7 +518,10 @@ internal override ValueTask ReadAsync(Memory destination, Cancellatio /// The number of bytes copied. private static unsafe int CopyMsQuicBuffersToUserBuffer(ReadOnlySpan sourceBuffers, Span destinationBuffer) { - Debug.Assert(sourceBuffers.Length != 0); + if (sourceBuffers.Length == 0) + { + return 0; + } int originalDestinationLength = destinationBuffer.Length; QuicBuffer nativeBuffer; @@ -543,16 +552,7 @@ internal override void AbortRead(long errorCode) bool shouldComplete = false; lock (_state) { - if (_state.ReadState == ReadState.PendingRead) - { - shouldComplete = true; - _state.Stream = null; - _state.ReceiveUserBuffer = null; - } - if (_state.ReadState < ReadState.ReadsCompleted) - { - _state.ReadState = ReadState.Aborted; - } + shouldComplete = CleanupReadStateAndCheckPending(_state, ReadState.Aborted); } if (shouldComplete) @@ -754,9 +754,7 @@ private void Dispose(bool disposing) if (_state.ReadState < ReadState.ReadsCompleted || _state.ReadState == ReadState.Aborted) { abortRead = true; - completeRead = _state.ReadState == ReadState.PendingRead; - _state.Stream = null; - _state.ReadState = ReadState.Aborted; + completeRead = CleanupReadStateAndCheckPending(_state, ReadState.Aborted); } if (_state.ShutdownState == ShutdownState.None) @@ -881,11 +879,9 @@ private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt) { ref StreamEventDataReceive receiveEvent = ref evt.Data.Receive; - if (receiveEvent.BufferCount == 0) + if (NetEventSource.Log.IsEnabled()) { - // This is a 0-length receive that happens once reads are finished (via abort or otherwise). - // State changes for this are handled in PEER_SEND_SHUTDOWN / PEER_SEND_ABORT / SHUTDOWN_COMPLETE event handlers. - return MsQuicStatusCodes.Success; + NetEventSource.Info(state, $"{state.TraceId} Stream received {receiveEvent.TotalBufferLength} bytes{(receiveEvent.Flags.HasFlag(QUIC_RECEIVE_FLAGS.FIN) ? " with FIN flag" : "")}"); } int readLength; @@ -922,8 +918,27 @@ private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt) state.ReceiveQuicBuffersCount = (int)receiveEvent.BufferCount; state.ReceiveQuicBuffersTotalBytes = checked((int)receiveEvent.TotalBufferLength); - state.ReadState = ReadState.IndividualReadComplete; - return MsQuicStatusCodes.Pending; + state.ReceiveIsFinal = receiveEvent.Flags.HasFlag(QUIC_RECEIVE_FLAGS.FIN); + + // 0-length receive can happens once reads are finished (gracefully or otherwise). + if (state.ReceiveQuicBuffersTotalBytes == 0) + { + if (state.ReceiveIsFinal) + { + // We can complete the state without waiting for PEER_SEND_SHUTDOWN + state.ReadState = ReadState.ReadsCompleted; + } + + // if it was not a graceful shutdown, we defer aborting to PEER_SEND_ABORT event handler + return MsQuicStatusCodes.Success; + } + else + { + // Normal RECEIVE - data will be buffered until user calls ReadAsync() and no new event will be issued until EnableReceive() + state.ReadState = ReadState.IndividualReadComplete; + return MsQuicStatusCodes.Pending; + } + case ReadState.PendingRead: // There is a pending ReadAsync(). @@ -933,8 +948,17 @@ private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt) state.ReadState = ReadState.None; readLength = CopyMsQuicBuffersToUserBuffer(new ReadOnlySpan(receiveEvent.Buffers, (int)receiveEvent.BufferCount), state.ReceiveUserBuffer.Span); + + // This was a final message and we've consumed everything. We can complete the state without waiting for PEER_SEND_SHUTDOWN + if (receiveEvent.Flags.HasFlag(QUIC_RECEIVE_FLAGS.FIN) && (uint)readLength == receiveEvent.TotalBufferLength) + { + state.ReadState = ReadState.ReadsCompleted; + } + // Else, if this was a final message, but we haven't consumed it fully, FIN flag will arrive again in the next RECEIVE event + state.ReceiveUserBuffer = null; break; + default: Debug.Assert(state.ReadState is ReadState.Aborted or ReadState.ConnectionClosed, $"Unexpected {nameof(ReadState)} '{state.ReadState}' in {nameof(HandleEventRecv)}."); @@ -1008,16 +1032,7 @@ private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt // This event won't occur within the middle of a receive. if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"{state.TraceId} Stream completing resettable event source."); - if (state.ReadState == ReadState.PendingRead) - { - shouldReadComplete = true; - state.Stream = null; - state.ReceiveUserBuffer = null; - } - if (state.ReadState < ReadState.ReadsCompleted) - { - state.ReadState = ReadState.ReadsCompleted; - } + shouldReadComplete = CleanupReadStateAndCheckPending(state, ReadState.ReadsCompleted); if (state.ShutdownState == ShutdownState.None) { @@ -1051,13 +1066,7 @@ private static uint HandleEventPeerSendAborted(State state, ref StreamEvent evt) bool shouldComplete = false; lock (state) { - if (state.ReadState == ReadState.PendingRead) - { - shouldComplete = true; - state.Stream = null; - state.ReceiveUserBuffer = null; - } - state.ReadState = ReadState.Aborted; + shouldComplete = CleanupReadStateAndCheckPending(state, ReadState.Aborted); state.ReadErrorCode = (long)evt.Data.PeerSendAborted.ErrorCode; } @@ -1079,16 +1088,7 @@ private static uint HandleEventPeerSendShutdown(State state) // This event won't occur within the middle of a receive. if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"{state.TraceId} Stream completing resettable event source."); - if (state.ReadState == ReadState.PendingRead) - { - shouldComplete = true; - state.Stream = null; - state.ReceiveUserBuffer = null; - } - if (state.ReadState < ReadState.ReadsCompleted) - { - state.ReadState = ReadState.ReadsCompleted; - } + shouldComplete = CleanupReadStateAndCheckPending(state, ReadState.ReadsCompleted); } if (shouldComplete) @@ -1378,11 +1378,7 @@ private static uint HandleEventConnectionClose(State state) lock (state) { - shouldCompleteRead = state.ReadState == ReadState.PendingRead; - if (state.ReadState < ReadState.ReadsCompleted) - { - state.ReadState = ReadState.ConnectionClosed; - } + shouldCompleteRead = CleanupReadStateAndCheckPending(state, ReadState.ConnectionClosed); if (state.SendState == SendState.None || state.SendState == SendState.Pending) { @@ -1428,15 +1424,47 @@ private static uint HandleEventConnectionClose(State state) private static Exception GetConnectionAbortedException(State state) => ThrowHelper.GetConnectionAbortedException(state.ConnectionState.AbortErrorCode); + private static bool CleanupReadStateAndCheckPending(State state, ReadState finalState) + { + Debug.Assert(finalState >= ReadState.ReadsCompleted, $"Expected final read state, got {finalState}"); + Debug.Assert(Monitor.IsEntered(state)); + + bool shouldComplete = false; + if (state.ReadState == ReadState.PendingRead) + { + shouldComplete = true; + state.Stream = null; + state.ReceiveUserBuffer = null; + state.ReceiveCancellationRegistration.Unregister(); + } + if (state.ReadState < ReadState.ReadsCompleted) + { + state.ReadState = finalState; + } + return shouldComplete; + } + // Read state transitions: // - // None --(data arrives in event RECV)-> IndividualReadComplete --(user calls ReadAsync() & completes syncronously)-> None - // None --(user calls ReadAsync() & waits)-> PendingRead --(data arrives in event RECV & completes user's ReadAsync())-> None + // None --(data arrives in event RECV)-> IndividualReadComplete + // None --(data arrives in event RECV with FIN flag)-> IndividualReadComplete(+FIN) + // None --(0-byte data arrives in event RECV with FIN flag)-> ReadsCompleted + // None --(user calls ReadAsync() & waits)-> PendingRead + // + // IndividualReadComplete --(user calls ReadAsync())-> None + // IndividualReadComplete(+FIN) --(user calls ReadAsync() & consumes only partial data)-> None + // IndividualReadComplete(+FIN) --(user calls ReadAsync() & consumes full data)-> ReadsCompleted + // + // PendingRead --(data arrives in event RECV & completes user's ReadAsync())-> None + // PendingRead --(data arrives in event RECV with FIN flag & completes user's ReadAsync() with only partial data)-> None + // PendingRead --(data arrives in event RECV with FIN flag & completes user's ReadAsync() with full data)-> ReadsCompleted + // // Any non-final state --(event PEER_SEND_SHUTDOWN or SHUTDOWN_COMPLETED with ConnectionClosed=false)-> ReadsCompleted // Any non-final state --(event PEER_SEND_ABORT)-> Aborted // Any non-final state --(user calls AbortRead())-> Aborted - // Any state --(CancellationToken's cancellation for ReadAsync())-> Aborted (TODO: should it be only for non-final as others?) + // Any non-final state --(CancellationToken's cancellation for ReadAsync())-> Aborted // Any non-final state --(event SHUTDOWN_COMPLETED with ConnectionClosed=true)-> ConnectionClosed + // // Closed - no transitions, set for Unidirectional write-only streams private enum ReadState { diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs index f011e561855ee3..66c9a8b6e51c2d 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs @@ -15,6 +15,8 @@ internal abstract class QuicStreamProvider : IDisposable, IAsyncDisposable internal abstract bool CanRead { get; } + internal abstract bool ReadsCompleted { get; } + internal abstract int ReadTimeout { get; set; } internal abstract int Read(Span buffer); diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs index 55ba9953260f03..8a6dbe496ed4a4 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs @@ -71,6 +71,8 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati public override bool CanRead => _provider.CanRead; + public bool ReadsCompleted => _provider.ReadsCompleted; + public override int Read(Span buffer) => _provider.Read(buffer); public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) => _provider.ReadAsync(buffer, cancellationToken); diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs index 253707dea2688a..74d2b978a725b6 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs @@ -4,6 +4,7 @@ using System.Buffers; using System.Collections.Generic; using System.Diagnostics; +using System.Diagnostics.Tracing; using System.Linq; using System.Net.Security; using System.Net.Sockets; @@ -22,7 +23,7 @@ namespace System.Net.Quic.Tests [Collection("NoParallelTests")] public class MsQuicTests : QuicTestBase { - private static ReadOnlyMemory s_data = Encoding.UTF8.GetBytes("Hello world!"); + private static byte[] s_data = Encoding.UTF8.GetBytes("Hello world!"); public MsQuicTests(ITestOutputHelper output) : base(output) { } @@ -788,5 +789,69 @@ public async Task BigWrite_SmallRead_Success(bool closeWithData) } } } + + [Fact] + public async Task BasicTest_WithReadsCompletedCheck() + { + await RunClientServer( + iterations: 100, + serverFunction: async connection => + { + using QuicStream stream = await connection.AcceptStreamAsync(); + Assert.False(stream.ReadsCompleted); + + byte[] buffer = new byte[s_data.Length]; + int bytesRead = await ReadAll(stream, buffer); + + Assert.True(stream.ReadsCompleted); + Assert.Equal(s_data.Length, bytesRead); + Assert.Equal(s_data, buffer); + + await stream.WriteAsync(s_data, endStream: true); + await stream.ShutdownCompleted(); + }, + clientFunction: async connection => + { + using QuicStream stream = connection.OpenBidirectionalStream(); + Assert.False(stream.ReadsCompleted); + + await stream.WriteAsync(s_data, endStream: true); + + byte[] buffer = new byte[s_data.Length]; + int bytesRead = await ReadAll(stream, buffer); + + Assert.True(stream.ReadsCompleted); + Assert.Equal(s_data.Length, bytesRead); + Assert.Equal(s_data, buffer); + + await stream.ShutdownCompleted(); + } + ); + } + + [Fact] + public async Task Read_ReadsCompleted_ReportedBeforeReturning0() + { + await RunBidirectionalClientServer( + async clientStream => + { + await clientStream.WriteAsync(new byte[1], endStream: true); + }, + async serverStream => + { + Assert.False(serverStream.ReadsCompleted); + + var received = await serverStream.ReadAsync(new byte[1]); + Assert.Equal(1, received); + Assert.True(serverStream.ReadsCompleted); + + var task = serverStream.ReadAsync(new byte[1]); + Assert.True(task.IsCompleted); + + received = await task; + Assert.Equal(0, received); + Assert.True(serverStream.ReadsCompleted); + }); + } } }