Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[browser] [wasm] Refactor Request Streaming to use HttpContent.CopyToAsync #91699

Merged
merged 15 commits into from
Sep 21, 2023
Merged
99 changes: 92 additions & 7 deletions src/libraries/Common/tests/System/Net/Http/ResponseStreamTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ public async Task BrowserHttpHandler_Streaming()

int readOffset = 0;
req.Content = new StreamContent(new DelegateStream(
canReadFunc: () => true,
readFunc: (buffer, offset, count) => throw new FormatException(),
readAsyncFunc: async (buffer, offset, count, cancellationToken) =>
{
await Task.Delay(1);
Expand Down Expand Up @@ -295,8 +297,11 @@ public async Task BrowserHttpHandler_StreamingRequest()
req.Options.Set(WebAssemblyEnableStreamingRequestKey, true);

int size = 1500 * 1024 * 1024;
int multipartOverhead = 125 + 4 /* "test" */;
int remaining = size;
req.Content = new StreamContent(new DelegateStream(
var content = new MultipartFormDataContent();
content.Add(new StreamContent(new DelegateStream(
canReadFunc: () => true,
readAsyncFunc: (buffer, offset, count, cancellationToken) =>
{
if (remaining > 0)
Expand All @@ -307,15 +312,16 @@ public async Task BrowserHttpHandler_StreamingRequest()
return Task.FromResult(send);
}
return Task.FromResult(0);
}));
})), "test");
req.Content = content;

req.Content.Headers.Add("Content-MD5-Skip", "browser");

using (HttpClient client = CreateHttpClientForRemoteServer(Configuration.Http.RemoteHttp2Server))
using (HttpResponseMessage response = await client.SendAsync(req))
{
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal(size.ToString(), Assert.Single(response.Headers.GetValues("X-HttpRequest-Body-Length")));
Assert.Equal((size + multipartOverhead).ToString(), Assert.Single(response.Headers.GetValues("X-HttpRequest-Body-Length")));
// Streaming requests can't set Content-Length
Assert.False(response.Headers.Contains("X-HttpRequest-Headers-ContentLength"));
}
Expand All @@ -335,22 +341,101 @@ public async Task BrowserHttpHandler_StreamingRequest_ThrowFromContentCopy_Reque
req.Options.Set(WebAssemblyEnableStreamingRequestKey, true);

Exception error = new FormatException();
var content = new StreamContent(new DelegateStream(
req.Content = new StreamContent(new DelegateStream(
canSeekFunc: () => true,
lengthFunc: () => 12345678,
positionGetFunc: () => 0,
canReadFunc: () => true,
readFunc: (buffer, offset, count) => throw error,
readFunc: (buffer, offset, count) => throw new FormatException(),
readAsyncFunc: (buffer, offset, count, cancellationToken) => syncFailure ? throw error : Task.Delay(1).ContinueWith<int>(_ => throw error)));

req.Content = content;

using (HttpClient client = CreateHttpClientForRemoteServer(Configuration.Http.RemoteHttp2Server))
{
Assert.Same(error, await Assert.ThrowsAsync<FormatException>(() => client.SendAsync(req)));
}
}

public static TheoryData CancelRequestReadFunctions
=> new TheoryData<bool, Func<Task<int>>>
{
{ false, () => Task.FromResult(0) },
{ true, () => Task.FromResult(0) },
{ false, () => Task.FromResult(1) },
{ true, () => Task.FromResult(1) },
{ false, () => throw new FormatException() },
{ true, () => throw new FormatException() },
};

[OuterLoop]
[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsBrowser))]
[MemberData(nameof(CancelRequestReadFunctions))]
public async Task BrowserHttpHandler_StreamingRequest_CancelRequest(bool cancelAsync, Func<Task<int>> readFunc)
{
var WebAssemblyEnableStreamingRequestKey = new HttpRequestOptionsKey<bool>("WebAssemblyEnableStreamingRequest");

var req = new HttpRequestMessage(HttpMethod.Post, Configuration.Http.Http2RemoteEchoServer);

req.Options.Set(WebAssemblyEnableStreamingRequestKey, true);

using var cts = new CancellationTokenSource();
var token = cts.Token;
int readNotCancelledCount = 0, readCancelledCount = 0;
req.Content = new StreamContent(new DelegateStream(
canReadFunc: () => true,
readFunc: (buffer, offset, count) => throw new FormatException(),
readAsyncFunc: async (buffer, offset, count, cancellationToken) =>
{
if (cancelAsync) await Task.Delay(1);
Assert.Equal(token.IsCancellationRequested, cancellationToken.IsCancellationRequested);
if (!token.IsCancellationRequested)
{
readNotCancelledCount++;
cts.Cancel();
}
else
{
readCancelledCount++;
}
return await readFunc();
}));

using (HttpClient client = CreateHttpClientForRemoteServer(Configuration.Http.RemoteHttp2Server))
{
TaskCanceledException ex = await Assert.ThrowsAsync<TaskCanceledException>(() => client.SendAsync(req, token));
Assert.Equal(token, ex.CancellationToken);
Assert.Equal(1, readNotCancelledCount);
Assert.Equal(0, readCancelledCount);
}
}

