From a7f122a4108ad940ef460dbb15c0ab7b2edfb46d Mon Sep 17 00:00:00 2001 From: Viacheslav Rostovtsev Date: Fri, 4 Nov 2022 15:40:29 -0700 Subject: [PATCH] feat: rest server streaming support --- .../Rest/PartialDecodingStreamReaderTest.cs | 166 ++++++++++++++++++ .../Rest/PartialDecodingStreamReader.cs | 152 ++++++++++++++++ .../Rest/ReadHttpResponseMessage.cs | 7 +- Google.Api.Gax.Grpc/Rest/RestCallInvoker.cs | 2 +- Google.Api.Gax.Grpc/Rest/RestChannel.cs | 64 +++++-- Google.Api.Gax.Grpc/Rest/RestMethod.cs | 17 ++ .../Rest/RestServiceCollection.cs | 4 +- 7 files changed, 397 insertions(+), 15 deletions(-) create mode 100644 Google.Api.Gax.Grpc.Tests/Rest/PartialDecodingStreamReaderTest.cs create mode 100644 Google.Api.Gax.Grpc/Rest/PartialDecodingStreamReader.cs diff --git a/Google.Api.Gax.Grpc.Tests/Rest/PartialDecodingStreamReaderTest.cs b/Google.Api.Gax.Grpc.Tests/Rest/PartialDecodingStreamReaderTest.cs new file mode 100644 index 00000000..11da3362 --- /dev/null +++ b/Google.Api.Gax.Grpc.Tests/Rest/PartialDecodingStreamReaderTest.cs @@ -0,0 +1,166 @@ +/* + * Copyright 2022 Google LLC + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file or at + * https://developers.google.com/open-source/licenses/bsd + */ + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using Xunit; + +namespace Google.Api.Gax.Grpc.Rest.Tests +{ + public class PartialDecodingStreamReaderTest + { + private static readonly string ArrayOfObjectsJson = @" +[ + { + ""foo"": 1 + }, + { + ""bar"": 2 + } +] +"; + + private static readonly string IncompleteArrayOfObjectsJson = @" +[ + { + ""foo"": 1 + },"; + + /// + /// Test coarse split data. + /// + [Fact] + public async void DecodingStreamReaderTestByLine() + { + StreamReader reader = new ReplayingStreamReader(ArrayOfObjectsJson.Split(new []{Environment.NewLine}, StringSplitOptions.RemoveEmptyEntries)); + var decodingReader = new PartialDecodingStreamReader(Task.FromResult(reader), JObject.Parse); + + var result = await decodingReader.MoveNext(CancellationToken.None); + Assert.True(result); + Assert.NotNull(decodingReader.Current); + Assert.Equal(decodingReader.Current["foo"], 1); + + result = await decodingReader.MoveNext(CancellationToken.None); + Assert.True(result); + Assert.NotNull(decodingReader.Current); + Assert.Equal(decodingReader.Current["bar"], 2); + + result = await decodingReader.MoveNext(CancellationToken.None); + Assert.False(result); + + result = await decodingReader.MoveNext(CancellationToken.None); + Assert.False(result); + } + + /// + /// Test data split by characters. + /// + [Fact] + public async void DecodingStreamReaderTestByChar() + { + StreamReader reader = new ReplayingStreamReader(ArrayOfObjectsJson.Select(c => c.ToString())); + var decodingReader = new PartialDecodingStreamReader(Task.FromResult(reader), JObject.Parse); + + var result = await decodingReader.MoveNext(CancellationToken.None); + Assert.True(result); + Assert.NotNull(decodingReader.Current); + Assert.Equal(decodingReader.Current["foo"], 1); + + result = await decodingReader.MoveNext(CancellationToken.None); + Assert.True(result); + Assert.NotNull(decodingReader.Current); + Assert.Equal(decodingReader.Current["bar"], 2); + + result = await decodingReader.MoveNext(CancellationToken.None); + Assert.False(result); + + result = await decodingReader.MoveNext(CancellationToken.None); + Assert.False(result); + } + + /// + /// Test when data breaks off unexpectedly. + /// + [Fact] + public async void DecodingStreamReaderTestIncomplete() + { + StreamReader reader = new ReplayingStreamReader(IncompleteArrayOfObjectsJson.Split(new []{Environment.NewLine}, StringSplitOptions.RemoveEmptyEntries)); + var decodingReader = new PartialDecodingStreamReader(Task.FromResult(reader), JObject.Parse); + + var result = await decodingReader.MoveNext(CancellationToken.None); + Assert.True(result); + Assert.NotNull(decodingReader.Current); + Assert.Equal(decodingReader.Current["foo"], 1); + + var ex = await Assert.ThrowsAsync(async () => + await decodingReader.MoveNext(CancellationToken.None)); + Assert.Contains("Closing `]` bracket not received after iterating through the stream.", ex.Message); + } + + /// + /// Test when data is empty array. + /// + [Fact] + public async void DecodingStreamReaderTestEmpty() + { + StreamReader reader = new ReplayingStreamReader(new[] {"[]"}); + var decodingReader = new PartialDecodingStreamReader(Task.FromResult(reader), JObject.Parse); + + var result = await decodingReader.MoveNext(CancellationToken.None); + Assert.False(result); + + result = await decodingReader.MoveNext(CancellationToken.None); + Assert.False(result); + } + } + + /// + /// A fake of a StreamReader emitting given strings + /// + internal class ReplayingStreamReader : StreamReader + { + private readonly Queue _queue; + + /// + /// Cannot override EndOfStream, so have to nudge + /// the base class to do this. + /// Initialize it with a non-empty stream and later read + /// that one to end. + /// + /// + public ReplayingStreamReader(IEnumerable strings) : base(new MemoryStream(new byte[1])) + { + _queue = new Queue(strings); + } + + public override Task ReadAsync(char[] buffer, int index, int count) + { + if (_queue.Count <= 0) + { + base.ReadToEnd(); // EndOfStream starts to return true + return Task.FromResult(0); + } + + var nextString = _queue.Dequeue(); + + Assert.True(count > nextString.Length); + + for (int i = 0; i < nextString.Length; i++) + { + buffer[index + i] = nextString[i]; + } + + return Task.FromResult(nextString.Length); + } + } +} diff --git a/Google.Api.Gax.Grpc/Rest/PartialDecodingStreamReader.cs b/Google.Api.Gax.Grpc/Rest/PartialDecodingStreamReader.cs new file mode 100644 index 00000000..2404be23 --- /dev/null +++ b/Google.Api.Gax.Grpc/Rest/PartialDecodingStreamReader.cs @@ -0,0 +1,152 @@ +/* + * Copyright 2022 Google LLC + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file or at + * https://developers.google.com/open-source/licenses/bsd + */ + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Grpc.Core; +using Newtonsoft.Json.Linq; + +namespace Google.Api.Gax.Grpc.Rest; + +/// +/// An IAsyncStreamReader implementation that reads an array of messages +/// from HTTP stream as they arrive in (partial) JSON chunks. +/// +/// Type of proto messages in the stream +internal class PartialDecodingStreamReader : IAsyncStreamReader +{ + private readonly Task _streamReaderTask; + private readonly Func _responseConverter; + + private readonly Queue _readyResults; + private readonly StringBuilder _currentBuffer; + + private StreamReader _streamReader; + private bool _arrayClosed; + + /// + /// Creates the StreamReader + /// + /// A stream reader returning partial JSON chunks + /// A function to transform a well-formed JSON object into the proto message. + public PartialDecodingStreamReader(Task streamReaderTask, Func responseConverter) + { + _streamReaderTask = streamReaderTask; + _responseConverter = responseConverter; + + _readyResults = new Queue(); + _currentBuffer = new StringBuilder(); + + _streamReader = null; + _arrayClosed = false; + } + + /// + public async Task MoveNext(CancellationToken cancellationToken) + { + _streamReader ??= await _streamReaderTask.ConfigureAwait(false); + + if (_readyResults.Count > 0) + { + Current = _readyResults.Dequeue(); + return true; + } + + if (_streamReader.EndOfStream) + { + return false; + } + + var buffer = new char[8000]; + while (_readyResults.Count == 0) + { + var taskRead = _streamReader.ReadAsync(buffer, 0, buffer.Length); + var cancellationTask = Task.Delay(-1, cancellationToken); + var resultTask = await Task.WhenAny(taskRead, cancellationTask).ConfigureAwait(false); + + if (resultTask == cancellationTask) + { + // If the cancellationTask "wins" `Task.WhenAny` by being cancelled, the following await will throw TaskCancelledException. + await cancellationTask.ConfigureAwait(false); + } + + var readLen = await taskRead.ConfigureAwait(false); + if (readLen == 0) + { + if (!_arrayClosed) + { + var errorText = "Closing `]` bracket not received after iterating through the stream. " + + "This means that streaming ended without all objects transmitted. " + + "It is likely a result of server or network error."; + throw new InvalidOperationException(errorText); + } + + return false; + } + + var readChars = buffer.Take(readLen); + foreach (var c in readChars) + { + // Closing bracket for the top-level array + if (_currentBuffer.Length == 0 && c == ']') + { + // TODO[virost, jskeet, 11/2022] Fix with tokenizer: + // it's possible to receive more data after the closing `]` + _arrayClosed = true; + continue; + } + + // Between-objects commas and spaces, as well as an opening bracket + // for the top-level array. + if (_currentBuffer.Length == 0 && c != '{') + { + continue; + } + + _currentBuffer.Append(c); + if (c != '}') + { + continue; + } + + var currentStr = _currentBuffer.ToString(); + try + { + // This will throw unless the characters in the _currentBuffer + // add up to a correct JSON and since the _currentBuffer always + // starts with an opening `{` bracket from one of the + // top-level array's element's, + // this will throw unless _currentBuffer contains one message. + // TODO[virost, jskeet, 11/2022] Use a JSON tokenizer instead + JObject.Parse(currentStr); + } + catch (Newtonsoft.Json.JsonReaderException) + { + // Tried to parse a partial json because the `}` was a part of + // a string or a child inner object. + continue; + } + + TResponse responseElement = _responseConverter(currentStr); + _readyResults.Enqueue(responseElement); + _currentBuffer.Clear(); + } + } + + Current = _readyResults.Dequeue(); + return true; + } + + /// + public TResponse Current { get; private set; } +} diff --git a/Google.Api.Gax.Grpc/Rest/ReadHttpResponseMessage.cs b/Google.Api.Gax.Grpc/Rest/ReadHttpResponseMessage.cs index 335087c8..ba23357c 100644 --- a/Google.Api.Gax.Grpc/Rest/ReadHttpResponseMessage.cs +++ b/Google.Api.Gax.Grpc/Rest/ReadHttpResponseMessage.cs @@ -13,6 +13,7 @@ using System.Linq; using System.Net; using System.Net.Http; +using System.Net.Http.Headers; using System.Runtime.ExceptionServices; namespace Google.Api.Gax.Grpc.Rest @@ -63,11 +64,13 @@ internal ReadHttpResponseMessage(HttpResponseMessage response, ExceptionDispatch // We'll bubble up the _readException instead. (OriginalResponseMessage, _readException) = (response, readException); - internal Metadata GetHeaders() + internal Metadata GetHeaders() => ReadHeaders(OriginalResponseMessage.Headers); + + internal static Metadata ReadHeaders(HttpResponseHeaders headers) { // TODO: This could be very wrong. I don't know what headers we should really return, and I don't know about semi-colon joining. var metadata = new Metadata(); - foreach (var header in OriginalResponseMessage.Headers) + foreach (var header in headers) { metadata.Add(header.Key, string.Join(";", header.Value)); } diff --git a/Google.Api.Gax.Grpc/Rest/RestCallInvoker.cs b/Google.Api.Gax.Grpc/Rest/RestCallInvoker.cs index a56c96da..47c56a78 100644 --- a/Google.Api.Gax.Grpc/Rest/RestCallInvoker.cs +++ b/Google.Api.Gax.Grpc/Rest/RestCallInvoker.cs @@ -31,7 +31,7 @@ public override AsyncDuplexStreamingCall AsyncDuplexStreami /// public override AsyncServerStreamingCall AsyncServerStreamingCall(Method method, string host, CallOptions options, TRequest request) => - throw new NotSupportedException("Streaming methods are not supported by the hybrid REST/gRPC mode"); + _channel.AsyncServerStreamingCall(method, host, options, request); /// public override AsyncUnaryCall AsyncUnaryCall(Method method, string host, CallOptions options, TRequest request) => diff --git a/Google.Api.Gax.Grpc/Rest/RestChannel.cs b/Google.Api.Gax.Grpc/Rest/RestChannel.cs index e347b959..2df0d41f 100644 --- a/Google.Api.Gax.Grpc/Rest/RestChannel.cs +++ b/Google.Api.Gax.Grpc/Rest/RestChannel.cs @@ -75,25 +75,29 @@ internal AsyncUnaryCall AsyncUnaryCall(Method(httpResponseTask); - var responseHeadersTask = ReadHeadersAsync(httpResponseTask); - Func statusFunc = () => GetStatus(httpResponseTask); - Func trailersFunc = () => GetTrailers(httpResponseTask); + var httpResponseTask = SendAsync(restMethod, host, options, request, cancellationTokenSource.Token, HttpCompletionOption.ResponseContentRead); + var readResponseTask = ReadResponseAsync(httpResponseTask); + + var responseTask = restMethod.ReadResponseAsync(readResponseTask); + var responseHeadersTask = ReadHeadersAsync(readResponseTask); + Func statusFunc = () => GetStatus(readResponseTask); + Func trailersFunc = () => GetTrailers(readResponseTask); return new AsyncUnaryCall(responseTask, responseHeadersTask, statusFunc, trailersFunc, cancellationTokenSource.Cancel); } - private async Task SendAsync(RestMethod restMethod, string host, CallOptions options, TRequest request, CancellationToken deadlineToken) + private async Task SendAsync(RestMethod restMethod, string host, CallOptions options, TRequest request, + CancellationToken deadlineToken, HttpCompletionOption httpCompletionOption) { // Ideally, add the header in the client builder instead of in the ServiceSettingsBase... - var httpRequest = restMethod.CreateRequest((IMessage) request, host); + var httpRequest = restMethod.CreateRequest((IMessage)request, host); foreach (var headerKeyValue in options.Headers .Where(mh => !mh.IsBinary) - .Where(mh=> mh.Key != VersionHeaderBuilder.HeaderName)) + .Where(mh => mh.Key != VersionHeaderBuilder.HeaderName)) { httpRequest.Headers.Add(headerKeyValue.Key, headerKeyValue.Value); } + httpRequest.Headers.Add(VersionHeaderBuilder.HeaderName, RestVersion); HttpResponseMessage httpResponseMessage; @@ -102,14 +106,51 @@ private async Task SendAsync(RestMethod restM try { await AddAuthHeadersAsync(httpRequest, restMethod, linkedCts.Token).ConfigureAwait(false); - httpResponseMessage = await _httpClient.SendAsync(httpRequest, HttpCompletionOption.ResponseContentRead, linkedCts.Token).ConfigureAwait(false); + httpResponseMessage = await _httpClient.SendAsync(httpRequest, httpCompletionOption, linkedCts.Token).ConfigureAwait(false); } catch (TaskCanceledException ex) when (deadlineToken.IsCancellationRequested) { - throw new RpcException(new Status(StatusCode.DeadlineExceeded, $"The timeout was reached when calling a method `{restMethod.FullName}`", ex)); + throw new RpcException(new Status(StatusCode.DeadlineExceeded, $"The timeout was reached when calling a method `{restMethod.FullName}`", ex)); + } + } + + return httpResponseMessage; + } + + /// + /// Equivalent to . + /// + public AsyncServerStreamingCall AsyncServerStreamingCall(Method method, string host, CallOptions options, TRequest request) + { + // TODO[virost, 11/2022] Refactor this and the Unary call to remove duplication + var restMethod = _serviceCollection.GetRestMethod(method); + + var cancellationTokenSource = new CancellationTokenSource(); + if (options.Deadline.HasValue) + { + // TODO: [virost, 2021-12] Use IClock. + var delayInterval = options.Deadline.Value - DateTime.UtcNow; + if (delayInterval.TotalMilliseconds <= 0) + { + throw new RpcException(new Status(StatusCode.DeadlineExceeded, $"The timeout was reached when calling a method `{restMethod.FullName}`")); } + cancellationTokenSource = new CancellationTokenSource(delayInterval); } + Task httpResponseTask = SendAsync(restMethod, host, options, request, cancellationTokenSource.Token, HttpCompletionOption.ResponseHeadersRead); + + Task responseHeadersTask = ReadHeadersAsync(httpResponseTask); + Func statusFunc = () => GetStatus(ReadResponseAsync(httpResponseTask)); + Func trailersFunc = () => GetTrailers(ReadResponseAsync(httpResponseTask)); + + IAsyncStreamReader responseStream = restMethod.ResponseStreamAsync(httpResponseTask); + return new AsyncServerStreamingCall(responseStream, responseHeadersTask, statusFunc, trailersFunc, cancellationTokenSource.Cancel); + } + + private async Task ReadResponseAsync(Task msgTask) + { + HttpResponseMessage httpResponseMessage = await msgTask.ConfigureAwait(false); + try { string content = await httpResponseMessage.Content.ReadAsStringAsync().ConfigureAwait(false); @@ -156,6 +197,9 @@ private async Task AddAuthHeadersAsync(HttpRequestMessage request, RestMethod re private async Task ReadHeadersAsync(Task httpResponseTask) => (await httpResponseTask.ConfigureAwait(false)).GetHeaders(); + private async Task ReadHeadersAsync(Task httpResponseTask) => + ReadHttpResponseMessage.ReadHeaders((await httpResponseTask.ConfigureAwait(false)).Headers); + private static Status GetStatus(Task httpResponseTask) => httpResponseTask.Status switch { TaskStatus.RanToCompletion => httpResponseTask.Result.GetStatus(), diff --git a/Google.Api.Gax.Grpc/Rest/RestMethod.cs b/Google.Api.Gax.Grpc/Rest/RestMethod.cs index 4b5f9dfb..b73865b2 100644 --- a/Google.Api.Gax.Grpc/Rest/RestMethod.cs +++ b/Google.Api.Gax.Grpc/Rest/RestMethod.cs @@ -9,6 +9,7 @@ using Google.Protobuf.Reflection; using Grpc.Core; using System; +using System.IO; using System.Net.Http; using System.Threading.Tasks; @@ -85,4 +86,20 @@ internal async Task ReadResponseAsync(Task ResponseStreamAsync(Task httpResponseTask) + { + var streamReaderTask = GetStreamReader(httpResponseTask); + Func responseConverter = json => (TResponse) _parser.Parse(json, _protoMethod.OutputType); + + return new PartialDecodingStreamReader(streamReaderTask, responseConverter); + } + + private static async Task GetStreamReader(Task httpResponseTask) + { + var httpResponse = await httpResponseTask.ConfigureAwait(false); + httpResponse.EnsureSuccessStatusCode(); + var stream = await httpResponse.Content.ReadAsStreamAsync().ConfigureAwait(false); + return new StreamReader(stream); + } } diff --git a/Google.Api.Gax.Grpc/Rest/RestServiceCollection.cs b/Google.Api.Gax.Grpc/Rest/RestServiceCollection.cs index bee81b50..aa62b7a9 100644 --- a/Google.Api.Gax.Grpc/Rest/RestServiceCollection.cs +++ b/Google.Api.Gax.Grpc/Rest/RestServiceCollection.cs @@ -37,8 +37,8 @@ internal static RestServiceCollection Create(ApiMetadata metadata) var typeRegistry = TypeRegistry.FromFiles(fileDescriptors.ToArray()); var parser = new JsonParser(JsonParser.Settings.Default.WithIgnoreUnknownFields(true).WithTypeRegistry(typeRegistry)); var methodsByName = services.SelectMany(service => service.Methods) - // We don't yet support streaming methods. - .Where(x => !x.IsClientStreaming && !x.IsServerStreaming) + // We don't support client streaming (and bidi) methods with REST. + .Where(x => !x.IsClientStreaming) // Ignore methods without HTTP annotations. Ideally there wouldn't be any, but // operations.proto doesn't specify an HTTP rule for WaitOperation. .Where(x => x.GetOptions()?.GetExtension(AnnotationsExtensions.Http) is not null)