diff --git a/src/Bedrock.Framework/Protocols/WebSockets/WebSocketFrameReader.cs b/src/Bedrock.Framework/Protocols/WebSockets/WebSocketFrameReader.cs index 148037ef..2b7dc175 100644 --- a/src/Bedrock.Framework/Protocols/WebSockets/WebSocketFrameReader.cs +++ b/src/Bedrock.Framework/Protocols/WebSockets/WebSocketFrameReader.cs @@ -13,6 +13,11 @@ namespace Bedrock.Framework.Protocols.WebSockets /// public class WebSocketFrameReader : IMessageReader { + /// + /// An instance of the WebSocketFrameReader. + /// + private WebSocketPayloadReader _payloadReader; + /// /// Attempts to parse a message from a sequence. /// @@ -30,7 +35,7 @@ public bool TryParseMessage(in ReadOnlySequence input, ref SequencePositio return false; } - if (input.IsSingleSegment) + if (input.IsSingleSegment || input.FirstSpan.Length >= 14) { if (TryParseSpan(input.FirstSpan, input.Length, out var bytesRead, out message)) { @@ -123,8 +128,17 @@ private bool TryParseSpan(in ReadOnlySpan span, long inputLength, out int } var header = new WebSocketHeader(fin, opcode, masked, payloadLength, maskingKey); - message = new WebSocketReadFrame(header, new WebSocketPayloadReader(header)); + if(_payloadReader == null) + { + _payloadReader = new WebSocketPayloadReader(header); + } + else + { + _payloadReader.Reset(header); + } + + message = new WebSocketReadFrame(header, _payloadReader); bytesRead = 2 + extendedPayloadLengthSize + maskSize; return true; } diff --git a/src/Bedrock.Framework/Protocols/WebSockets/WebSocketMessageReader.cs b/src/Bedrock.Framework/Protocols/WebSockets/WebSocketMessageReader.cs index e1139a51..27543f4d 100644 --- a/src/Bedrock.Framework/Protocols/WebSockets/WebSocketMessageReader.cs +++ b/src/Bedrock.Framework/Protocols/WebSockets/WebSocketMessageReader.cs @@ -101,62 +101,91 @@ public WebSocketMessageReader(PipeReader transport, IControlFrameHandler control /// /// A cancellation token, if any. /// A message read result. - public async ValueTask ReadAsync(CancellationToken cancellationToken = default) + public ValueTask ReadAsync(CancellationToken cancellationToken = default) { if (_awaitingHeader) { - var frame = await GetNextMessageFrameAsync(cancellationToken).ConfigureAwait(false); - ValidateHeader(frame.Header); + var readTask = GetNextMessageFrameAsync(cancellationToken); + if(readTask.IsCompletedSuccessfully) + { + var frame = readTask.Result; + ValidateHeader(frame.Header); - _header = frame.Header; - _payloadReader = frame.Payload; + _header = frame.Header; + _payloadReader = frame.Payload; + } + else + { + return DoReadHeaderRequiredAsync(readTask, cancellationToken); + } } + return ReadPayloadAsync(cancellationToken); + } + + /// + /// Completes an async read when reading a header is required. + /// + /// The active async read task from the ProtocolReader. + /// A cancellation token. + /// A MessageReadResult. + private async ValueTask DoReadHeaderRequiredAsync(ValueTask readTask, CancellationToken cancellationToken) + { + var frame = await readTask.ConfigureAwait(false); + + ValidateHeader(frame.Header); + + _header = frame.Header; + _payloadReader = frame.Payload; + + return await ReadPayloadAsync(cancellationToken); + } + + /// + /// Reads a portion of a message payload. + /// + /// A cancellation token. + /// A MessageReadResult. + private ValueTask ReadPayloadAsync(CancellationToken cancellationToken) + { //Don't keep reading data into the buffer if we've hit a threshold //TODO: Is this even the right value to use in this context? if (_buffer.UnconsumedWrittenCount < _options.PauseWriterThreshold) { var readTask = _protocolReader.ReadAsync(_payloadReader, cancellationToken); - ProtocolReadResult> payloadSequence; - if (readTask.IsCompletedSuccessfully) { - payloadSequence = readTask.Result; + PopulateFromRead(readTask.Result); } else { - payloadSequence = await readTask; - } - - if (payloadSequence.IsCanceled) - { - throw new OperationCanceledException("Read canceled while attempting to read WebSocket payload."); - } - - var sequence = payloadSequence.Message; - - //If there is already data in the buffer, we'll need to add to it - if (_buffer.UnconsumedWrittenCount > 0) - { - if (sequence.IsSingleSegment) - { - _buffer.Write(sequence.FirstSpan); - } - else - { - foreach (var segment in sequence) - { - _buffer.Write(segment.Span); - } - } - } + return CreateMessageReadResultAsync(readTask, cancellationToken); + } + } - _currentSequence = payloadSequence.Message; - _isCompleted = payloadSequence.IsCompleted; - _isCanceled = payloadSequence.IsCanceled; + var endOfMessage = _header.Fin && _payloadReader.BytesRemaining == 0; - _awaitingHeader = _payloadReader.BytesRemaining == 0; + //Serve back buffered data, if it exists, else give the direct sequence without buffering + if (_buffer.UnconsumedWrittenCount > 0) + { + return new ValueTask( + new MessageReadResult(new ReadOnlySequence(_buffer.WrittenMemory), endOfMessage, _isCanceled, _isCompleted)); } + else + { + return new ValueTask(new MessageReadResult(_currentSequence, endOfMessage, _isCanceled, _isCompleted)); + } + } + + /// + /// Creates a new MessageReadResult asynchronously. + /// + /// The active read task from the ProtocolReader. + /// A cancellation token. + /// A new MessageReadResult. + private async ValueTask CreateMessageReadResultAsync(ValueTask>> readTask, CancellationToken cancellationToken) + { + PopulateFromRead(await readTask); var endOfMessage = _header.Fin && _payloadReader.BytesRemaining == 0; @@ -171,6 +200,42 @@ public async ValueTask ReadAsync(CancellationToken cancellati } } + /// + /// Populates the message reader from a payload read result. + /// + /// The read result to populate the message reader from. + private void PopulateFromRead(ProtocolReadResult> readResult) + { + if (readResult.IsCanceled) + { + throw new OperationCanceledException("Read canceled while attempting to read WebSocket payload."); + } + + var sequence = readResult.Message; + + //If there is already data in the buffer, we'll need to add to it + if (_buffer.UnconsumedWrittenCount > 0) + { + if (sequence.IsSingleSegment) + { + _buffer.Write(sequence.FirstSpan); + } + else + { + foreach (var segment in sequence) + { + _buffer.Write(segment.Span); + } + } + } + + _currentSequence = readResult.Message; + _isCompleted = readResult.IsCompleted; + _isCanceled = readResult.IsCanceled; + + _awaitingHeader = _payloadReader.BytesRemaining == 0; + } + /// /// Advances the reader to the provided position. /// @@ -223,15 +288,40 @@ public void AdvanceTo(SequencePosition consumed, SequencePosition examined) /// /// A cancellation token, if any. /// True if the message is text, false otherwise. - public async ValueTask MoveNextMessageAsync(CancellationToken cancellationToken = default) + public ValueTask MoveNextMessageAsync(CancellationToken cancellationToken = default) { if (_payloadReader is object && _payloadReader.BytesRemaining != 0) { throw new InvalidOperationException("MoveNextMessageAsync cannot be called while a message is still being read."); } - var frame = await GetNextMessageFrameAsync(cancellationToken); + var readTask = GetNextMessageFrameAsync(cancellationToken); + if (readTask.IsCompletedSuccessfully) + { + return new ValueTask(SetNextMessageAndGetIsText(readTask.Result)); + } + return DoSetNextMessageAsync(readTask); + } + + /// + /// Sets the next message frame asynchronously. + /// + /// The active ProtocolReader read task. + /// True if the next message is a text message, false otherwise. + private async ValueTask DoSetNextMessageAsync(ValueTask readTask) + { + return SetNextMessageAndGetIsText(await readTask); + } + + /// + /// Sets the message reader up with the next message frame data and determines if the message + /// is a text or binary message. + /// + /// The read frame to set the message reader with. + /// True if the next message is text, false otherwise. + private bool SetNextMessageAndGetIsText(WebSocketReadFrame frame) + { if (frame.Header.Opcode != WebSocketOpcode.Binary && frame.Header.Opcode != WebSocketOpcode.Text) { ThrowBadProtocol($"Expected a start of message frame of Binary or Text but received {frame.Header.Opcode} instead."); diff --git a/src/Bedrock.Framework/Protocols/WebSockets/WebSocketPayloadEncoder.cs b/src/Bedrock.Framework/Protocols/WebSockets/WebSocketPayloadEncoder.cs index f8edebd4..653add2f 100644 --- a/src/Bedrock.Framework/Protocols/WebSockets/WebSocketPayloadEncoder.cs +++ b/src/Bedrock.Framework/Protocols/WebSockets/WebSocketPayloadEncoder.cs @@ -13,7 +13,7 @@ namespace Bedrock.Framework.Protocols.WebSockets /// Masks or unmasks a WebSocket payload according to the provided masking key, tracking the /// masking key index accross mask or unmasking requests. /// - internal struct WebSocketPayloadEncoder + internal class WebSocketPayloadEncoder { /// /// The masking key to use to mask or unmask the payload. @@ -30,6 +30,15 @@ internal struct WebSocketPayloadEncoder /// /// The masking key to use to mask or unmask payloads. public WebSocketPayloadEncoder(int maskingKey) + { + Reset(maskingKey); + } + + /// + /// Resets the payload encoder. + /// + /// The masking key to use to mask or unmask payloads. + public void Reset(int maskingKey) { _maskingKey = maskingKey; _currentMaskIndex = 0; diff --git a/src/Bedrock.Framework/Protocols/WebSockets/WebSocketPayloadReader.cs b/src/Bedrock.Framework/Protocols/WebSockets/WebSocketPayloadReader.cs index 0f52a81f..a35e7bd3 100644 --- a/src/Bedrock.Framework/Protocols/WebSockets/WebSocketPayloadReader.cs +++ b/src/Bedrock.Framework/Protocols/WebSockets/WebSocketPayloadReader.cs @@ -43,6 +43,18 @@ public WebSocketPayloadReader(WebSocketHeader header) _masked = header.Masked; } + /// + /// Resets the payload reader. + /// + /// The WebSocketHeader associated with this payload. + public void Reset(WebSocketHeader header) + { + BytesRemaining = header.PayloadLength; + _masked = header.Masked; + + _payloadEncoder.Reset(header.MaskingKey); + } + /// /// Attempts to read the WebSocket payload from a sequence. /// diff --git a/tests/Bedrock.Framework.Benchmarks/Program.cs b/tests/Bedrock.Framework.Benchmarks/Program.cs index f997ccb0..e8e9daa7 100644 --- a/tests/Bedrock.Framework.Benchmarks/Program.cs +++ b/tests/Bedrock.Framework.Benchmarks/Program.cs @@ -1,5 +1,6 @@ using BenchmarkDotNet.Running; using System; +using System.Threading.Tasks; namespace Bedrock.Framework.Benchmarks { diff --git a/tests/Bedrock.Framework.Benchmarks/WebSocketProtocolBenchmarks.cs b/tests/Bedrock.Framework.Benchmarks/WebSocketProtocolBenchmarks.cs index b4398150..b9c799da 100644 --- a/tests/Bedrock.Framework.Benchmarks/WebSocketProtocolBenchmarks.cs +++ b/tests/Bedrock.Framework.Benchmarks/WebSocketProtocolBenchmarks.cs @@ -1,5 +1,6 @@ using Bedrock.Framework.Protocols.WebSockets; using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Configs; using Microsoft.AspNetCore.Connections; using System; using System.Buffers; @@ -13,21 +14,27 @@ namespace Bedrock.Framework.Benchmarks { + [GroupBenchmarksBy(BenchmarkLogicalGroupRule.ByCategory)] + [CategoriesColumn] public class WebSocketProtocolBenchmarks { - private WebSocket _webSocket; + private WebSocket _webSocketServer; - private WebSocketProtocol _webSocketProtocol; + private WebSocket _webSocketClient; - private DefaultConnectionContext _connectionContext; + private WebSocketProtocol _webSocketProtocolServer; - private MemoryStream _stream; + private WebSocketProtocol _webSocketProtocolClient; - private byte[] _message; + private DefaultConnectionContext _serverConnectionContext; - private ArraySegment _arrayBuffer; + private DefaultConnectionContext _clientConnectionContext; + + private MemoryStream _serverStream; - private ReadOnlyMemory _romBuffer; + private MemoryStream _clientStream; + + private ArraySegment _arrayBuffer; private class DummyPipeReader : PipeReader { @@ -64,50 +71,85 @@ private class DummyDuplexPipe : IDuplexPipe [GlobalSetup] public async ValueTask Setup() + { + var serverMessage = await GetMessageBytes(true, 4000); + var clientMessage = await GetMessageBytes(false, 4000); + + (_serverConnectionContext, _serverStream) = CreateContextAndStream(serverMessage); + (_clientConnectionContext, _clientStream) = CreateContextAndStream(clientMessage); + + _webSocketServer = WebSocket.CreateFromStream(_serverStream, true, null, TimeSpan.FromSeconds(30)); + _webSocketProtocolServer = new WebSocketProtocol(_serverConnectionContext, WebSocketProtocolType.Server); + + _webSocketClient = WebSocket.CreateFromStream(_clientStream, false, null, TimeSpan.FromSeconds(30)); + _webSocketProtocolClient = new WebSocketProtocol(_clientConnectionContext, WebSocketProtocolType.Server); + + _arrayBuffer = new ArraySegment(new byte[10000]); + } + + private async ValueTask GetMessageBytes(bool isMasked, long size) { var writer = new WebSocketFrameWriter(); var pipe = new Pipe(); - _message = new byte[4000]; - - var header = WebSocketHeader.CreateMasked(true, WebSocketOpcode.Binary, 4000); - writer.WriteMessage(new WebSocketWriteFrame(header, new ReadOnlySequence(_message)), pipe.Writer); + var header = new WebSocketHeader(true, WebSocketOpcode.Binary, isMasked, (ulong)size, isMasked ? WebSocketHeader.GenerateMaskingKey() : default); + writer.WriteMessage(new WebSocketWriteFrame(header, new ReadOnlySequence(new byte[4000])), pipe.Writer); await pipe.Writer.FlushAsync(); var result = await pipe.Reader.ReadAsync(); - _message = result.Buffer.ToArray(); + return result.Buffer.ToArray(); + } - var dummyReader = new DummyPipeReader { Result = new ReadResult(new ReadOnlySequence(_message), false, false) }; - var dummyDuplexPipe = new DummyDuplexPipe { DummyReader = dummyReader }; + private (DefaultConnectionContext context, MemoryStream stream) CreateContextAndStream(byte[] message) + { + var reader = new DummyPipeReader { Result = new ReadResult(new ReadOnlySequence(message), false, false) }; + var duplexPipe = new DummyDuplexPipe { DummyReader = reader }; - _connectionContext = new DefaultConnectionContext { Transport = dummyDuplexPipe }; - _stream = new MemoryStream(_message); + var stream = new MemoryStream(message); + var context = new DefaultConnectionContext { Transport = duplexPipe }; - _webSocket = WebSocket.CreateFromStream(_stream, true, null, TimeSpan.FromSeconds(30)); - _webSocketProtocol = new WebSocketProtocol(_connectionContext, WebSocketProtocolType.Server); + return (context, stream); + } - _arrayBuffer = new ArraySegment(new byte[10000]); - _romBuffer = new ReadOnlyMemory(_message); + [BenchmarkCategory("Masked"), Benchmark(Baseline = true)] + public async ValueTask WebSocketRead() + { + _clientStream.Seek(0, SeekOrigin.Begin); + + var endOfMessage = false; + while (!endOfMessage) + { + var result = await _webSocketClient.ReceiveAsync(_arrayBuffer, CancellationToken.None); + endOfMessage = result.EndOfMessage; + } } - [Benchmark(Baseline = true)] + [BenchmarkCategory("Unmasked"), Benchmark(Baseline = true)] public async ValueTask WebSocketReadMasked() { - _stream.Seek(0, SeekOrigin.Begin); + _serverStream.Seek(0, SeekOrigin.Begin); var endOfMessage = false; while (!endOfMessage) { - var result = await _webSocket.ReceiveAsync(_arrayBuffer, CancellationToken.None); + var result = await _webSocketServer.ReceiveAsync(_arrayBuffer, CancellationToken.None); endOfMessage = result.EndOfMessage; } } - [Benchmark] + [BenchmarkCategory("Masked"), Benchmark] + public async ValueTask WebSocketProtocolRead() + { + var message = await _webSocketProtocolClient.ReadAsync(); + var data = await message.Reader.ReadAsync(); + message.Reader.AdvanceTo(data.Data.End); + } + + [BenchmarkCategory("Unmasked"), Benchmark] public async ValueTask WebSocketProtocolReadMasked() { - var message = await _webSocketProtocol.ReadAsync(); + var message = await _webSocketProtocolServer.ReadAsync(); var data = await message.Reader.ReadAsync(); message.Reader.AdvanceTo(data.Data.End); }