Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid first chance exception on the hot path of RetriableStream #18064

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions sdk/core/Azure.Core/src/Shared/RetriableStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ private class RetriableStreamImpl : Stream

private readonly int _maxRetries;

private readonly long _length;
private readonly ExceptionDispatchInfo _lengthException;
private readonly long? _length;

private Stream _currentStream;

Expand All @@ -66,13 +65,16 @@ private class RetriableStreamImpl : Stream

public RetriableStreamImpl(Stream initialStream, Func<long, Stream> streamFactory, Func<long, ValueTask<Stream>> asyncStreamFactory, ResponseClassifier responseClassifier, int maxRetries)
{
try
if (initialStream.CanSeek)
{
_length = EnsureStream(initialStream).Length;
}
catch (Exception ex)
{
_lengthException = ExceptionDispatchInfo.Capture(ex);
try
{
_length = EnsureStream(initialStream).Length;
}
catch
{
// ignore
}
}

_currentStream = EnsureStream(initialStream);
Expand Down Expand Up @@ -152,14 +154,7 @@ public override int Read(byte[] buffer, int offset, int count)

public override bool CanRead => _currentStream.CanRead;
public override bool CanSeek { get; }
public override long Length
{
get
{
_lengthException?.Throw();
return _length;
}
}
public override long Length => _length ?? throw new NotSupportedException();

public override long Position
{
Expand Down
40 changes: 33 additions & 7 deletions sdk/core/Azure.Core/tests/RetriableStreamTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public RetriableStreamTests(bool isAsync) : base(isAsync)
[Test]
public async Task MaintainsGlobalLengthAndPosition()
{
var stream1 = new MockReadStream(100, throwAfter: 50);
var stream1 = new MockReadStream(100, throwAfter: 50, canSeek: true);
var stream2 = new MockReadStream(50, offset: 50);

MockTransport mockTransport = CreateMockTransport(
Expand Down Expand Up @@ -88,7 +88,7 @@ public async Task DisposesStreams()
[Test]
public async Task DoesntRetryNonRetryableExceptions()
{
var stream1 = new MockReadStream(100, throwAfter: 50);
var stream1 = new MockReadStream(100, throwAfter: 50, canSeek: true);
var stream2 = new MockReadStream(50, offset: 50, throwAfter: 0, exceptionType: typeof(InvalidOperationException));

MockTransport mockTransport = CreateMockTransport(
Expand Down Expand Up @@ -125,7 +125,7 @@ public async Task DoesntRetryCustomerCancellationTokens()
Assert.Ignore();
}

var stream1 = new MockReadStream(100);
var stream1 = new MockReadStream(100, canSeek: true);

MockTransport mockTransport = CreateMockTransport(
new MockResponse(200) { ContentStream = stream1 });
Expand All @@ -149,7 +149,7 @@ public async Task DoesntRetryCustomerCancellationTokens()
[Test]
public async Task RetriesOnNonCustomerCancellationToken()
{
var stream1 = new MockReadStream(100, throwAfter: 50, exceptionType: typeof(OperationCanceledException));
var stream1 = new MockReadStream(100, throwAfter: 50, exceptionType: typeof(OperationCanceledException), canSeek: true);
var stream2 = new MockReadStream(50, offset: 50);

MockTransport mockTransport = CreateMockTransport(
Expand Down Expand Up @@ -288,6 +288,26 @@ public async Task RetriesMaxCountAndThrowsAggregateException()
Assert.AreEqual(4, mockTransport.Requests.Count);
}

[Test]
public void ThrowsForLengthOnNonSeekableStream()
{
Assert.Throws<NotSupportedException>(() => _ = RetriableStream.Create(
_ => new MockReadStream(100, canSeek: false),
_ => default,
new ResponseClassifier(),
5).Length);
}

[Test]
public void IgnoresMisbehavingStreams()
{
Assert.Throws<NotSupportedException>(() => _ = RetriableStream.Create(
_ => new NoLengthStream(canSeek: true),
_ => default,
new ResponseClassifier(),
5).Length);
}

[Test]
public void ThrowsIfInitialRequestThrow()
{
Expand Down Expand Up @@ -370,6 +390,11 @@ private static Request CreateRequest(HttpPipeline pipeline, long offset)

private class NoLengthStream : ReadOnlyStream
{
public NoLengthStream(bool canSeek = false)
{
CanSeek = canSeek;
}

public override int Read(byte[] buffer, int offset, int count)
{
throw new IOException();
Expand All @@ -381,7 +406,7 @@ public override long Seek(long offset, SeekOrigin origin)
}

public override bool CanRead { get; } = true;
public override bool CanSeek { get; } = false;
public override bool CanSeek { get; }
public override long Length => throw new NotSupportedException();
public override long Position { get; set; }
}
Expand All @@ -393,12 +418,13 @@ private class MockReadStream : ReadOnlyStream
private byte _offset;
private readonly Type _exceptionType;

public MockReadStream(long length, long throwAfter = int.MaxValue, byte offset = 0, Type exceptionType = null)
public MockReadStream(long length, long throwAfter = int.MaxValue, byte offset = 0, Type exceptionType = null, bool canSeek = false)
{
_throwAfter = throwAfter;
_offset = offset;
_exceptionType = exceptionType ?? typeof(IOException);
Length = length;
CanSeek = canSeek;
}

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
Expand Down Expand Up @@ -433,7 +459,7 @@ public override long Seek(long offset, SeekOrigin origin)
}

public override bool CanRead { get; } = true;
public override bool CanSeek { get; } = false;
public override bool CanSeek { get; }
public override long Length { get; }
public override long Position { get; set; }
public bool IsDisposed { get; set; }
Expand Down