[OuterLoop]
[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsBrowser))]
public async Task BrowserHttpHandler_StreamingRequest_Http1Fails()
{
var WebAssemblyEnableStreamingRequestKey = new HttpRequestOptionsKey<bool>("WebAssemblyEnableStreamingRequest");

var req = new HttpRequestMessage(HttpMethod.Post, Configuration.Http.RemoteHttp11Server.BaseUri);

req.Options.Set(WebAssemblyEnableStreamingRequestKey, true);

int readCount = 0;
req.Content = new StreamContent(new DelegateStream(
canReadFunc: () => true,
readFunc: (buffer, offset, count) => throw new FormatException(),
readAsyncFunc: (buffer, offset, count, cancellationToken) =>
{
readCount++;
return Task.FromResult(1);
}));

using (HttpClient client = CreateHttpClientForRemoteServer(Configuration.Http.RemoteHttp11Server))
{
HttpRequestException ex = await Assert.ThrowsAsync<HttpRequestException>(() => client.SendAsync(req));
Assert.Equal("TypeError: Failed to fetch", ex.Message);
Assert.Equal(1, readCount);
}
}

[OuterLoop]
[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsBrowser))]
public async Task BrowserHttpHandler_StreamingResponse()
Expand Down
3 changes: 3 additions & 0 deletions src/libraries/System.Net.Http/src/Resources/Strings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,9 @@
<data name="net_http_synchronous_reads_not_supported" xml:space="preserve">
<value>Synchronous reads are not supported, use ReadAsync instead.</value>
</data>
<data name="net_http_synchronous_writes_not_supported" xml:space="preserve">
<value>Synchronous writes are not supported, use WriteAsync instead.</value>
</data>
<data name="net_socks_auth_failed" xml:space="preserve">
<value>Failed to authenticate with the SOCKS server.</value>
</data>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ private static async Task<WasmFetchResponse> CallFetch(HttpRequestMessage reques
}
}

