diff --git a/src/libraries/Common/src/Interop/Interop.zlib.cs b/src/libraries/Common/src/Interop/Interop.zlib.cs
index 280c5558667eb9..ad517da4079ca2 100644
--- a/src/libraries/Common/src/Interop/Interop.zlib.cs
+++ b/src/libraries/Common/src/Interop/Interop.zlib.cs
@@ -20,6 +20,9 @@ internal static extern ZLibNative.ErrorCode DeflateInit2_(
[DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_Deflate")]
internal static extern ZLibNative.ErrorCode Deflate(ref ZLibNative.ZStream stream, ZLibNative.FlushCode flush);
+ [DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_DeflateReset")]
+ internal static extern ZLibNative.ErrorCode DeflateReset(ref ZLibNative.ZStream stream);
+
[DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_DeflateEnd")]
internal static extern ZLibNative.ErrorCode DeflateEnd(ref ZLibNative.ZStream stream);
@@ -29,6 +32,9 @@ internal static extern ZLibNative.ErrorCode DeflateInit2_(
[DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_Inflate")]
internal static extern ZLibNative.ErrorCode Inflate(ref ZLibNative.ZStream stream, ZLibNative.FlushCode flush);
+ [DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_InflateReset")]
+ internal static extern ZLibNative.ErrorCode InflateReset(ref ZLibNative.ZStream stream);
+
[DllImport(Libraries.CompressionNative, EntryPoint = "CompressionNative_InflateEnd")]
internal static extern ZLibNative.ErrorCode InflateEnd(ref ZLibNative.ZStream stream);
diff --git a/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.ZStream.cs b/src/libraries/Common/src/System/IO/Compression/ZLibNative.ZStream.cs
similarity index 100%
rename from src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.ZStream.cs
rename to src/libraries/Common/src/System/IO/Compression/ZLibNative.ZStream.cs
diff --git a/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.cs b/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs
similarity index 97%
rename from src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.cs
rename to src/libraries/Common/src/System/IO/Compression/ZLibNative.cs
index 8118aeba0ecb82..f0393ebbf35cb5 100644
--- a/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.cs
+++ b/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs
@@ -23,6 +23,7 @@ public enum FlushCode : int
NoFlush = 0,
SyncFlush = 2,
Finish = 4,
+ Block = 5
}
public enum ErrorCode : int
@@ -281,6 +282,13 @@ public ErrorCode Deflate(FlushCode flush)
}
+ public ErrorCode DeflateReset()
+ {
+ EnsureNotDisposed();
+ EnsureState(State.InitializedForDeflate);
+ return Interop.zlib.DeflateReset(ref _zStream);
+ }
+
public ErrorCode DeflateEnd()
{
EnsureNotDisposed();
@@ -313,6 +321,13 @@ public ErrorCode Inflate(FlushCode flush)
}
+ public ErrorCode InflateReset()
+ {
+ EnsureNotDisposed();
+ EnsureState(State.InitializedForInflate);
+ return Interop.zlib.InflateReset(ref _zStream);
+ }
+
public ErrorCode InflateEnd()
{
EnsureNotDisposed();
diff --git a/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs b/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs
index a092c966483896..d074f618bf16db 100644
--- a/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs
+++ b/src/libraries/Common/src/System/Net/WebSockets/WebSocketValidate.cs
@@ -9,6 +9,19 @@ namespace System.Net.WebSockets
{
internal static partial class WebSocketValidate
{
+ ///
+ /// The minimum value for window bits that the websocket per-message-deflate extension can support.
+ /// For the current implementation of deflate(), a windowBits value of 8 (a window size of 256 bytes) is not supported.
+ /// We cannot use silently 9 instead of 8, because the websocket produces raw deflate stream
+ /// and thus it needs to know the window bits in advance.
+ ///
+ internal const int MinDeflateWindowBits = 9;
+
+ ///
+ /// The maximum value for window bits that the websocket per-message-deflate extension can support.
+ ///
+ internal const int MaxDeflateWindowBits = 15;
+
internal const int MaxControlFramePayloadLength = 123;
private const int CloseStatusCodeAbort = 1006;
private const int CloseStatusCodeFailedTLSHandshake = 1015;
diff --git a/src/libraries/Native/AnyOS/System.IO.Compression.Native/entrypoints.c b/src/libraries/Native/AnyOS/System.IO.Compression.Native/entrypoints.c
index b194b978debe23..f363a91eb1add3 100644
--- a/src/libraries/Native/AnyOS/System.IO.Compression.Native/entrypoints.c
+++ b/src/libraries/Native/AnyOS/System.IO.Compression.Native/entrypoints.c
@@ -28,9 +28,11 @@ static const Entry s_compressionNative[] =
DllImportEntry(CompressionNative_Crc32)
DllImportEntry(CompressionNative_Deflate)
DllImportEntry(CompressionNative_DeflateEnd)
+ DllImportEntry(CompressionNative_DeflateReset)
DllImportEntry(CompressionNative_DeflateInit2_)
DllImportEntry(CompressionNative_Inflate)
DllImportEntry(CompressionNative_InflateEnd)
+ DllImportEntry(CompressionNative_InflateReset)
DllImportEntry(CompressionNative_InflateInit2_)
};
diff --git a/src/libraries/Native/AnyOS/zlib/pal_zlib.c b/src/libraries/Native/AnyOS/zlib/pal_zlib.c
index 2c399639d0fa92..aa4dcdca8a29e8 100644
--- a/src/libraries/Native/AnyOS/zlib/pal_zlib.c
+++ b/src/libraries/Native/AnyOS/zlib/pal_zlib.c
@@ -135,6 +135,17 @@ int32_t CompressionNative_Deflate(PAL_ZStream* stream, int32_t flush)
return result;
}
+int32_t CompressionNative_DeflateReset(PAL_ZStream* stream)
+{
+ assert(stream != NULL);
+
+ z_stream* zStream = GetCurrentZStream(stream);
+ int32_t result = deflateReset(zStream);
+ TransferStateToPalZStream(zStream, stream);
+
+ return result;
+}
+
int32_t CompressionNative_DeflateEnd(PAL_ZStream* stream)
{
assert(stream != NULL);
@@ -172,6 +183,17 @@ int32_t CompressionNative_Inflate(PAL_ZStream* stream, int32_t flush)
return result;
}
+int32_t CompressionNative_InflateReset(PAL_ZStream* stream)
+{
+ assert(stream != NULL);
+
+ z_stream* zStream = GetCurrentZStream(stream);
+ int32_t result = inflateReset(zStream);
+ TransferStateToPalZStream(zStream, stream);
+
+ return result;
+}
+
int32_t CompressionNative_InflateEnd(PAL_ZStream* stream)
{
assert(stream != NULL);
diff --git a/src/libraries/Native/AnyOS/zlib/pal_zlib.h b/src/libraries/Native/AnyOS/zlib/pal_zlib.h
index b317091b843f62..1eb1baa6b3846b 100644
--- a/src/libraries/Native/AnyOS/zlib/pal_zlib.h
+++ b/src/libraries/Native/AnyOS/zlib/pal_zlib.h
@@ -95,6 +95,14 @@ Returns a PAL_ErrorCode indicating success or an error number on failure.
*/
FUNCTIONEXPORT int32_t FUNCTIONCALLINGCONVENCTION CompressionNative_Deflate(PAL_ZStream* stream, int32_t flush);
+/*
+This function is equivalent to DeflateEnd followed by DeflateInit, but does not free and reallocate
+the internal compression state. The stream will leave the compression level and any other attributes that may have been set unchanged.
+
+Returns a PAL_ErrorCode indicating success or an error number on failure.
+*/
+FUNCTIONEXPORT int32_t FUNCTIONCALLINGCONVENCTION CompressionNative_DeflateReset(PAL_ZStream* stream);
+
/*
All dynamically allocated data structures for this stream are freed.
@@ -117,6 +125,14 @@ Returns a PAL_ErrorCode indicating success or an error number on failure.
*/
FUNCTIONEXPORT int32_t FUNCTIONCALLINGCONVENCTION CompressionNative_Inflate(PAL_ZStream* stream, int32_t flush);
+/*
+This function is equivalent to InflateEnd followed by InflateInit, but does not free and reallocate
+the internal decompression state. The The stream will keep attributes that may have been set by InflateInit.
+
+Returns a PAL_ErrorCode indicating success or an error number on failure.
+*/
+FUNCTIONEXPORT int32_t FUNCTIONCALLINGCONVENCTION CompressionNative_InflateReset(PAL_ZStream* stream);
+
/*
All dynamically allocated data structures for this stream are freed.
diff --git a/src/libraries/Native/Unix/System.IO.Compression.Native/System.IO.Compression.Native_unixexports.src b/src/libraries/Native/Unix/System.IO.Compression.Native/System.IO.Compression.Native_unixexports.src
index 08dd1700a52f21..2ac827035f271b 100644
--- a/src/libraries/Native/Unix/System.IO.Compression.Native/System.IO.Compression.Native_unixexports.src
+++ b/src/libraries/Native/Unix/System.IO.Compression.Native/System.IO.Compression.Native_unixexports.src
@@ -15,7 +15,9 @@ BrotliEncoderSetParameter
CompressionNative_Crc32
CompressionNative_Deflate
CompressionNative_DeflateEnd
+CompressionNative_DeflateReset
CompressionNative_DeflateInit2_
CompressionNative_Inflate
CompressionNative_InflateEnd
+CompressionNative_InflateReset
CompressionNative_InflateInit2_
diff --git a/src/libraries/Native/Windows/System.IO.Compression.Native/System.IO.Compression.Native.def b/src/libraries/Native/Windows/System.IO.Compression.Native/System.IO.Compression.Native.def
index 6821d0e538f51f..aecd0dd974618a 100644
--- a/src/libraries/Native/Windows/System.IO.Compression.Native/System.IO.Compression.Native.def
+++ b/src/libraries/Native/Windows/System.IO.Compression.Native/System.IO.Compression.Native.def
@@ -15,7 +15,9 @@ EXPORTS
CompressionNative_Crc32
CompressionNative_Deflate
CompressionNative_DeflateEnd
+ CompressionNative_DeflateReset
CompressionNative_DeflateInit2_
CompressionNative_Inflate
CompressionNative_InflateEnd
+ CompressionNative_InflateReset
CompressionNative_InflateInit2_
diff --git a/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj b/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj
index e2a7adee12f579..0ffa0044e2a167 100644
--- a/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj
+++ b/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj
@@ -25,8 +25,10 @@
-
-
+
+
diff --git a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs
index cee3a5170b8625..96cecd9e30f471 100644
--- a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs
+++ b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs
@@ -36,6 +36,8 @@ internal ClientWebSocketOptions() { }
[System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
public System.TimeSpan KeepAliveInterval { get { throw null; } set { } }
[System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
+ public System.Net.WebSockets.WebSocketDeflateOptions? DangerousDeflateOptions { get { throw null; } set { } }
+ [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
public System.Net.IWebProxy? Proxy { get { throw null; } set { } }
[System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
public System.Net.Security.RemoteCertificateValidationCallback? RemoteCertificateValidationCallback { get { throw null; } set { } }
diff --git a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx
index 3259b86c99fcba..7b4718b554a151 100644
--- a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx
+++ b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx
@@ -1,63 +1,4 @@
-
@@ -193,8 +134,14 @@
Connection was aborted.
-
+
WebSocket binary type '{0}' not supported.
-
-
+
+
+ The WebSocket failed to negotiate max server window bits. The client requested {0} but the server responded with {1}.
+
+
+ The WebSocket failed to negotiate max client window bits. The client requested {0} but the server responded with {1}.
+
+
\ No newline at end of file
diff --git a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj
index b74f3d8962be63..e84ea02f895ba8 100644
--- a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj
+++ b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj
@@ -6,6 +6,7 @@
+
@@ -37,6 +38,7 @@
+
diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs
index 85b0f025b46502..79dd04229b9c33 100644
--- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs
+++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs
@@ -100,6 +100,13 @@ public TimeSpan KeepAliveInterval
set => throw new PlatformNotSupportedException();
}
+ [UnsupportedOSPlatform("browser")]
+ public WebSocketDeflateOptions? DangerousDeflateOptions
+ {
+ get => throw new PlatformNotSupportedException();
+ set => throw new PlatformNotSupportedException();
+ }
+
[UnsupportedOSPlatform("browser")]
public void SetBuffer(int receiveBufferSize, int sendBufferSize)
{
diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs
new file mode 100644
index 00000000000000..3faa886d5c3068
--- /dev/null
+++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs
@@ -0,0 +1,23 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace System.Net.WebSockets
+{
+ internal static class ClientWebSocketDeflateConstants
+ {
+ ///
+ /// The maximum length that this extension can have, assuming that we're not abusing white space.
+ ///
+ /// "permessage-deflate; client_max_window_bits=15; client_no_context_takeover; server_max_window_bits=15; server_no_context_takeover"
+ ///
+ public const int MaxExtensionLength = 128;
+
+ public const string Extension = "permessage-deflate";
+
+ public const string ClientMaxWindowBits = "client_max_window_bits";
+ public const string ClientNoContextTakeover = "client_no_context_takeover";
+
+ public const string ServerMaxWindowBits = "server_max_window_bits";
+ public const string ServerNoContextTakeover = "server_no_context_takeover";
+ }
+}
diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs
index a7609a0ff09057..5ab2ad51d94eb3 100644
--- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs
+++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs
@@ -148,6 +148,18 @@ public TimeSpan KeepAliveInterval
}
}
+ ///
+ /// Gets or sets the options for the per-message-deflate extension.
+ /// When present, the options are sent to the server during the handshake phase. If the server
+ /// supports per-message-deflate and the options are accepted, the instance
+ /// will be created with compression enabled by default for all messages.
+ /// Be aware that enabling compression makes the application subject to CRIME/BREACH type of attacks.
+ /// It is strongly advised to turn off compression when sending data containing secrets by
+ /// specifying flag for such messages.
+ ///
+ [UnsupportedOSPlatform("browser")]
+ public WebSocketDeflateOptions? DangerousDeflateOptions { get; set; }
+
internal int ReceiveBufferSize => _receiveBufferSize;
internal ArraySegment? Buffer => _buffer;
diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs
index d61f368e7aae8e..e0c3902a915909 100644
--- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs
+++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs
@@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
+using System.Globalization;
using System.IO;
using System.Net.Http;
using System.Net.Http.Headers;
@@ -22,6 +23,7 @@ internal sealed class WebSocketHandle
private readonly CancellationTokenSource _abortSource = new CancellationTokenSource();
private WebSocketState _state = WebSocketState.Connecting;
+ private WebSocketDeflateOptions? _negotiatedDeflateOptions;
public WebSocket? WebSocket { get; private set; }
public WebSocketState State => WebSocket?.State ?? _state;
@@ -183,6 +185,21 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli
}
}
+ // Because deflate options are negotiated we need a new object
+ WebSocketDeflateOptions? negotiatedDeflateOptions = null;
+
+ if (options.DangerousDeflateOptions is not null && response.Headers.TryGetValues(HttpKnownHeaderNames.SecWebSocketExtensions, out IEnumerable? extensions))
+ {
+ foreach (ReadOnlySpan extension in extensions)
+ {
+ if (extension.TrimStart().StartsWith(ClientWebSocketDeflateConstants.Extension))
+ {
+ negotiatedDeflateOptions = ParseDeflateOptions(extension, options.DangerousDeflateOptions);
+ break;
+ }
+ }
+ }
+
if (response.Content is null)
{
throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely);
@@ -192,11 +209,14 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli
Stream connectedStream = response.Content.ReadAsStream();
Debug.Assert(connectedStream.CanWrite);
Debug.Assert(connectedStream.CanRead);
- WebSocket = WebSocket.CreateFromStream(
- connectedStream,
- isServer: false,
- subprotocol,
- options.KeepAliveInterval);
+ WebSocket = WebSocket.CreateFromStream(connectedStream, new WebSocketCreationOptions
+ {
+ IsServer = false,
+ SubProtocol = subprotocol,
+ KeepAliveInterval = options.KeepAliveInterval,
+ DangerousDeflateOptions = negotiatedDeflateOptions
+ });
+ _negotiatedDeflateOptions = negotiatedDeflateOptions;
}
catch (Exception exc)
{
@@ -226,6 +246,73 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli
}
}
+ private static WebSocketDeflateOptions ParseDeflateOptions(ReadOnlySpan extension, WebSocketDeflateOptions original)
+ {
+ var options = new WebSocketDeflateOptions();
+
+ while (true)
+ {
+ int end = extension.IndexOf(';');
+ ReadOnlySpan value = (end >= 0 ? extension[..end] : extension).Trim();
+
+ if (value.Length > 0)
+ {
+ if (value.SequenceEqual(ClientWebSocketDeflateConstants.ClientNoContextTakeover))
+ {
+ options.ClientContextTakeover = false;
+ }
+ else if (value.SequenceEqual(ClientWebSocketDeflateConstants.ServerNoContextTakeover))
+ {
+ options.ServerContextTakeover = false;
+ }
+ else if (value.StartsWith(ClientWebSocketDeflateConstants.ClientMaxWindowBits))
+ {
+ options.ClientMaxWindowBits = ParseWindowBits(value);
+ }
+ else if (value.StartsWith(ClientWebSocketDeflateConstants.ServerMaxWindowBits))
+ {
+ options.ServerMaxWindowBits = ParseWindowBits(value);
+ }
+
+ static int ParseWindowBits(ReadOnlySpan value)
+ {
+ var startIndex = value.IndexOf('=');
+
+ if (startIndex < 0 ||
+ !int.TryParse(value.Slice(startIndex + 1), NumberStyles.Integer, CultureInfo.InvariantCulture, out int windowBits) ||
+ windowBits < WebSocketValidate.MinDeflateWindowBits ||
+ windowBits > WebSocketValidate.MaxDeflateWindowBits)
+ {
+ throw new WebSocketException(WebSocketError.HeaderError,
+ SR.Format(SR.net_WebSockets_InvalidResponseHeader, ClientWebSocketDeflateConstants.Extension, value.ToString()));
+ }
+
+ return windowBits;
+ }
+ }
+
+ if (end < 0)
+ {
+ break;
+ }
+ extension = extension[(end + 1)..];
+ }
+
+ if (options.ClientMaxWindowBits > original.ClientMaxWindowBits)
+ {
+ throw new WebSocketException(string.Format(SR.net_WebSockets_ClientWindowBitsNegotiationFailure,
+ original.ClientMaxWindowBits, options.ClientMaxWindowBits));
+ }
+
+ if (options.ServerMaxWindowBits > original.ServerMaxWindowBits)
+ {
+ throw new WebSocketException(string.Format(SR.net_WebSockets_ServerWindowBitsNegotiationFailure,
+ original.ServerMaxWindowBits, options.ServerMaxWindowBits));
+ }
+
+ return options;
+ }
+
/// Adds the necessary headers for the web socket request.
/// The request to which the headers should be added.
/// The generated security key to send in the Sec-WebSocket-Key header.
@@ -240,6 +327,47 @@ private static void AddWebSocketHeaders(HttpRequestMessage request, string secKe
{
request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketProtocol, string.Join(", ", options.RequestedSubProtocols));
}
+ if (options.DangerousDeflateOptions is not null)
+ {
+ request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketExtensions, GetDeflateOptions(options.DangerousDeflateOptions));
+
+ static string GetDeflateOptions(WebSocketDeflateOptions options)
+ {
+ var builder = new StringBuilder(ClientWebSocketDeflateConstants.MaxExtensionLength);
+ builder.Append(ClientWebSocketDeflateConstants.Extension).Append("; ");
+
+ if (options.ClientMaxWindowBits != WebSocketValidate.MaxDeflateWindowBits)
+ {
+ builder.Append(ClientWebSocketDeflateConstants.ClientMaxWindowBits).Append('=')
+ .Append(options.ClientMaxWindowBits.ToString(CultureInfo.InvariantCulture));
+ }
+ else
+ {
+ // Advertise that we support this option
+ builder.Append(ClientWebSocketDeflateConstants.ClientMaxWindowBits);
+ }
+
+ if (!options.ClientContextTakeover)
+ {
+ builder.Append("; ").Append(ClientWebSocketDeflateConstants.ClientNoContextTakeover);
+ }
+
+ if (options.ServerMaxWindowBits != WebSocketValidate.MaxDeflateWindowBits)
+ {
+ builder.Append("; ")
+ .Append(ClientWebSocketDeflateConstants.ServerMaxWindowBits).Append('=')
+ .Append(options.ServerMaxWindowBits.ToString(CultureInfo.InvariantCulture));
+ }
+
+ if (!options.ServerContextTakeover)
+ {
+ builder.Append("; ").Append(ClientWebSocketDeflateConstants.ServerNoContextTakeover);
+ }
+
+ Debug.Assert(builder.Length <= ClientWebSocketDeflateConstants.MaxExtensionLength);
+ return builder.ToString();
+ }
+ }
}
///
diff --git a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs
new file mode 100644
index 00000000000000..e0a0e1e59fd846
--- /dev/null
+++ b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs
@@ -0,0 +1,103 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System.Net.Test.Common;
+using System.Reflection;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+using Xunit;
+using Xunit.Abstractions;
+
+namespace System.Net.WebSockets.Client.Tests
+{
+ [PlatformSpecific(~TestPlatforms.Browser)]
+ public class DeflateTests : ClientWebSocketTestBase
+ {
+ public DeflateTests(ITestOutputHelper output) : base(output)
+ {
+ }
+
+ [ConditionalTheory(nameof(WebSocketsSupported))]
+ [ActiveIssue("https://github.com/dotnet/runtime/issues/34690", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)]
+ [InlineData(15, true, 15, true, "permessage-deflate; client_max_window_bits")]
+ [InlineData(14, true, 15, true, "permessage-deflate; client_max_window_bits=14")]
+ [InlineData(15, true, 14, true, "permessage-deflate; client_max_window_bits; server_max_window_bits=14")]
+ [InlineData(10, true, 11, true, "permessage-deflate; client_max_window_bits=10; server_max_window_bits=11")]
+ [InlineData(15, false, 15, true, "permessage-deflate; client_max_window_bits; client_no_context_takeover")]
+ [InlineData(15, true, 15, false, "permessage-deflate; client_max_window_bits; server_no_context_takeover")]
+ public async Task PerMessageDeflateHeaders(int clientWindowBits, bool clientContextTakeover,
+ int serverWindowBits, bool serverContextTakover,
+ string expected)
+ {
+ await LoopbackServer.CreateClientAndServerAsync(async uri =>
+ {
+ using var client = new ClientWebSocket();
+ using var cancellation = new CancellationTokenSource(TimeOutMilliseconds);
+
+ client.Options.DangerousDeflateOptions = new WebSocketDeflateOptions
+ {
+ ClientMaxWindowBits = clientWindowBits,
+ ClientContextTakeover = clientContextTakeover,
+ ServerMaxWindowBits = serverWindowBits,
+ ServerContextTakeover = serverContextTakover
+ };
+
+ await client.ConnectAsync(uri, cancellation.Token);
+
+ object webSocketHandle = client.GetType().GetField("_innerWebSocket", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(client);
+ WebSocketDeflateOptions negotiatedDeflateOptions = (WebSocketDeflateOptions)webSocketHandle.GetType()
+ .GetField("_negotiatedDeflateOptions", BindingFlags.NonPublic | BindingFlags.Instance)
+ .GetValue(webSocketHandle);
+
+ Assert.Equal(clientWindowBits - 1, negotiatedDeflateOptions.ClientMaxWindowBits);
+ Assert.Equal(clientContextTakeover, negotiatedDeflateOptions.ClientContextTakeover);
+ Assert.Equal(serverWindowBits - 1, negotiatedDeflateOptions.ServerMaxWindowBits);
+ Assert.Equal(serverContextTakover, negotiatedDeflateOptions.ServerContextTakeover);
+ }, server => server.AcceptConnectionAsync(async connection =>
+ {
+ var extensionsReply = CreateDeflateOptionsHeader(new WebSocketDeflateOptions
+ {
+ ClientMaxWindowBits = clientWindowBits - 1,
+ ClientContextTakeover = clientContextTakeover,
+ ServerMaxWindowBits = serverWindowBits - 1,
+ ServerContextTakeover = serverContextTakover
+ });
+ Dictionary headers = await LoopbackHelper.WebSocketHandshakeAsync(connection, extensionsReply);
+ Assert.NotNull(headers);
+ Assert.True(headers.TryGetValue("Sec-WebSocket-Extensions", out string extensions));
+ Assert.Equal(expected, extensions);
+ }), new LoopbackServer.Options { WebSocketEndpoint = true });
+ }
+
+ private static string CreateDeflateOptionsHeader(WebSocketDeflateOptions options)
+ {
+ var builder = new StringBuilder();
+ builder.Append("permessage-deflate");
+
+ if (options.ClientMaxWindowBits != 15)
+ {
+ builder.Append("; client_max_window_bits=").Append(options.ClientMaxWindowBits);
+ }
+
+ if (!options.ClientContextTakeover)
+ {
+ builder.Append("; client_no_context_takeover");
+ }
+
+ if (options.ServerMaxWindowBits != 15)
+ {
+ builder.Append("; server_max_window_bits=").Append(options.ServerMaxWindowBits);
+ }
+
+ if (!options.ServerContextTakeover)
+ {
+ builder.Append("; server_no_context_takeover");
+ }
+
+ return builder.ToString();
+ }
+ }
+}
diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs
index 5726326c6ab8fa..48d167b072f781 100644
--- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs
+++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs
@@ -11,7 +11,7 @@ namespace System.Net.WebSockets.Client.Tests
{
public static class LoopbackHelper
{
- public static async Task> WebSocketHandshakeAsync(LoopbackServer.Connection connection)
+ public static async Task> WebSocketHandshakeAsync(LoopbackServer.Connection connection, string? extensions = null)
{
string serverResponse = null;
List headers = await connection.ReadRequestHeaderAsync().ConfigureAwait(false);
@@ -34,6 +34,7 @@ public static async Task> WebSocketHandshakeAsync(Loo
"Content-Length: 0\r\n" +
"Upgrade: websocket\r\n" +
"Connection: Upgrade\r\n" +
+ (extensions is null ? null : $"Sec-WebSocket-Extensions: {extensions}\r\n") +
"Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n";
}
}
diff --git a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj
index a1323fa83db1ed..21ba2a12dd5247 100644
--- a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj
+++ b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj
@@ -46,6 +46,7 @@
+
diff --git a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs
index e4ff945bb5b647..32ebf5eb1e804e 100644
--- a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs
+++ b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs
@@ -29,6 +29,7 @@ protected WebSocket() { }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public static System.Net.WebSockets.WebSocket CreateClientWebSocket(System.IO.Stream innerStream, string? subProtocol, int receiveBufferSize, int sendBufferSize, System.TimeSpan keepAliveInterval, bool useZeroMaskingKey, System.ArraySegment internalBuffer) { throw null; }
public static System.Net.WebSockets.WebSocket CreateFromStream(System.IO.Stream stream, bool isServer, string? subProtocol, System.TimeSpan keepAliveInterval) { throw null; }
+ public static System.Net.WebSockets.WebSocket CreateFromStream(System.IO.Stream stream, System.Net.WebSockets.WebSocketCreationOptions options) { throw null; }
public static System.ArraySegment CreateServerBuffer(int receiveBufferSize) { throw null; }
public abstract void Dispose();
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
@@ -41,6 +42,7 @@ protected WebSocket() { }
public static void RegisterPrefixes() { }
public abstract System.Threading.Tasks.Task SendAsync(System.ArraySegment buffer, System.Net.WebSockets.WebSocketMessageType messageType, bool endOfMessage, System.Threading.CancellationToken cancellationToken);
public virtual System.Threading.Tasks.ValueTask SendAsync(System.ReadOnlyMemory buffer, System.Net.WebSockets.WebSocketMessageType messageType, bool endOfMessage, System.Threading.CancellationToken cancellationToken) { throw null; }
+ public virtual System.Threading.Tasks.ValueTask SendAsync(System.ReadOnlyMemory buffer, System.Net.WebSockets.WebSocketMessageType messageType, System.Net.WebSockets.WebSocketMessageFlags messageFlags, System.Threading.CancellationToken cancellationToken) { throw null; }
protected static void ThrowOnInvalidState(System.Net.WebSockets.WebSocketState state, params System.Net.WebSockets.WebSocketState[] validStates) { }
}
public enum WebSocketCloseStatus
@@ -131,4 +133,25 @@ public enum WebSocketState
Closed = 5,
Aborted = 6,
}
+ public sealed partial class WebSocketCreationOptions
+ {
+ public bool IsServer { get { throw null; } set { } }
+ public string? SubProtocol { get { throw null; } set { } }
+ public System.TimeSpan KeepAliveInterval { get { throw null; } set { } }
+ public System.Net.WebSockets.WebSocketDeflateOptions? DangerousDeflateOptions { get { throw null; } set { } }
+ }
+ public sealed partial class WebSocketDeflateOptions
+ {
+ public int ClientMaxWindowBits { get { throw null; } set { } }
+ public bool ClientContextTakeover { get { throw null; } set { } }
+ public int ServerMaxWindowBits { get { throw null; } set { } }
+ public bool ServerContextTakeover { get { throw null; } set { } }
+ }
+ [Flags]
+ public enum WebSocketMessageFlags
+ {
+ None = 0,
+ EndOfMessage = 1,
+ DisableCompression = 2
+ }
}
diff --git a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx
index a4f630ea24c039..693f8d3863fd7f 100644
--- a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx
+++ b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx
@@ -138,4 +138,31 @@
The base stream is not writeable.
-
+
+ The argument must be a value between {0} and {1}.
+
+
+ The WebSocket received a continuation frame with Per-Message Compressed flag set.
+
+
+ The WebSocket received compressed frame when compression is not enabled.
+
+
+ The underlying compression routine could not be loaded correctly.
+
+
+ The stream state of the underlying compression routine is inconsistent.
+
+
+ The underlying compression routine could not reserve sufficient memory.
+
+
+ The underlying compression routine returned an unexpected error code {0}.
+
+
+ The message was compressed using an unsupported compression method.
+
+
+ The compression options for a continuation cannot be different than the options used to send the first fragment of the message.
+
+
\ No newline at end of file
diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj
index d65e6c55737af0..215cf6b4a91649 100644
--- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj
+++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj
@@ -1,26 +1,48 @@
True
- $(NetCoreAppCurrent)
+ $(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-Unix;$(NetCoreAppCurrent)-Browserenable
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs
new file mode 100644
index 00000000000000..e7f18072842433
--- /dev/null
+++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs
@@ -0,0 +1,235 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers;
+using System.Diagnostics;
+using static System.IO.Compression.ZLibNative;
+
+namespace System.Net.WebSockets.Compression
+{
+ ///
+ /// Provides a wrapper around the ZLib compression API.
+ ///
+ internal sealed class WebSocketDeflater : IDisposable
+ {
+ private readonly int _windowBits;
+ private ZLibStreamHandle? _stream;
+ private readonly bool _persisted;
+
+ private byte[]? _buffer;
+
+ internal WebSocketDeflater(int windowBits, bool persisted)
+ {
+ _windowBits = -windowBits; // Negative for raw deflate
+ _persisted = persisted;
+ }
+
+ public void Dispose()
+ {
+ if (_stream is not null)
+ {
+ _stream.Dispose();
+ _stream = null;
+ }
+ }
+
+ public void ReleaseBuffer()
+ {
+ if (_buffer is not null)
+ {
+ ArrayPool.Shared.Return(_buffer);
+ _buffer = null;
+ }
+ }
+
+ public ReadOnlySpan Deflate(ReadOnlySpan payload, bool endOfMessage)
+ {
+ Debug.Assert(_buffer is null, "Invalid state, ReleaseBuffer not called.");
+
+ // Do not try to rent more than 1MB initially, because it will actually allocate
+ // instead of renting. Be optimistic that what we're sending is actually going to fit.
+ const int MaxInitialBufferLength = 1024 * 1024;
+
+ // For small payloads there might actually be overhead in the compression and the resulting
+ // output might be larger than the payload. This is why we rent at least 4KB initially.
+ const int MinInitialBufferLength = 4 * 1024;
+
+ _buffer = ArrayPool.Shared.Rent(Math.Clamp(payload.Length, MinInitialBufferLength, MaxInitialBufferLength));
+ int position = 0;
+
+ while (true)
+ {
+ DeflatePrivate(payload, _buffer.AsSpan(position), endOfMessage,
+ out int consumed, out int written, out bool needsMoreOutput);
+ position += written;
+
+ if (!needsMoreOutput)
+ {
+ Debug.Assert(consumed == payload.Length);
+ break;
+ }
+
+ payload = payload.Slice(consumed);
+
+ // Rent a 30% bigger buffer
+ byte[] newBuffer = ArrayPool.Shared.Rent((int)(_buffer.Length * 1.3));
+ _buffer.AsSpan(0, position).CopyTo(newBuffer);
+ ArrayPool.Shared.Return(_buffer);
+ _buffer = newBuffer;
+ }
+
+ return new ReadOnlySpan(_buffer, 0, position);
+ }
+
+ private void DeflatePrivate(ReadOnlySpan payload, Span output, bool endOfMessage,
+ out int consumed, out int written, out bool needsMoreOutput)
+ {
+ _stream ??= CreateDeflater();
+
+ if (payload.Length == 0)
+ {
+ consumed = 0;
+ written = 0;
+ }
+ else
+ {
+ UnsafeDeflate(payload, output, out consumed, out written, out needsMoreOutput);
+
+ if (needsMoreOutput)
+ {
+ Debug.Assert(written == output.Length);
+ return;
+ }
+ }
+
+ written += UnsafeFlush(output.Slice(written), out needsMoreOutput);
+
+ if (needsMoreOutput)
+ {
+ return;
+ }
+ Debug.Assert(output.Slice(written - WebSocketInflater.FlushMarkerLength, WebSocketInflater.FlushMarkerLength)
+ .EndsWith(WebSocketInflater.FlushMarker), "The deflated block must always end with a flush marker.");
+
+ if (endOfMessage)
+ {
+ // As per RFC we need to remove the flush markers
+ written -= WebSocketInflater.FlushMarkerLength;
+ }
+
+ if (endOfMessage && !_persisted)
+ {
+ _stream.Dispose();
+ _stream = null;
+ }
+ }
+
+ private unsafe void UnsafeDeflate(ReadOnlySpan input, Span output, out int consumed, out int written, out bool needsMoreBuffer)
+ {
+ Debug.Assert(_stream is not null);
+
+ fixed (byte* fixedInput = input)
+ fixed (byte* fixedOutput = output)
+ {
+ _stream.NextIn = (IntPtr)fixedInput;
+ _stream.AvailIn = (uint)input.Length;
+
+ _stream.NextOut = (IntPtr)fixedOutput;
+ _stream.AvailOut = (uint)output.Length;
+
+ // The flush is set to Z_NO_FLUSH, which allows deflate to decide
+ // how much data to accumulate before producing output,
+ // in order to maximize compression.
+ var errorCode = Deflate(_stream, FlushCode.NoFlush);
+
+ consumed = input.Length - (int)_stream.AvailIn;
+ written = output.Length - (int)_stream.AvailOut;
+
+ needsMoreBuffer = errorCode == ErrorCode.BufError || _stream.AvailIn > 0;
+ }
+ }
+
+ private unsafe int UnsafeFlush(Span output, out bool needsMoreBuffer)
+ {
+ Debug.Assert(_stream is not null);
+ Debug.Assert(_stream.AvailIn == 0);
+
+ fixed (byte* fixedOutput = output)
+ {
+ _stream.NextIn = IntPtr.Zero;
+ _stream.AvailIn = 0;
+
+ _stream.NextOut = (IntPtr)fixedOutput;
+ _stream.AvailOut = (uint)output.Length;
+
+ // We need to use Z_BLOCK_FLUSH to instruct the zlib to flush all outstanding
+ // data but also not to emit a deflate block boundary. After we know that there is no
+ // more data, we can safely proceed to instruct the library to emit the boundary markers.
+ ErrorCode errorCode = Deflate(_stream, FlushCode.Block);
+ Debug.Assert(errorCode is ErrorCode.Ok or ErrorCode.BufError);
+
+ // We need at least 6 bytes to guarantee that we can emit a deflate block boundary.
+ needsMoreBuffer = _stream.AvailOut < 6;
+
+ if (!needsMoreBuffer)
+ {
+ // The flush is set to Z_SYNC_FLUSH, all pending output is flushed
+ // to the output buffer and the output is aligned on a byte boundary,
+ // so that the decompressor can get all input data available so far.
+ // This completes the current deflate block and follows it with an empty
+ // stored block that is three bits plus filler bits to the next byte,
+ // followed by four bytes (00 00 ff ff).
+ errorCode = Deflate(_stream, FlushCode.SyncFlush);
+ Debug.Assert(errorCode == ErrorCode.Ok);
+ }
+
+ return output.Length - (int)_stream.AvailOut;
+ }
+ }
+
+ private static ErrorCode Deflate(ZLibStreamHandle stream, FlushCode flushCode)
+ {
+ ErrorCode errorCode = stream.Deflate(flushCode);
+
+ if (errorCode is ErrorCode.Ok or ErrorCode.StreamEnd or ErrorCode.BufError)
+ {
+ return errorCode;
+ }
+
+ string message = errorCode == ErrorCode.StreamError
+ ? SR.ZLibErrorInconsistentStream
+ : string.Format(SR.ZLibErrorUnexpected, (int)errorCode);
+ throw new WebSocketException(message);
+ }
+
+ private ZLibStreamHandle CreateDeflater()
+ {
+ ZLibStreamHandle stream;
+ ErrorCode errorCode;
+ try
+ {
+ errorCode = CreateZLibStreamForDeflate(out stream,
+ level: CompressionLevel.DefaultCompression,
+ windowBits: _windowBits,
+ memLevel: Deflate_DefaultMemLevel,
+ strategy: CompressionStrategy.DefaultStrategy);
+ }
+ catch (Exception cause)
+ {
+ throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause);
+ }
+
+ if (errorCode == ErrorCode.Ok)
+ {
+ return stream;
+ }
+
+ stream.Dispose();
+
+ string message = errorCode == ErrorCode.MemError
+ ? SR.ZLibErrorNotEnoughMemory
+ : string.Format(SR.ZLibErrorUnexpected, (int)errorCode);
+ throw new WebSocketException(message);
+ }
+ }
+}
diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs
new file mode 100644
index 00000000000000..6ade12d539a440
--- /dev/null
+++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs
@@ -0,0 +1,285 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers;
+using System.Diagnostics;
+using static System.IO.Compression.ZLibNative;
+
+namespace System.Net.WebSockets.Compression
+{
+ ///
+ /// Provides a wrapper around the ZLib decompression API.
+ ///
+ internal sealed class WebSocketInflater : IDisposable
+ {
+ internal const int FlushMarkerLength = 4;
+ internal static ReadOnlySpan FlushMarker => new byte[] { 0x00, 0x00, 0xFF, 0xFF };
+
+ private readonly int _windowBits;
+ private ZLibStreamHandle? _stream;
+ private readonly bool _persisted;
+
+ ///
+ /// There is no way of knowing, when decoding data, if the underlying inflater
+ /// has flushed all outstanding data to consumer other than to provide a buffer
+ /// and see whether any bytes are written. There are cases when the consumers
+ /// provide a buffer exactly the size of the uncompressed data and in this case
+ /// to avoid requiring another read we will use this field.
+ ///
+ private byte? _remainingByte;
+
+ ///
+ /// The last added bytes to the inflater were part of the final
+ /// payload for the message being sent.
+ ///
+ private bool _endOfMessage;
+
+ private byte[]? _buffer;
+
+ ///
+ /// The position for the next unconsumed byte in the inflate buffer.
+ ///
+ private int _position;
+
+ ///
+ /// How many unconsumed bytes are left in the inflate buffer.
+ ///
+ private int _available;
+
+ internal WebSocketInflater(int windowBits, bool persisted)
+ {
+ _windowBits = -windowBits; // Negative for raw deflate
+ _persisted = persisted;
+ }
+
+ public Memory Memory => _buffer.AsMemory(_position + _available);
+
+ public Span Span => _buffer.AsSpan(_position + _available);
+
+ public void Dispose()
+ {
+ if (_stream is not null)
+ {
+ _stream.Dispose();
+ _stream = null;
+ }
+ ReleaseBuffer();
+ }
+
+ ///
+ /// Initializes the inflater by allocating a buffer so the websocket can receive directly onto it.
+ ///
+ /// the length of the message payload
+ /// the length of the buffer where the payload will be inflated
+ public void Prepare(long payloadLength, int userBufferLength)
+ {
+ if (_buffer is not null)
+ {
+ Debug.Assert(_available > 0);
+
+ _buffer.AsSpan(_position, _available).CopyTo(_buffer);
+ _position = 0;
+ }
+ else
+ {
+ // Rent a buffer as close to the size of the user buffer as possible,
+ // but not try to rent anything above 1MB because the array pool will allocate.
+ // If the payload is smaller than the user buffer, rent only as much as we need.
+ _buffer = ArrayPool.Shared.Rent(Math.Min(userBufferLength, (int)Math.Min(payloadLength, 1024 * 1024)));
+ }
+ }
+
+ public void AddBytes(int totalBytesReceived, bool endOfMessage)
+ {
+ Debug.Assert(totalBytesReceived == 0 || _buffer is not null, "Prepare must be called.");
+
+ _available += totalBytesReceived;
+ _endOfMessage = endOfMessage;
+
+ if (endOfMessage)
+ {
+ if (_buffer is null)
+ {
+ Debug.Assert(_available == 0);
+
+ _buffer = ArrayPool.Shared.Rent(FlushMarkerLength);
+ _available = FlushMarkerLength;
+ FlushMarker.CopyTo(_buffer);
+ }
+ else
+ {
+ if (_buffer.Length < _available + FlushMarkerLength)
+ {
+ byte[] newBuffer = ArrayPool.Shared.Rent(_available + FlushMarkerLength);
+ _buffer.AsSpan(0, _available).CopyTo(newBuffer);
+ ArrayPool.Shared.Return(_buffer);
+
+ _buffer = newBuffer;
+ }
+
+ FlushMarker.CopyTo(_buffer.AsSpan(_available));
+ _available += FlushMarkerLength;
+ }
+ }
+ }
+
+ ///
+ /// Inflates the last receive payload into the provided buffer.
+ ///
+ public unsafe bool Inflate(Span output, out int written)
+ {
+ _stream ??= CreateInflater();
+
+ if (_available > 0 && output.Length > 0)
+ {
+ int consumed;
+
+ fixed (byte* bufferPtr = _buffer)
+ {
+ _stream.NextIn = (IntPtr)(bufferPtr + _position);
+ _stream.AvailIn = (uint)_available;
+
+ written = Inflate(_stream, output, FlushCode.NoFlush);
+ consumed = _available - (int)_stream.AvailIn;
+ }
+
+ _position += consumed;
+ _available -= consumed;
+ }
+ else
+ {
+ written = 0;
+ }
+
+ if (_available == 0)
+ {
+ ReleaseBuffer();
+ return _endOfMessage ? Finish(output, ref written) : true;
+ }
+
+ return false;
+ }
+
+ ///
+ /// Finishes the decoding by flushing any outstanding data to the output.
+ ///
+ /// true if the flush completed, false to indicate that there is more outstanding data.
+ private unsafe bool Finish(Span output, ref int written)
+ {
+ Debug.Assert(_stream is not null && _stream.AvailIn == 0);
+ Debug.Assert(_available == 0);
+
+ if (_remainingByte is not null)
+ {
+ if (output.Length == written)
+ {
+ return false;
+ }
+ output[written] = _remainingByte.GetValueOrDefault();
+ _remainingByte = null;
+ written += 1;
+ }
+
+ // If we have more space in the output, try to inflate
+ if (output.Length > written)
+ {
+ written += Inflate(_stream, output[written..], FlushCode.SyncFlush);
+ }
+
+ // After inflate, if we have more space in the output then it means that we
+ // have finished. Otherwise we need to manually check for more data.
+ if (written < output.Length || IsFinished(_stream, out _remainingByte))
+ {
+ if (!_persisted)
+ {
+ _stream.Dispose();
+ _stream = null;
+ }
+ return true;
+ }
+
+ return false;
+ }
+
+ private void ReleaseBuffer()
+ {
+ if (_buffer is not null)
+ {
+ ArrayPool.Shared.Return(_buffer);
+ _buffer = null;
+ _available = 0;
+ _position = 0;
+ }
+ }
+
+ private static unsafe bool IsFinished(ZLibStreamHandle stream, out byte? remainingByte)
+ {
+ // There is no other way to make sure that we've consumed all data
+ // but to try to inflate again with at least one byte of output buffer.
+ byte b;
+ if (Inflate(stream, new Span(&b, 1), FlushCode.SyncFlush) == 0)
+ {
+ remainingByte = null;
+ return true;
+ }
+
+ remainingByte = b;
+ return false;
+ }
+
+ private static unsafe int Inflate(ZLibStreamHandle stream, Span destination, FlushCode flushCode)
+ {
+ Debug.Assert(destination.Length > 0);
+ ErrorCode errorCode;
+
+ fixed (byte* bufPtr = destination)
+ {
+ stream.NextOut = (IntPtr)bufPtr;
+ stream.AvailOut = (uint)destination.Length;
+
+ errorCode = stream.Inflate(flushCode);
+
+ if (errorCode is ErrorCode.Ok or ErrorCode.StreamEnd or ErrorCode.BufError)
+ {
+ return destination.Length - (int)stream.AvailOut;
+ }
+ }
+
+ string message = errorCode switch
+ {
+ ErrorCode.MemError => SR.ZLibErrorNotEnoughMemory,
+ ErrorCode.DataError => SR.ZLibUnsupportedCompression,
+ ErrorCode.StreamError => SR.ZLibErrorInconsistentStream,
+ _ => string.Format(SR.ZLibErrorUnexpected, (int)errorCode)
+ };
+ throw new WebSocketException(message);
+ }
+
+ private ZLibStreamHandle CreateInflater()
+ {
+ ZLibStreamHandle stream;
+ ErrorCode errorCode;
+
+ try
+ {
+ errorCode = CreateZLibStreamForInflate(out stream, _windowBits);
+ }
+ catch (Exception exception)
+ {
+ throw new WebSocketException(SR.ZLibErrorDLLLoadError, exception);
+ }
+
+ if (errorCode == ErrorCode.Ok)
+ {
+ return stream;
+ }
+
+ stream.Dispose();
+
+ string message = errorCode == ErrorCode.MemError
+ ? SR.ZLibErrorNotEnoughMemory
+ : string.Format(SR.ZLibErrorUnexpected, (int)errorCode);
+ throw new WebSocketException(message);
+ }
+ }
+}
diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs
index 9a0142f9c73b36..971c2ceff82be4 100644
--- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs
+++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs
@@ -4,6 +4,7 @@
using System.Buffers;
using System.Diagnostics;
using System.IO;
+using System.Net.WebSockets.Compression;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
@@ -25,18 +26,6 @@ namespace System.Net.WebSockets
///
internal sealed partial class ManagedWebSocket : WebSocket
{
- /// Creates a from a connected to a websocket endpoint.
- /// The connected Stream.
- /// true if this is the server-side of the connection; false if this is the client-side of the connection.
- /// The agreed upon subprotocol for the connection.
- /// The interval to use for keep-alive pings.
- /// The created instance.
- public static ManagedWebSocket CreateFromConnectedStream(
- Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval)
- {
- return new ManagedWebSocket(stream, isServer, subprotocol, keepAliveInterval);
- }
-
/// Thread-safe random number generator used to generate masks for each send.
private static readonly RandomNumberGenerator s_random = RandomNumberGenerator.Create();
/// Encoding for the payload of text messages: UTF8 encoding that throws if invalid bytes are discovered, per the RFC.
@@ -113,7 +102,7 @@ public static ManagedWebSocket CreateFromConnectedStream(
/// remaining to be received for that header. As a result, between fragments, the payload
/// length in this header should be 0.
///
- private MessageHeader _lastReceiveHeader = new MessageHeader { Opcode = MessageOpcode.Text, Fin = true };
+ private MessageHeader _lastReceiveHeader = new MessageHeader { Opcode = MessageOpcode.Text, Fin = true, Processed = true };
/// The offset of the next available byte in the _receiveBuffer.
private int _receiveBufferOffset;
/// The number of bytes available in the _receiveBuffer.
@@ -137,6 +126,10 @@ public static ManagedWebSocket CreateFromConnectedStream(
///
private bool _lastSendWasFragment;
///
+ /// Whether the last SendAsync had flag set.
+ ///
+ private bool _lastSendHadDisableCompression;
+ ///
/// The task returned from the last ReceiveAsync(ArraySegment, ...) operation to not complete synchronously.
/// If this is not null and not completed when a subsequent ReceiveAsync is issued, an exception occurs.
///
@@ -151,12 +144,15 @@ public static ManagedWebSocket CreateFromConnectedStream(
///
private object ReceiveAsyncLock => _utf8TextState; // some object, as we're simply lock'ing on it
+ private readonly WebSocketInflater? _inflater;
+ private readonly WebSocketDeflater? _deflater;
+
/// Initializes the websocket.
/// The connected Stream.
/// true if this is the server-side of the connection; false if this is the client-side of the connection.
/// The agreed upon subprotocol for the connection.
/// The interval to use for keep-alive pings.
- private ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval)
+ internal ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval)
{
Debug.Assert(StateUpdateLock != null, $"Expected {nameof(StateUpdateLock)} to be non-null");
Debug.Assert(ReceiveAsyncLock != null, $"Expected {nameof(ReceiveAsyncLock)} to be non-null");
@@ -212,6 +208,29 @@ private ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Time
}
}
+ /// Initializes the websocket.
+ /// The connected Stream.
+ /// The options with which the websocket must be created.
+ internal ManagedWebSocket(Stream stream, WebSocketCreationOptions options)
+ : this(stream, options.IsServer, options.SubProtocol, options.KeepAliveInterval)
+ {
+ var deflateOptions = options.DangerousDeflateOptions;
+
+ if (deflateOptions is not null)
+ {
+ if (options.IsServer)
+ {
+ _inflater = new WebSocketInflater(deflateOptions.ClientMaxWindowBits, deflateOptions.ClientContextTakeover);
+ _deflater = new WebSocketDeflater(deflateOptions.ServerMaxWindowBits, deflateOptions.ServerContextTakeover);
+ }
+ else
+ {
+ _inflater = new WebSocketInflater(deflateOptions.ServerMaxWindowBits, deflateOptions.ServerContextTakeover);
+ _deflater = new WebSocketDeflater(deflateOptions.ClientMaxWindowBits, deflateOptions.ClientContextTakeover);
+ }
+ }
+ }
+
public override void Dispose()
{
lock (StateUpdateLock)
@@ -227,7 +246,10 @@ private void DisposeCore()
{
_disposed = true;
_keepAliveTimer?.Dispose();
- _stream?.Dispose();
+ _stream.Dispose();
+ _inflater?.Dispose();
+ _deflater?.Dispose();
+
if (_state < WebSocketState.Aborted)
{
_state = WebSocketState.Closed;
@@ -255,10 +277,10 @@ public override Task SendAsync(ArraySegment buffer, WebSocketMessageType m
WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer));
- return SendPrivateAsync(buffer, messageType, endOfMessage, cancellationToken).AsTask();
+ return SendPrivateAsync(buffer, messageType, endOfMessage ? WebSocketMessageFlags.EndOfMessage : default, cancellationToken).AsTask();
}
- private ValueTask SendPrivateAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
+ private ValueTask SendPrivateAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken)
{
if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary)
{
@@ -277,13 +299,27 @@ private ValueTask SendPrivateAsync(ReadOnlyMemory buffer, WebSocketMessage
return new ValueTask(Task.FromException(exc));
}
- MessageOpcode opcode =
- _lastSendWasFragment ? MessageOpcode.Continuation :
- messageType == WebSocketMessageType.Binary ? MessageOpcode.Binary :
- MessageOpcode.Text;
+ bool endOfMessage = messageFlags.HasFlag(WebSocketMessageFlags.EndOfMessage);
+ bool disableCompression = messageFlags.HasFlag(WebSocketMessageFlags.DisableCompression);
+ MessageOpcode opcode;
- ValueTask t = SendFrameAsync(opcode, endOfMessage, buffer, cancellationToken);
+ if (_lastSendWasFragment)
+ {
+ if (_lastSendHadDisableCompression != disableCompression)
+ {
+ throw new ArgumentException(SR.net_WebSockets_Argument_MessageFlagsHasDifferentCompressionOptions, nameof(messageFlags));
+ }
+ opcode = MessageOpcode.Continuation;
+ }
+ else
+ {
+ opcode = messageType == WebSocketMessageType.Binary ? MessageOpcode.Binary : MessageOpcode.Text;
+ }
+
+ ValueTask t = SendFrameAsync(opcode, endOfMessage, disableCompression, buffer, cancellationToken);
_lastSendWasFragment = !endOfMessage;
+ _lastSendHadDisableCompression = disableCompression;
+
return t;
}
@@ -299,7 +335,7 @@ public override Task ReceiveAsync(ArraySegment buf
lock (ReceiveAsyncLock) // synchronize with receives in CloseAsync
{
ThrowIfOperationInProgress(_lastReceiveAsync.IsCompleted);
- Task t = ReceiveAsyncPrivate(buffer, cancellationToken).AsTask();
+ Task t = ReceiveAsyncPrivate(buffer, cancellationToken).AsTask();
_lastReceiveAsync = t;
return t;
}
@@ -357,7 +393,12 @@ public override void Abort()
public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
{
- return SendPrivateAsync(buffer, messageType, endOfMessage, cancellationToken);
+ return SendPrivateAsync(buffer, messageType, endOfMessage ? WebSocketMessageFlags.EndOfMessage : default, cancellationToken);
+ }
+
+ public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken)
+ {
+ return SendPrivateAsync(buffer, messageType, messageFlags, cancellationToken);
}
public override ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken)
@@ -371,7 +412,7 @@ public override ValueTask ReceiveAsync(Memory
{
ThrowIfOperationInProgress(_lastReceiveAsync.IsCompleted);
- ValueTask receiveValueTask = ReceiveAsyncPrivate(buffer, cancellationToken);
+ ValueTask receiveValueTask = ReceiveAsyncPrivate(buffer, cancellationToken);
if (receiveValueTask.IsCompletedSuccessfully)
{
_lastReceiveAsync = receiveValueTask.Result.MessageType == WebSocketMessageType.Close ? s_cachedCloseTask : Task.CompletedTask;
@@ -400,7 +441,7 @@ private Task ValidateAndReceiveAsync(Task receiveTask, byte[] buffer, Cancellati
!(receiveTask is Task wsrr && wsrr.Result.MessageType == WebSocketMessageType.Close) &&
!(receiveTask is Task vwsrr && vwsrr.Result.MessageType == WebSocketMessageType.Close)))
{
- ValueTask vt = ReceiveAsyncPrivate(buffer, cancellationToken);
+ ValueTask vt = ReceiveAsyncPrivate(buffer, cancellationToken);
receiveTask =
vt.IsCompletedSuccessfully ? (vt.Result.MessageType == WebSocketMessageType.Close ? s_cachedCloseTask : Task.CompletedTask) :
vt.AsTask();
@@ -409,19 +450,13 @@ private Task ValidateAndReceiveAsync(Task receiveTask, byte[] buffer, Cancellati
return receiveTask;
}
- /// implementation for .
- private readonly struct ValueWebSocketReceiveResultGetter : IWebSocketReceiveResultGetter
- {
- public ValueWebSocketReceiveResult GetResult(int count, WebSocketMessageType messageType, bool endOfMessage, WebSocketCloseStatus? closeStatus, string? closeDescription) =>
- new ValueWebSocketReceiveResult(count, messageType, endOfMessage); // closeStatus/closeDescription are ignored
- }
-
/// Sends a websocket frame to the network.
/// The opcode for the message.
/// The value of the FIN bit for the message.
- /// The buffer containing the payload data fro the message.
+ /// Disables compression for the message.
+ /// The buffer containing the payload data from the message.
/// The CancellationToken to use to cancel the websocket.
- private ValueTask SendFrameAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken)
+ private ValueTask SendFrameAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken)
{
// If a cancelable cancellation token was provided, that would require registering with it, which means more state we have to
// pass around (the CancellationTokenRegistration), so if it is cancelable, just immediately go to the fallback path.
@@ -430,15 +465,16 @@ private ValueTask SendFrameAsync(MessageOpcode opcode, bool endOfMessage, ReadOn
#pragma warning disable CA1416 // Validate platform compatibility, will not wait because timeout equals 0
return cancellationToken.CanBeCanceled || !_sendFrameAsyncLock.Wait(0, default) ?
#pragma warning restore CA1416
- SendFrameFallbackAsync(opcode, endOfMessage, payloadBuffer, cancellationToken) :
- SendFrameLockAcquiredNonCancelableAsync(opcode, endOfMessage, payloadBuffer);
+ SendFrameFallbackAsync(opcode, endOfMessage, disableCompression, payloadBuffer, cancellationToken) :
+ SendFrameLockAcquiredNonCancelableAsync(opcode, endOfMessage, disableCompression, payloadBuffer);
}
/// Sends a websocket frame to the network. The caller must hold the sending lock.
/// The opcode for the message.
/// The value of the FIN bit for the message.
+ /// Disables compression for the message.
/// The buffer containing the payload data fro the message.
- private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory payloadBuffer)
+ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory payloadBuffer)
{
Debug.Assert(_sendFrameAsyncLock.CurrentCount == 0, "Caller should hold the _sendFrameAsyncLock");
@@ -449,7 +485,7 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode,
try
{
// Write the payload synchronously to the buffer, then write that buffer out to the network.
- int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, payloadBuffer.Span);
+ int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, disableCompression, payloadBuffer.Span);
writeTask = _stream.WriteAsync(new ReadOnlyMemory(_sendBuffer, 0, sendBytes));
// If the operation happens to complete synchronously (or, more specifically, by
@@ -503,12 +539,12 @@ private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask)
}
}
- private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken)
+ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlyMemory payloadBuffer, CancellationToken cancellationToken)
{
await _sendFrameAsyncLock.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
- int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, payloadBuffer.Span);
+ int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, disableCompression, payloadBuffer.Span);
using (cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this))
{
await _stream.WriteAsync(new ReadOnlyMemory(_sendBuffer, 0, sendBytes), cancellationToken).ConfigureAwait(false);
@@ -528,10 +564,16 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM
}
/// Writes a frame into the send buffer, which can then be sent over the network.
- private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, ReadOnlySpan payloadBuffer)
+ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, bool disableCompression, ReadOnlySpan payloadBuffer)
{
- // Ensure we have a _sendBuffer.
- AllocateSendBuffer(payloadBuffer.Length + MaxMessageHeaderLength);
+ if (_deflater is not null && !disableCompression)
+ {
+ payloadBuffer = _deflater.Deflate(payloadBuffer, endOfMessage);
+ }
+ int payloadLength = payloadBuffer.Length;
+
+ // Ensure we have a _sendBuffer
+ AllocateSendBuffer(payloadLength + MaxMessageHeaderLength);
Debug.Assert(_sendBuffer != null);
// Write the message header data to the buffer.
@@ -541,31 +583,34 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, Read
{
// The server doesn't send a mask, so the mask offset returned by WriteHeader
// is actually the end of the header.
- headerLength = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: false);
+ headerLength = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: false, compressed: _deflater is not null && !disableCompression);
}
else
{
// We need to know where the mask starts so that we can use the mask to manipulate the payload data,
// and we need to know the total length for sending it on the wire.
- maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true);
+ maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true, compressed: _deflater is not null && !disableCompression);
headerLength = maskOffset.GetValueOrDefault() + MaskLength;
}
// Write the payload
if (payloadBuffer.Length > 0)
{
- payloadBuffer.CopyTo(new Span(_sendBuffer, headerLength, payloadBuffer.Length));
+ payloadBuffer.CopyTo(new Span(_sendBuffer, headerLength, payloadLength));
+
+ // Release the deflater buffer if any, we're not going to need the payloadBuffer anymore.
+ _deflater?.ReleaseBuffer();
// If we added a mask to the header, XOR the payload with the mask. We do the manipulation in the send buffer so as to avoid
// changing the data in the caller-supplied payload buffer.
if (maskOffset.HasValue)
{
- ApplyMask(new Span(_sendBuffer, headerLength, payloadBuffer.Length), _sendBuffer, maskOffset.Value, 0);
+ ApplyMask(new Span(_sendBuffer, headerLength, payloadLength), _sendBuffer, maskOffset.Value, 0);
}
}
// Return the number of bytes in the send buffer
- return headerLength + payloadBuffer.Length;
+ return headerLength + payloadLength;
}
private void SendKeepAliveFrameAsync()
@@ -578,7 +623,7 @@ private void SendKeepAliveFrameAsync()
// This exists purely to keep the connection alive; don't wait for the result, and ignore any failures.
// The call will handle releasing the lock. We send a pong rather than ping, since it's allowed by
// the RFC as a unidirectional heartbeat and we're not interested in waiting for a response.
- ValueTask t = SendFrameLockAcquiredNonCancelableAsync(MessageOpcode.Pong, true, ReadOnlyMemory.Empty);
+ ValueTask t = SendFrameLockAcquiredNonCancelableAsync(MessageOpcode.Pong, endOfMessage: true, disableCompression: true, ReadOnlyMemory.Empty);
if (t.IsCompletedSuccessfully)
{
t.GetAwaiter().GetResult();
@@ -599,7 +644,7 @@ private void SendKeepAliveFrameAsync()
}
}
- private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnlySpan payload, bool endOfMessage, bool useMask)
+ private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnlySpan payload, bool endOfMessage, bool useMask, bool compressed)
{
// Client header format:
// 1 bit - FIN - 1 if this is the final fragment in the message (it could be the only fragment), otherwise 0
@@ -629,6 +674,11 @@ private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnly
{
sendBuffer[0] |= 0x80; // 1 bit for FIN
}
+ if (compressed && opcode != MessageOpcode.Continuation)
+ {
+ // Per-Message Deflate flag needs to be set only in the first frame
+ sendBuffer[0] |= 0b_0100_0000;
+ }
// Store the payload length.
int maskOffset;
@@ -680,13 +730,8 @@ private static void WriteRandomMask(byte[] buffer, int offset) =>
///
/// The buffer into which payload data should be written.
/// The CancellationToken used to cancel the websocket.
- /// Used to get the result. Allows the same method to be used with both WebSocketReceiveResult and ValueWebSocketReceiveResult.
/// Information about the received message.
- private async ValueTask ReceiveAsyncPrivate(
- Memory payloadBuffer,
- CancellationToken cancellationToken,
- TWebSocketReceiveResultGetter resultGetter = default)
- where TWebSocketReceiveResultGetter : struct, IWebSocketReceiveResultGetter // constrained to avoid boxing and enable inlining
+ private async ValueTask ReceiveAsyncPrivate(Memory payloadBuffer, CancellationToken cancellationToken)
{
// This is a long method. While splitting it up into pieces would arguably help with readability, doing so would
// also result in more allocations, as each async method that yields ends up with multiple allocations. The impact
@@ -707,7 +752,7 @@ private async ValueTask ReceiveAsyncPrivate ReceiveAsyncPrivate ReceiveAsyncPrivate(0, WebSocketMessageType.Close, true);
}
// If this is a continuation, replace the opcode with the one of the message it's continuing
if (header.Opcode == MessageOpcode.Continuation)
{
header.Opcode = _lastReceiveHeader.Opcode;
+ header.Compressed = _lastReceiveHeader.Compressed;
}
// The message should now be a binary or text message. Handle it by reading the payload and returning the contents.
Debug.Assert(header.Opcode == MessageOpcode.Binary || header.Opcode == MessageOpcode.Text, $"Unexpected opcode {header.Opcode}");
// If there's no data to read, return an appropriate result.
- if (header.PayloadLength == 0 || payloadBuffer.Length == 0)
+ if (header.Processed || payloadBuffer.Length == 0)
{
_lastReceiveHeader = header;
- return resultGetter.GetResult(
- 0,
- header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary,
- header.Fin && header.PayloadLength == 0,
- null, null);
+ return GetReceiveResult(
+ count: 0,
+ messageType: header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary,
+ endOfMessage: header.EndOfMessage);
}
// Otherwise, read as much of the payload as we can efficiently, and update the header to reflect how much data
@@ -779,56 +832,86 @@ private async ValueTask ReceiveAsyncPrivate 0)
- {
- int receiveBufferBytesToCopy = Math.Min(payloadBuffer.Length, (int)Math.Min(header.PayloadLength, _receiveBufferCount));
- Debug.Assert(receiveBufferBytesToCopy > 0);
- _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(payloadBuffer.Span);
- ConsumeFromBuffer(receiveBufferBytesToCopy);
- totalBytesReceived += receiveBufferBytesToCopy;
- Debug.Assert(
- _receiveBufferCount == 0 ||
- totalBytesReceived == payloadBuffer.Length ||
- totalBytesReceived == header.PayloadLength);
- }
- // Then read directly into the payload buffer until we've hit a limit.
- while (totalBytesReceived < payloadBuffer.Length &&
- totalBytesReceived < header.PayloadLength)
+ // Only start a new receive if we haven't received the entire frame.
+ if (header.PayloadLength > 0)
{
- int numBytesRead = await _stream.ReadAsync(payloadBuffer.Slice(
- totalBytesReceived,
- (int)Math.Min(payloadBuffer.Length, header.PayloadLength) - totalBytesReceived), cancellationToken).ConfigureAwait(false);
- if (numBytesRead <= 0)
+ if (header.Compressed)
+ {
+ Debug.Assert(_inflater is not null);
+ _inflater.Prepare(header.PayloadLength, payloadBuffer.Length);
+ }
+
+ // Read directly into the appropriate buffer until we've hit a limit.
+ int limit = (int)Math.Min(header.Compressed ? _inflater!.Span.Length : payloadBuffer.Length, header.PayloadLength);
+
+ if (_receiveBufferCount > 0)
+ {
+ int receiveBufferBytesToCopy = Math.Min(limit, _receiveBufferCount);
+ Debug.Assert(receiveBufferBytesToCopy > 0);
+
+ _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(
+ header.Compressed ? _inflater!.Span : payloadBuffer.Span);
+ ConsumeFromBuffer(receiveBufferBytesToCopy);
+ totalBytesReceived += receiveBufferBytesToCopy;
+ }
+
+ while (totalBytesReceived < limit)
{
- ThrowIfEOFUnexpected(throwOnPrematureClosure: true);
- break;
+ int numBytesRead = await _stream.ReadAsync(header.Compressed ?
+ _inflater!.Memory.Slice(totalBytesReceived, limit - totalBytesReceived) :
+ payloadBuffer.Slice(totalBytesReceived, limit - totalBytesReceived),
+ cancellationToken).ConfigureAwait(false);
+ if (numBytesRead <= 0)
+ {
+ ThrowIfEOFUnexpected(throwOnPrematureClosure: true);
+ break;
+ }
+ totalBytesReceived += numBytesRead;
+ }
+
+ if (_isServer)
+ {
+ _receivedMaskOffsetOffset = ApplyMask(header.Compressed ?
+ _inflater!.Span.Slice(0, totalBytesReceived) :
+ payloadBuffer.Span.Slice(0, totalBytesReceived), header.Mask, _receivedMaskOffsetOffset);
+ }
+
+ header.PayloadLength -= totalBytesReceived;
+
+ if (header.Compressed)
+ {
+ _inflater!.AddBytes(totalBytesReceived, endOfMessage: header.Fin && header.PayloadLength == 0);
}
- totalBytesReceived += numBytesRead;
}
- if (_isServer)
+ if (header.Compressed)
+ {
+ // In case of compression totalBytesReceived should actually represent how much we've
+ // inflated, rather than how much we've read from the stream.
+ header.Processed = _inflater!.Inflate(payloadBuffer.Span, out totalBytesReceived) && header.PayloadLength == 0;
+ }
+ else
{
- _receivedMaskOffsetOffset = ApplyMask(payloadBuffer.Span.Slice(0, totalBytesReceived), header.Mask, _receivedMaskOffsetOffset);
+ // Without compression the frame is processed as soon as we've received everything
+ header.Processed = header.PayloadLength == 0;
}
- header.PayloadLength -= totalBytesReceived;
// If this a text message, validate that it contains valid UTF8.
if (header.Opcode == MessageOpcode.Text &&
- !TryValidateUtf8(payloadBuffer.Span.Slice(0, totalBytesReceived), header.Fin && header.PayloadLength == 0, _utf8TextState))
+ !TryValidateUtf8(payloadBuffer.Span.Slice(0, totalBytesReceived), header.EndOfMessage, _utf8TextState))
{
await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData, WebSocketError.Faulted).ConfigureAwait(false);
}
_lastReceiveHeader = header;
- return resultGetter.GetResult(
+ return GetReceiveResult(
totalBytesReceived,
header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary,
- header.Fin && header.PayloadLength == 0,
- null, null);
+ header.EndOfMessage);
}
}
- catch (Exception exc) when (!(exc is OperationCanceledException))
+ catch (Exception exc) when (exc is not OperationCanceledException)
{
if (_state == WebSocketState.Aborted)
{
@@ -849,6 +932,23 @@ private async ValueTask ReceiveAsyncPrivate
+ /// Returns either or .
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private TResult GetReceiveResult(int count, WebSocketMessageType messageType, bool endOfMessage)
+ {
+ if (typeof(TResult) == typeof(ValueWebSocketReceiveResult))
+ {
+ // Although it might seem that this will incur boxing of the struct,
+ // the JIT is smart enough to figure out it is unncessessary and will emit
+ // bytecode that returns the ValueWebSocketReceiveResult directly.
+ return (TResult)(object)new ValueWebSocketReceiveResult(count, messageType, endOfMessage);
+ }
+
+ return (TResult)(object)new WebSocketReceiveResult(count, messageType, endOfMessage, _closeStatus, _closeStatusDescription);
+ }
+
/// Processes a received close message.
/// The message header.
/// The CancellationToken used to cancel the websocket operation.
@@ -967,6 +1067,7 @@ private async ValueTask HandleReceivedPingPongAsync(MessageHeader header, Cancel
await SendFrameAsync(
MessageOpcode.Pong,
endOfMessage: true,
+ disableCompression: true,
_receiveBuffer.Slice(_receiveBufferOffset, (int)header.PayloadLength),
cancellationToken).ConfigureAwait(false);
}
@@ -1051,8 +1152,9 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync(
Span receiveBufferSpan = _receiveBuffer.Span;
header.Fin = (receiveBufferSpan[_receiveBufferOffset] & 0x80) != 0;
- bool reservedSet = (receiveBufferSpan[_receiveBufferOffset] & 0x70) != 0;
+ bool reservedSet = (receiveBufferSpan[_receiveBufferOffset] & 0b_0011_0000) != 0;
header.Opcode = (MessageOpcode)(receiveBufferSpan[_receiveBufferOffset] & 0xF);
+ header.Compressed = (receiveBufferSpan[_receiveBufferOffset] & 0b_0100_0000) != 0;
bool masked = (receiveBufferSpan[_receiveBufferOffset + 1] & 0x80) != 0;
header.PayloadLength = receiveBufferSpan[_receiveBufferOffset + 1] & 0x7F;
@@ -1083,6 +1185,12 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync(
return SR.net_Websockets_ReservedBitsSet;
}
+ if (header.Compressed && _inflater is null)
+ {
+ resultHeader = default;
+ return SR.net_Websockets_PerMessageCompressedFlagWhenNotEnabled;
+ }
+
if (masked)
{
if (!_isServer)
@@ -1106,6 +1214,16 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync(
resultHeader = default;
return SR.net_Websockets_ContinuationFromFinalFrame;
}
+ if (header.Compressed)
+ {
+ // Must not mark continuations as compressed
+ resultHeader = default;
+ return SR.net_Websockets_PerMessageCompressedFlagInContinuation;
+ }
+
+ // Set the compressed flag from the previous header so the receive procedure can use it
+ // directly without needing to check the previous header in case of continuations.
+ header.Compressed = _lastReceiveHeader.Compressed;
break;
case MessageOpcode.Binary:
@@ -1137,6 +1255,7 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync(
// Return the read header
resultHeader = header;
+ resultHeader.Processed = header.PayloadLength == 0;
return null;
}
@@ -1248,7 +1367,7 @@ private async ValueTask SendCloseFrameAsync(WebSocketCloseStatus closeStatus, st
buffer[0] = (byte)(closeStatusValue >> 8);
buffer[1] = (byte)(closeStatusValue & 0xFF);
- await SendFrameAsync(MessageOpcode.Close, true, new Memory(buffer, 0, count), cancellationToken).ConfigureAwait(false);
+ await SendFrameAsync(MessageOpcode.Close, endOfMessage: true, disableCompression: true, new Memory(buffer, 0, count), cancellationToken).ConfigureAwait(false);
}
finally
{
@@ -1580,24 +1699,18 @@ private struct MessageHeader
internal MessageOpcode Opcode;
internal bool Fin;
internal long PayloadLength;
+ internal bool Compressed;
internal int Mask;
- }
- ///
- /// Interface used by to enable it to return
- /// different result types in an efficient manner.
- ///
- /// The type of the result
- private interface IWebSocketReceiveResultGetter
- {
- TResult GetResult(int count, WebSocketMessageType messageType, bool endOfMessage, WebSocketCloseStatus? closeStatus, string? closeDescription);
- }
+ ///
+ /// Returns if frame has been received and processed.
+ ///
+ internal bool Processed { get; set; }
- /// implementation for .
- private readonly struct WebSocketReceiveResultGetter : IWebSocketReceiveResultGetter
- {
- public WebSocketReceiveResult GetResult(int count, WebSocketMessageType messageType, bool endOfMessage, WebSocketCloseStatus? closeStatus, string? closeDescription) =>
- new WebSocketReceiveResult(count, messageType, endOfMessage, closeStatus, closeDescription);
+ ///
+ /// Returns if message has been received and processed.
+ ///
+ internal bool EndOfMessage => Fin && Processed && PayloadLength == 0;
}
}
}
diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs
index 3bd6835a16f1d4..044c7b95536bef 100644
--- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs
+++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs
@@ -58,6 +58,11 @@ public virtual ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessage
new ValueTask(SendAsync(arraySegment, messageType, endOfMessage, cancellationToken)) :
SendWithArrayPoolAsync(buffer, messageType, endOfMessage, cancellationToken);
+ public virtual ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken = default)
+ {
+ return SendAsync(buffer, messageType, messageFlags.HasFlag(WebSocketMessageFlags.EndOfMessage), cancellationToken);
+ }
+
private async ValueTask SendWithArrayPoolAsync(
ReadOnlyMemory buffer,
WebSocketMessageType messageType,
@@ -157,7 +162,24 @@ public static WebSocket CreateFromStream(Stream stream, bool isServer, string? s
0));
}
- return ManagedWebSocket.CreateFromConnectedStream(stream, isServer, subProtocol, keepAliveInterval);
+ return new ManagedWebSocket(stream, isServer, subProtocol, keepAliveInterval);
+ }
+
+ /// Creates a that operates on a representing a web socket connection.
+ /// The for the connection.
+ /// The options with which the websocket must be created.
+ public static WebSocket CreateFromStream(Stream stream, WebSocketCreationOptions options)
+ {
+ if (stream is null)
+ throw new ArgumentNullException(nameof(stream));
+
+ if (options is null)
+ throw new ArgumentNullException(nameof(options));
+
+ if (!stream.CanRead || !stream.CanWrite)
+ throw new ArgumentException(!stream.CanRead ? SR.NotReadableStream : SR.NotWriteableStream, nameof(stream));
+
+ return new ManagedWebSocket(stream, options);
}
[EditorBrowsable(EditorBrowsableState.Never)]
@@ -209,8 +231,7 @@ public static WebSocket CreateClientWebSocket(Stream innerStream,
// Ignore useZeroMaskingKey. ManagedWebSocket doesn't currently support that debugging option.
// Ignore internalBuffer. ManagedWebSocket uses its own small buffer for headers/control messages.
-
- return ManagedWebSocket.CreateFromConnectedStream(innerStream, false, subProtocol, keepAliveInterval);
+ return new ManagedWebSocket(innerStream, false, subProtocol, keepAliveInterval);
}
}
}
diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs
new file mode 100644
index 00000000000000..d042583da54448
--- /dev/null
+++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs
@@ -0,0 +1,63 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Threading;
+
+namespace System.Net.WebSockets
+{
+ ///
+ /// Options that control how a is created.
+ ///
+ public sealed class WebSocketCreationOptions
+ {
+ private string? _subProtocol;
+ private TimeSpan _keepAliveInterval;
+
+ ///
+ /// Defines if this websocket is the server-side of the connection. The default value is false.
+ ///
+ public bool IsServer { get; set; }
+
+ ///
+ /// The agreed upon sub-protocol that was used when creating the connection.
+ ///
+ public string? SubProtocol
+ {
+ get => _subProtocol;
+ set
+ {
+ if (value is not null)
+ {
+ WebSocketValidate.ValidateSubprotocol(value);
+ }
+ _subProtocol = value;
+ }
+ }
+
+ ///
+ /// The keep-alive interval to use, or or to disable keep-alives.
+ /// The default is .
+ ///
+ public TimeSpan KeepAliveInterval
+ {
+ get => _keepAliveInterval;
+ set
+ {
+ if (value != Timeout.InfiniteTimeSpan && value < TimeSpan.Zero)
+ {
+ throw new ArgumentOutOfRangeException(nameof(KeepAliveInterval), value,
+ SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, 0));
+ }
+ _keepAliveInterval = value;
+ }
+ }
+
+ ///
+ /// The agreed upon options for per message deflate.
+ /// Be aware that enabling compression makes the application subject to CRIME/BREACH type of attacks.
+ /// It is strongly advised to turn off compression when sending data containing secrets by
+ /// specifying flag for such messages.
+ ///
+ public WebSocketDeflateOptions? DangerousDeflateOptions { get; set; }
+ }
+}
diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs
new file mode 100644
index 00000000000000..e497751db288e4
--- /dev/null
+++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs
@@ -0,0 +1,71 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace System.Net.WebSockets
+{
+ ///
+ /// Options to enable per-message deflate compression for .
+ ///
+ ///
+ /// Although the WebSocket spec allows window bits from 8 to 15, the current implementation doesn't support 8 bits.
+ ///
+ public sealed class WebSocketDeflateOptions
+ {
+ private int _clientMaxWindowBits = WebSocketValidate.MaxDeflateWindowBits;
+ private int _serverMaxWindowBits = WebSocketValidate.MaxDeflateWindowBits;
+
+ ///
+ /// This parameter indicates the base-2 logarithm for the LZ77 sliding window size used by
+ /// the client to compress messages and by the server to decompress them.
+ /// Must be a value between 9 and 15. The default is 15.
+ ///
+ /// https://tools.ietf.org/html/rfc7692#section-7.1.2.2
+ public int ClientMaxWindowBits
+ {
+ get => _clientMaxWindowBits;
+ set
+ {
+ if (value < WebSocketValidate.MinDeflateWindowBits || value > WebSocketValidate.MaxDeflateWindowBits)
+ {
+ throw new ArgumentOutOfRangeException(nameof(ClientMaxWindowBits), value,
+ SR.Format(SR.net_WebSockets_ArgumentOutOfRange, WebSocketValidate.MinDeflateWindowBits, WebSocketValidate.MaxDeflateWindowBits));
+ }
+ _clientMaxWindowBits = value;
+ }
+ }
+
+ ///
+ /// When true the client-side of the connection indicates that it will persist the deflate context accross messages.
+ /// The default is true.
+ ///
+ /// https://tools.ietf.org/html/rfc7692#section-7.1.1.2
+ public bool ClientContextTakeover { get; set; } = true;
+
+ ///
+ /// This parameter indicates the base-2 logarithm for the LZ77 sliding window size used by
+ /// the server to compress messages and by the client to decompress them.
+ /// Must be a value between 9 and 15. The default is 15.
+ ///
+ /// https://tools.ietf.org/html/rfc7692#section-7.1.2.1
+ public int ServerMaxWindowBits
+ {
+ get => _serverMaxWindowBits;
+ set
+ {
+ if (value < WebSocketValidate.MinDeflateWindowBits || value > WebSocketValidate.MaxDeflateWindowBits)
+ {
+ throw new ArgumentOutOfRangeException(nameof(ServerMaxWindowBits), value,
+ SR.Format(SR.net_WebSockets_ArgumentOutOfRange, WebSocketValidate.MinDeflateWindowBits, WebSocketValidate.MaxDeflateWindowBits));
+ }
+ _serverMaxWindowBits = value;
+ }
+ }
+
+ ///
+ /// When true the server-side of the connection indicates that it will persist the deflate context accross messages.
+ /// The default is true.
+ ///
+ /// https://tools.ietf.org/html/rfc7692#section-7.1.1.1
+ public bool ServerContextTakeover { get; set; } = true;
+ }
+}
diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketMessageFlags.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketMessageFlags.cs
new file mode 100644
index 00000000000000..9ce165d8de8433
--- /dev/null
+++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketMessageFlags.cs
@@ -0,0 +1,27 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace System.Net.WebSockets
+{
+ ///
+ /// Flags for controlling how the should send a message.
+ ///
+ [Flags]
+ public enum WebSocketMessageFlags
+ {
+ ///
+ /// None
+ ///
+ None = 0,
+
+ ///
+ /// Indicates that the data in "buffer" is the last part of a message.
+ ///
+ EndOfMessage = 1,
+
+ ///
+ /// Disables compression for the message if compression has been enabled for the instance.
+ ///
+ DisableCompression = 2
+ }
+}
diff --git a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj
index 7cf0328df31ca8..4e0bc74ebdaeca 100644
--- a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj
+++ b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj
@@ -7,6 +7,9 @@
+
+
+ (() => options.ClientMaxWindowBits = 8);
+ Assert.Throws(() => options.ClientMaxWindowBits = 16);
+
+ options.ClientMaxWindowBits = 14;
+ Assert.Equal(14, options.ClientMaxWindowBits);
+ }
+
+ [Fact]
+ public void ServerMaxWindowBits()
+ {
+ WebSocketDeflateOptions options = new();
+ Assert.Equal(15, options.ServerMaxWindowBits);
+
+ Assert.Throws(() => options.ServerMaxWindowBits = 8);
+ Assert.Throws(() => options.ServerMaxWindowBits = 16);
+
+ options.ServerMaxWindowBits = 14;
+ Assert.Equal(14, options.ServerMaxWindowBits);
+ }
+
+ [Fact]
+ public void ContextTakeover()
+ {
+ WebSocketDeflateOptions options = new();
+
+ Assert.True(options.ClientContextTakeover);
+ Assert.True(options.ServerContextTakeover);
+
+ options.ClientContextTakeover = false;
+ Assert.False(options.ClientContextTakeover);
+
+ options.ServerContextTakeover = false;
+ Assert.False(options.ServerContextTakeover);
+ }
+ }
+}
diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs
new file mode 100644
index 00000000000000..25efbe94b1d5bd
--- /dev/null
+++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs
@@ -0,0 +1,626 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.IO;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace System.Net.WebSockets.Tests
+{
+ public class WebSocketDeflateTests
+ {
+ private readonly CancellationTokenSource? _cancellation;
+
+ public WebSocketDeflateTests()
+ {
+ if (!Debugger.IsAttached)
+ {
+ _cancellation = new CancellationTokenSource(TimeSpan.FromSeconds(5));
+ }
+ }
+
+ public CancellationToken CancellationToken => _cancellation?.Token ?? default;
+
+ public static IEnumerable