Skip to content

Commit

Permalink
Removed all state machine creation for sync path in message reader.
Browse files Browse the repository at this point in the history
  • Loading branch information
mattnischan committed Feb 2, 2020
1 parent 4cbc411 commit c373ed0
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 67 deletions.
18 changes: 16 additions & 2 deletions src/Bedrock.Framework/Protocols/WebSockets/WebSocketFrameReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ namespace Bedrock.Framework.Protocols.WebSockets
/// </summary>
public class WebSocketFrameReader : IMessageReader<WebSocketReadFrame>
{
/// <summary>
/// An instance of the WebSocketFrameReader.
/// </summary>
private WebSocketPayloadReader _payloadReader;

/// <summary>
/// Attempts to parse a message from a sequence.
/// </summary>
Expand All @@ -30,7 +35,7 @@ public bool TryParseMessage(in ReadOnlySequence<byte> input, ref SequencePositio
return false;
}

if (input.IsSingleSegment)
if (input.IsSingleSegment || input.FirstSpan.Length >= 14)
{
if (TryParseSpan(input.FirstSpan, input.Length, out var bytesRead, out message))
{
Expand Down Expand Up @@ -123,8 +128,17 @@ private bool TryParseSpan(in ReadOnlySpan<byte> span, long inputLength, out int
}

var header = new WebSocketHeader(fin, opcode, masked, payloadLength, maskingKey);
message = new WebSocketReadFrame(header, new WebSocketPayloadReader(header));

if(_payloadReader == null)
{
_payloadReader = new WebSocketPayloadReader(header);
}
else
{
_payloadReader.Reset(header);
}

message = new WebSocketReadFrame(header, _payloadReader);
bytesRead = 2 + extendedPayloadLengthSize + maskSize;
return true;
}
Expand Down
168 changes: 129 additions & 39 deletions src/Bedrock.Framework/Protocols/WebSockets/WebSocketMessageReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,62 +101,91 @@ public WebSocketMessageReader(PipeReader transport, IControlFrameHandler control
/// </summary>
/// <param name="cancellationToken">A cancellation token, if any.</param>
/// <returns>A message read result.</returns>
public async ValueTask<MessageReadResult> ReadAsync(CancellationToken cancellationToken = default)
public ValueTask<MessageReadResult> ReadAsync(CancellationToken cancellationToken = default)
{
if (_awaitingHeader)
{
var frame = await GetNextMessageFrameAsync(cancellationToken).ConfigureAwait(false);
ValidateHeader(frame.Header);
var readTask = GetNextMessageFrameAsync(cancellationToken);
if(readTask.IsCompletedSuccessfully)
{
var frame = readTask.Result;
ValidateHeader(frame.Header);

_header = frame.Header;
_payloadReader = frame.Payload;
_header = frame.Header;
_payloadReader = frame.Payload;
}
else
{
return DoReadHeaderRequiredAsync(readTask, cancellationToken);
}
}

return ReadPayloadAsync(cancellationToken);
}

/// <summary>
/// Completes an async read when reading a header is required.
/// </summary>
/// <param name="readTask">The active async read task from the ProtocolReader.</param>
/// <param name="cancellationToken">A cancellation token.</param>
/// <returns>A MessageReadResult.</returns>
private async ValueTask<MessageReadResult> DoReadHeaderRequiredAsync(ValueTask<WebSocketReadFrame> readTask, CancellationToken cancellationToken)
{
var frame = await readTask.ConfigureAwait(false);

ValidateHeader(frame.Header);

_header = frame.Header;
_payloadReader = frame.Payload;

return await ReadPayloadAsync(cancellationToken);
}

/// <summary>
/// Reads a portion of a message payload.
/// </summary>
/// <param name="cancellationToken">A cancellation token.</param>
/// <returns>A MessageReadResult.</returns>
private ValueTask<MessageReadResult> ReadPayloadAsync(CancellationToken cancellationToken)
{
//Don't keep reading data into the buffer if we've hit a threshold
//TODO: Is this even the right value to use in this context?
if (_buffer.UnconsumedWrittenCount < _options.PauseWriterThreshold)
{
var readTask = _protocolReader.ReadAsync(_payloadReader, cancellationToken);
ProtocolReadResult<ReadOnlySequence<byte>> payloadSequence;

if (readTask.IsCompletedSuccessfully)
{
payloadSequence = readTask.Result;
PopulateFromRead(readTask.Result);
}
else
{
payloadSequence = await readTask;
}

if (payloadSequence.IsCanceled)
{
throw new OperationCanceledException("Read canceled while attempting to read WebSocket payload.");
}

var sequence = payloadSequence.Message;

//If there is already data in the buffer, we'll need to add to it
if (_buffer.UnconsumedWrittenCount > 0)
{
if (sequence.IsSingleSegment)
{
_buffer.Write(sequence.FirstSpan);
}
else
{
foreach (var segment in sequence)
{
_buffer.Write(segment.Span);
}
}
}
return CreateMessageReadResultAsync(readTask, cancellationToken);
}
}

_currentSequence = payloadSequence.Message;
_isCompleted = payloadSequence.IsCompleted;
_isCanceled = payloadSequence.IsCanceled;
var endOfMessage = _header.Fin && _payloadReader.BytesRemaining == 0;

_awaitingHeader = _payloadReader.BytesRemaining == 0;
//Serve back buffered data, if it exists, else give the direct sequence without buffering
if (_buffer.UnconsumedWrittenCount > 0)
{
return new ValueTask<MessageReadResult>(
new MessageReadResult(new ReadOnlySequence<byte>(_buffer.WrittenMemory), endOfMessage, _isCanceled, _isCompleted));
}
else
{
return new ValueTask<MessageReadResult>(new MessageReadResult(_currentSequence, endOfMessage, _isCanceled, _isCompleted));
}
}

/// <summary>
/// Creates a new MessageReadResult asynchronously.
/// </summary>
/// <param name="readTask">The active read task from the ProtocolReader.</param>
/// <param name="cancellationToken">A cancellation token.</param>
/// <returns>A new MessageReadResult.</returns>
private async ValueTask<MessageReadResult> CreateMessageReadResultAsync(ValueTask<ProtocolReadResult<ReadOnlySequence<byte>>> readTask, CancellationToken cancellationToken)
{
PopulateFromRead(await readTask);

var endOfMessage = _header.Fin && _payloadReader.BytesRemaining == 0;

Expand All @@ -171,6 +200,42 @@ public async ValueTask<MessageReadResult> ReadAsync(CancellationToken cancellati
}
}

/// <summary>
/// Populates the message reader from a payload read result.
/// </summary>
/// <param name="readResult">The read result to populate the message reader from.</param>
private void PopulateFromRead(ProtocolReadResult<ReadOnlySequence<byte>> readResult)
{
if (readResult.IsCanceled)
{
throw new OperationCanceledException("Read canceled while attempting to read WebSocket payload.");
}

var sequence = readResult.Message;

//If there is already data in the buffer, we'll need to add to it
if (_buffer.UnconsumedWrittenCount > 0)
{
if (sequence.IsSingleSegment)
{
_buffer.Write(sequence.FirstSpan);
}
else
{
foreach (var segment in sequence)
{
_buffer.Write(segment.Span);
}
}
}

_currentSequence = readResult.Message;
_isCompleted = readResult.IsCompleted;
_isCanceled = readResult.IsCanceled;

_awaitingHeader = _payloadReader.BytesRemaining == 0;
}

/// <summary>
/// Advances the reader to the provided position.
/// </summary>
Expand Down Expand Up @@ -223,15 +288,40 @@ public void AdvanceTo(SequencePosition consumed, SequencePosition examined)
/// </summary>
/// <param name="cancellationToken">A cancellation token, if any.</param>
/// <returns>True if the message is text, false otherwise.</returns>
public async ValueTask<bool> MoveNextMessageAsync(CancellationToken cancellationToken = default)
public ValueTask<bool> MoveNextMessageAsync(CancellationToken cancellationToken = default)
{
if (_payloadReader is object && _payloadReader.BytesRemaining != 0)
{
throw new InvalidOperationException("MoveNextMessageAsync cannot be called while a message is still being read.");
}

var frame = await GetNextMessageFrameAsync(cancellationToken);
var readTask = GetNextMessageFrameAsync(cancellationToken);
if (readTask.IsCompletedSuccessfully)
{
return new ValueTask<bool>(SetNextMessageAndGetIsText(readTask.Result));
}

return DoSetNextMessageAsync(readTask);
}

/// <summary>
/// Sets the next message frame asynchronously.
/// </summary>
/// <param name="readTask">The active ProtocolReader read task.</param>
/// <returns>True if the next message is a text message, false otherwise.</returns>
private async ValueTask<bool> DoSetNextMessageAsync(ValueTask<WebSocketReadFrame> readTask)
{
return SetNextMessageAndGetIsText(await readTask);
}

/// <summary>
/// Sets the message reader up with the next message frame data and determines if the message
/// is a text or binary message.
/// </summary>
/// <param name="frame">The read frame to set the message reader with.</param>
/// <returns>True if the next message is text, false otherwise.</returns>
private bool SetNextMessageAndGetIsText(WebSocketReadFrame frame)
{
if (frame.Header.Opcode != WebSocketOpcode.Binary && frame.Header.Opcode != WebSocketOpcode.Text)
{
ThrowBadProtocol($"Expected a start of message frame of Binary or Text but received {frame.Header.Opcode} instead.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Bedrock.Framework.Protocols.WebSockets
/// Masks or unmasks a WebSocket payload according to the provided masking key, tracking the
/// masking key index accross mask or unmasking requests.
/// </summary>
internal struct WebSocketPayloadEncoder
internal class WebSocketPayloadEncoder
{
/// <summary>
/// The masking key to use to mask or unmask the payload.
Expand All @@ -30,6 +30,15 @@ internal struct WebSocketPayloadEncoder
/// </summary>
/// <param name="maskingKey">The masking key to use to mask or unmask payloads.</param>
public WebSocketPayloadEncoder(int maskingKey)
{
Reset(maskingKey);
}

/// <summary>
/// Resets the payload encoder.
/// </summary>
/// <param name="maskingKey">The masking key to use to mask or unmask payloads.</param>
public void Reset(int maskingKey)
{
_maskingKey = maskingKey;
_currentMaskIndex = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ public WebSocketPayloadReader(WebSocketHeader header)
_masked = header.Masked;
}

/// <summary>
/// Resets the payload reader.
/// </summary>
/// <param name="header">The WebSocketHeader associated with this payload.</param>
public void Reset(WebSocketHeader header)
{
BytesRemaining = header.PayloadLength;
_masked = header.Masked;

_payloadEncoder.Reset(header.MaskingKey);
}

/// <summary>
/// Attempts to read the WebSocket payload from a sequence.
/// </summary>
Expand Down
1 change: 1 addition & 0 deletions tests/Bedrock.Framework.Benchmarks/Program.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using BenchmarkDotNet.Running;
using System;
using System.Threading.Tasks;

namespace Bedrock.Framework.Benchmarks
{
Expand Down
Loading

0 comments on commit c373ed0

Please sign in to comment.