Task<JSObject>? promise;
JSObject? fetchResponse;
cancellationToken.ThrowIfCancellationRequested();
if (request.Content != null)
{
Expand All @@ -220,28 +220,43 @@ private static async Task<WasmFetchResponse> CallFetch(HttpRequestMessage reques

if (streamingEnabled)
{
Stream stream = await request.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(true);
cancellationToken.ThrowIfCancellationRequested();

ReadableStreamPullState pullState = new ReadableStreamPullState(stream, cancellationToken);

promise = BrowserHttpInterop.Fetch(uri, headerNames.ToArray(), headerValues.ToArray(), optionNames, optionValues, abortController, ReadableStreamPull, pullState);
using (JSObject transformStream = BrowserHttpInterop.CreateTransformStream())
{
Task<JSObject> fetchPromise = BrowserHttpInterop.Fetch(uri, headerNames.ToArray(), headerValues.ToArray(), optionNames, optionValues, abortController, transformStream);
ValueTask<JSObject> fetchTask = BrowserHttpInterop.CancelationHelper(fetchPromise, cancellationToken, abortController, null); // initialize fetch cancellation

using (WasmHttpWriteStream stream = new WasmHttpWriteStream(transformStream))
{
try
{
await request.Content.CopyToAsync(stream, cancellationToken).ConfigureAwait(true);
await BrowserHttpInterop.TransformStreamClose(transformStream).ConfigureAwait(true);
campersau marked this conversation as resolved.
Show resolved Hide resolved
}
catch (Exception ex)
{
BrowserHttpInterop.TransformStreamAbort(transformStream, ex);
// don't rethrow, prefer exceptions from fetch
}
}

fetchResponse = await fetchTask.ConfigureAwait(true);
}
}
else
{
byte[] buffer = await request.Content.ReadAsByteArrayAsync(cancellationToken).ConfigureAwait(true);
cancellationToken.ThrowIfCancellationRequested();

promise = BrowserHttpInterop.Fetch(uri, headerNames.ToArray(), headerValues.ToArray(), optionNames, optionValues, abortController, buffer);
Task<JSObject> fetchPromise = BrowserHttpInterop.Fetch(uri, headerNames.ToArray(), headerValues.ToArray(), optionNames, optionValues, abortController, buffer);
pavelsavara marked this conversation as resolved.
Show resolved Hide resolved
fetchResponse = await BrowserHttpInterop.CancelationHelper(fetchPromise, cancellationToken, abortController, null).ConfigureAwait(true);
}
}
else
{
promise = BrowserHttpInterop.Fetch(uri, headerNames.ToArray(), headerValues.ToArray(), optionNames, optionValues, abortController);
Task<JSObject> fetchPromise = BrowserHttpInterop.Fetch(uri, headerNames.ToArray(), headerValues.ToArray(), optionNames, optionValues, abortController);
fetchResponse = await BrowserHttpInterop.CancelationHelper(fetchPromise, cancellationToken, abortController, null).ConfigureAwait(true);
}

cancellationToken.ThrowIfCancellationRequested();
JSObject fetchResponse = await BrowserHttpInterop.CancelationHelper(promise, cancellationToken, abortController, null).ConfigureAwait(true);
return new WasmFetchResponse(fetchResponse, abortController, abortRegistration.Value);
}
catch (JSException jse)
Expand All @@ -257,14 +272,6 @@ private static async Task<WasmFetchResponse> CallFetch(HttpRequestMessage reques
}
}

private static void ReadableStreamPull(object state)
{
ReadableStreamPullState pullState = (ReadableStreamPullState)state;
#pragma warning disable CS4014 // intentionally not awaited
pullState.PullAsync();
#pragma warning restore CS4014
}

private static HttpResponseMessage ConvertResponse(HttpRequestMessage request, WasmFetchResponse fetchResponse)
{
#if FEATURE_WASM_THREADS
Expand Down Expand Up @@ -329,41 +336,81 @@ static async Task<HttpResponseMessage> Impl(HttpRequestMessage request, Cancella
}
}

