From c9a1f8793c644b417e72204c6f6e26ce6fe3db86 Mon Sep 17 00:00:00 2001 From: Pavel Krymets Date: Tue, 19 Jan 2021 15:57:07 -0800 Subject: [PATCH] Avoid first chance exception on the hot path of RetriableStream (#18064) --- .../Azure.Core/src/Shared/RetriableStream.cs | 27 +++++-------- .../Azure.Core/tests/RetriableStreamTests.cs | 40 +++++++++++++++---- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/sdk/core/Azure.Core/src/Shared/RetriableStream.cs b/sdk/core/Azure.Core/src/Shared/RetriableStream.cs index 85715f82bfa4..596ffe47c0fd 100644 --- a/sdk/core/Azure.Core/src/Shared/RetriableStream.cs +++ b/sdk/core/Azure.Core/src/Shared/RetriableStream.cs @@ -53,8 +53,7 @@ private class RetriableStreamImpl : Stream private readonly int _maxRetries; - private readonly long _length; - private readonly ExceptionDispatchInfo _lengthException; + private readonly long? _length; private Stream _currentStream; @@ -66,13 +65,16 @@ private class RetriableStreamImpl : Stream public RetriableStreamImpl(Stream initialStream, Func streamFactory, Func> asyncStreamFactory, ResponseClassifier responseClassifier, int maxRetries) { - try + if (initialStream.CanSeek) { - _length = EnsureStream(initialStream).Length; - } - catch (Exception ex) - { - _lengthException = ExceptionDispatchInfo.Capture(ex); + try + { + _length = EnsureStream(initialStream).Length; + } + catch + { + // ignore + } } _currentStream = EnsureStream(initialStream); @@ -152,14 +154,7 @@ public override int Read(byte[] buffer, int offset, int count) public override bool CanRead => _currentStream.CanRead; public override bool CanSeek { get; } - public override long Length - { - get - { - _lengthException?.Throw(); - return _length; - } - } + public override long Length => _length ?? throw new NotSupportedException(); public override long Position { diff --git a/sdk/core/Azure.Core/tests/RetriableStreamTests.cs b/sdk/core/Azure.Core/tests/RetriableStreamTests.cs index 2248cc1c0f82..6dae7202832f 100644 --- a/sdk/core/Azure.Core/tests/RetriableStreamTests.cs +++ b/sdk/core/Azure.Core/tests/RetriableStreamTests.cs @@ -24,7 +24,7 @@ public RetriableStreamTests(bool isAsync) : base(isAsync) [Test] public async Task MaintainsGlobalLengthAndPosition() { - var stream1 = new MockReadStream(100, throwAfter: 50); + var stream1 = new MockReadStream(100, throwAfter: 50, canSeek: true); var stream2 = new MockReadStream(50, offset: 50); MockTransport mockTransport = CreateMockTransport( @@ -88,7 +88,7 @@ public async Task DisposesStreams() [Test] public async Task DoesntRetryNonRetryableExceptions() { - var stream1 = new MockReadStream(100, throwAfter: 50); + var stream1 = new MockReadStream(100, throwAfter: 50, canSeek: true); var stream2 = new MockReadStream(50, offset: 50, throwAfter: 0, exceptionType: typeof(InvalidOperationException)); MockTransport mockTransport = CreateMockTransport( @@ -125,7 +125,7 @@ public async Task DoesntRetryCustomerCancellationTokens() Assert.Ignore(); } - var stream1 = new MockReadStream(100); + var stream1 = new MockReadStream(100, canSeek: true); MockTransport mockTransport = CreateMockTransport( new MockResponse(200) { ContentStream = stream1 }); @@ -149,7 +149,7 @@ public async Task DoesntRetryCustomerCancellationTokens() [Test] public async Task RetriesOnNonCustomerCancellationToken() { - var stream1 = new MockReadStream(100, throwAfter: 50, exceptionType: typeof(OperationCanceledException)); + var stream1 = new MockReadStream(100, throwAfter: 50, exceptionType: typeof(OperationCanceledException), canSeek: true); var stream2 = new MockReadStream(50, offset: 50); MockTransport mockTransport = CreateMockTransport( @@ -288,6 +288,26 @@ public async Task RetriesMaxCountAndThrowsAggregateException() Assert.AreEqual(4, mockTransport.Requests.Count); } + [Test] + public void ThrowsForLengthOnNonSeekableStream() + { + Assert.Throws(() => _ = RetriableStream.Create( + _ => new MockReadStream(100, canSeek: false), + _ => default, + new ResponseClassifier(), + 5).Length); + } + + [Test] + public void IgnoresMisbehavingStreams() + { + Assert.Throws(() => _ = RetriableStream.Create( + _ => new NoLengthStream(canSeek: true), + _ => default, + new ResponseClassifier(), + 5).Length); + } + [Test] public void ThrowsIfInitialRequestThrow() { @@ -370,6 +390,11 @@ private static Request CreateRequest(HttpPipeline pipeline, long offset) private class NoLengthStream : ReadOnlyStream { + public NoLengthStream(bool canSeek = false) + { + CanSeek = canSeek; + } + public override int Read(byte[] buffer, int offset, int count) { throw new IOException(); @@ -381,7 +406,7 @@ public override long Seek(long offset, SeekOrigin origin) } public override bool CanRead { get; } = true; - public override bool CanSeek { get; } = false; + public override bool CanSeek { get; } public override long Length => throw new NotSupportedException(); public override long Position { get; set; } } @@ -393,12 +418,13 @@ private class MockReadStream : ReadOnlyStream private byte _offset; private readonly Type _exceptionType; - public MockReadStream(long length, long throwAfter = int.MaxValue, byte offset = 0, Type exceptionType = null) + public MockReadStream(long length, long throwAfter = int.MaxValue, byte offset = 0, Type exceptionType = null, bool canSeek = false) { _throwAfter = throwAfter; _offset = offset; _exceptionType = exceptionType ?? typeof(IOException); Length = length; + CanSeek = canSeek; } public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) @@ -433,7 +459,7 @@ public override long Seek(long offset, SeekOrigin origin) } public override bool CanRead { get; } = true; - public override bool CanSeek { get; } = false; + public override bool CanSeek { get; } public override long Length { get; } public override long Position { get; set; } public bool IsDisposed { get; set; }