internal sealed class ReadableStreamPullState
internal sealed class WasmHttpWriteStream : Stream
{
private readonly Stream _stream;
private readonly CancellationToken _cancellationToken;
private readonly byte[] _buffer;
private readonly JSObject _transformStream;

public ReadableStreamPullState(Stream stream, CancellationToken cancellationToken)
public WasmHttpWriteStream(JSObject transformStream)
{
ArgumentNullException.ThrowIfNull(stream);
ArgumentNullException.ThrowIfNull(transformStream);

_stream = stream;
_cancellationToken = cancellationToken;
_buffer = new byte[65536];
_transformStream = transformStream;
}

public async Task PullAsync()
private async Task WriteAsyncCore(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
{
try
{
int length = await _stream.ReadAsync(_buffer, _cancellationToken).ConfigureAwait(true);
ReadableStreamControllerEnqueueUnsafe(this, _buffer, length);
}
catch (Exception ex)
cancellationToken.ThrowIfCancellationRequested();
campersau marked this conversation as resolved.
Show resolved Hide resolved
using (Buffers.MemoryHandle handle = buffer.Pin())
{
BrowserHttpInterop.ReadableStreamControllerError(this, ex);
await TransformStreamWriteUnsafe(_transformStream, buffer, handle).ConfigureAwait(true);
campersau marked this conversation as resolved.
Show resolved Hide resolved
}

static unsafe Task TransformStreamWriteUnsafe(JSObject transformStream, ReadOnlyMemory<byte> buffer, Buffers.MemoryHandle handle)
=> BrowserHttpInterop.TransformStreamWrite(transformStream, (nint)handle.Pointer, buffer.Length);
}

private static unsafe void ReadableStreamControllerEnqueueUnsafe(object pullState, byte[] buffer, int length)
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
{
return new ValueTask(WriteAsyncCore(buffer, cancellationToken));
}

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
ValidateBufferArguments(buffer, offset, count);
return WriteAsyncCore(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken);
}

public override bool CanRead => false;
public override bool CanSeek => false;
public override bool CanWrite => true;

protected override void Dispose(bool disposing)
{
_transformStream.Dispose();
}

public override void Flush()
{
fixed (byte* ptr = buffer)
{
BrowserHttpInterop.ReadableStreamControllerEnqueue(pullState, (nint)ptr, length);
}
}

#region PlatformNotSupported

public override long Position
{
get => throw new NotSupportedException();
set => throw new NotSupportedException();
}
public override long Length => throw new NotSupportedException();
public override int Read(byte[] buffer, int offset, int count)
{
throw new NotSupportedException();
}

public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}

public override void SetLength(long value)
{
throw new NotSupportedException();
}

public override void Write(byte[] buffer, int offset, int count)
{
throw new NotSupportedException(SR.net_http_synchronous_writes_not_supported);
}
#endregion
}

internal sealed class WasmFetchResponse : IDisposable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,22 @@ public static partial void AbortRequest(
public static partial void AbortResponse(
JSObject fetchResponse);

[JSImport("INTERNAL.http_wasm_readable_stream_controller_enqueue")]
public static partial void ReadableStreamControllerEnqueue(
[JSMarshalAs<JSType.Any>] object pullState,
[JSImport("INTERNAL.http_wasm_create_transform_stream")]
public static partial JSObject CreateTransformStream();

[JSImport("INTERNAL.http_wasm_transform_stream_write")]
public static partial Task TransformStreamWrite(
JSObject transformStream,
IntPtr bufferPtr,
int bufferLength);

[JSImport("INTERNAL.http_wasm_readable_stream_controller_error")]
public static partial void ReadableStreamControllerError(
[JSMarshalAs<JSType.Any>] object pullState,
[JSImport("INTERNAL.http_wasm_transform_stream_close")]
public static partial Task TransformStreamClose(
JSObject transformStream);

[JSImport("INTERNAL.http_wasm_transform_stream_abort")]
public static partial void TransformStreamAbort(
JSObject transformStream,
Exception error);

[JSImport("INTERNAL.http_wasm_get_response_header_names")]
Expand Down Expand Up @@ -79,8 +86,7 @@ public static partial Task<JSObject> Fetch(
string[] optionNames,
[JSMarshalAs<JSType.Array<JSType.Any>>] object?[] optionValues,
JSObject abortControler,
[JSMarshalAs<JSType.Function<JSType.Any>>] Action<object> pull,
[JSMarshalAs<JSType.Any>] object pullState);
JSObject transformStream);

[JSImport("INTERNAL.http_wasm_fetch_bytes")]
private static partial Task<JSObject> FetchBytes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ public async Task HttpRequest_StringContent_WithoutMediaType()
await LoopbackServer.CreateServerAsync(async (server, uri) =>
{
var request = new HttpRequestMessage(HttpMethod.Post, uri);
request.Content = new StringContent("Hello World", null, ((MediaTypeHeaderValue)null)!);
request.Content = new StringContent("", null, ((MediaTypeHeaderValue)null)!);

Task<HttpResponseMessage> requestTask = client.SendAsync(request);
await server.AcceptConnectionAsync(async connection =>
Expand Down
Loading