From 77294e9f1b7246b7fe051d94fd7a7d6a8267987a Mon Sep 17 00:00:00 2001 From: Jocelyn Schreppler Date: Thu, 9 May 2024 13:43:16 -0400 Subject: [PATCH 1/5] retriable decode --- .../Azure.Storage.Blobs/src/BlobBaseClient.cs | 2 +- ...tructuredMessageDecodingRetriableStream.cs | 181 ++++++++++++++++++ .../Shared/StructuredMessageDecodingStream.cs | 151 +++++++++++---- .../tests/Azure.Storage.Common.Tests.csproj | 3 + ...uredMessageDecodingRetriableStreamTests.cs | 43 +++++ .../StructuredMessageDecodingStreamTests.cs | 22 +-- .../StructuredMessageStreamRoundtripTests.cs | 4 +- 7 files changed, 354 insertions(+), 52 deletions(-) create mode 100644 sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs create mode 100644 sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs diff --git a/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs b/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs index b664a1c860d5b..d9fd59d0d4bdd 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs @@ -1757,7 +1757,7 @@ private async ValueTask> StartDownloadAsyn if (response.GetRawResponse().Headers.TryGetValue(Constants.StructuredMessage.CrcStructuredMessageHeader, out string _) && response.GetRawResponse().Headers.TryGetValue(Constants.HeaderNames.ContentLength, out string rawContentLength)) { - result.Content = new StructuredMessageDecodingStream(result.Content, long.Parse(rawContentLength)); + (result.Content, _) = StructuredMessageDecodingStream.WrapStream(result.Content, long.Parse(rawContentLength)); } // if not null, we expected a structured message response // but we didn't find one in the above condition diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs new file mode 100644 index 0000000000000..87e1800baa895 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs @@ -0,0 +1,181 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Core.Pipeline; + +namespace Azure.Storage.Shared; + +internal class StructuredMessageDecodingRetriableStream : Stream +{ + private readonly Stream _innerRetriable; + private long _decodedBytesRead; + + private readonly List _decodedDatas = new(); + + private readonly Func _decodingStreamFactory; + private readonly Func> _decodingAsyncStreamFactory; + + public StructuredMessageDecodingRetriableStream( + Stream initialResponse, + Func decodingStreamFactory, + Func> decodingAsyncStreamFactory, + ResponseClassifier responseClassifier, + int maxRetries) + { + _decodingStreamFactory = decodingStreamFactory; + _decodingAsyncStreamFactory = decodingAsyncStreamFactory; + _innerRetriable = RetriableStream.Create(initialResponse, StreamFactory, StreamFactoryAsync, responseClassifier, maxRetries); + } + + private Stream StreamFactory(long _) + { + long offset = _decodedDatas.LastOrDefault()?.SegmentCrcs?.LastOrDefault().SegmentEnd ?? 0; + (Stream decodingStream, StructuredMessageDecodingStream.DecodedData decodedData) = _decodingStreamFactory(offset); + _decodedDatas.Add(decodedData); + FastForwardInternal(decodingStream, _decodedBytesRead - offset, false).EnsureCompleted(); + return decodingStream; + } + + private async ValueTask StreamFactoryAsync(long _) + { + long offset = _decodedDatas.LastOrDefault()?.SegmentCrcs?.LastOrDefault().SegmentEnd ?? 0; + (Stream decodingStream, StructuredMessageDecodingStream.DecodedData decodedData) = await _decodingAsyncStreamFactory(offset).ConfigureAwait(false); + _decodedDatas.Add(decodedData); + await FastForwardInternal(decodingStream, _decodedBytesRead - offset, true).ConfigureAwait(false); + return decodingStream; + } + + private static async ValueTask FastForwardInternal(Stream stream, long bytes, bool async) + { + using (ArrayPool.Shared.RentDisposable(4 * Constants.KB, out byte[] buffer)) + { + if (async) + { + while (bytes > 0) + { + bytes -= await stream.ReadAsync(buffer, 0, (int)Math.Min(bytes, buffer.Length)).ConfigureAwait(false); + } + } + else + { + while (bytes > 0) + { + bytes -= stream.Read(buffer, 0, (int)Math.Min(bytes, buffer.Length)); + } + } + } + } + + protected override void Dispose(bool disposing) + { + foreach (IDisposable data in _decodedDatas) + { + data.Dispose(); + } + _decodedDatas.Clear(); + _innerRetriable.Dispose(); + } + + #region Read + + public override int Read(byte[] buffer, int offset, int count) + { + int read = _innerRetriable.Read(buffer, offset, count); + _decodedBytesRead += read; + return read; + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + int read = await _innerRetriable.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + _decodedBytesRead += read; + return read; + } + +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + public override int Read(Span buffer) + { + int read = _innerRetriable.Read(buffer); + _decodedBytesRead += read; + return read; + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + int read = await _innerRetriable.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); + _decodedBytesRead += read; + return read; + } +#endif + + public override int ReadByte() + { + int val = _innerRetriable.ReadByte(); + _decodedBytesRead += 1; + return val; + } + + public override int EndRead(IAsyncResult asyncResult) + { + int read = _innerRetriable.EndRead(asyncResult); + _decodedBytesRead += read; + return read; + } + #endregion + + #region Passthru + public override bool CanRead => _innerRetriable.CanRead; + + public override bool CanSeek => _innerRetriable.CanSeek; + + public override bool CanWrite => _innerRetriable.CanWrite; + + public override bool CanTimeout => _innerRetriable.CanTimeout; + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) => _innerRetriable.CopyToAsync(destination, bufferSize, cancellationToken); + + public override long Length => _innerRetriable.Length; + + public override long Position { get => _innerRetriable.Position; set => _innerRetriable.Position = value; } + + public override void Flush() => _innerRetriable.Flush(); + + public override Task FlushAsync(CancellationToken cancellationToken) => _innerRetriable.FlushAsync(cancellationToken); + + public override long Seek(long offset, SeekOrigin origin) => _innerRetriable.Seek(offset, origin); + + public override void SetLength(long value) => _innerRetriable.SetLength(value); + + public override void Write(byte[] buffer, int offset, int count) => _innerRetriable.Write(buffer, offset, count); + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => _innerRetriable.WriteAsync(buffer, offset, count, cancellationToken); + + public override void WriteByte(byte value) => _innerRetriable.WriteByte(value); + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) => _innerRetriable.BeginWrite(buffer, offset, count, callback, state); + + public override void EndWrite(IAsyncResult asyncResult) => _innerRetriable.EndWrite(asyncResult); + + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) => _innerRetriable.BeginRead(buffer, offset, count, callback, state); + + public override int ReadTimeout { get => _innerRetriable.ReadTimeout; set => _innerRetriable.ReadTimeout = value; } + + public override int WriteTimeout { get => _innerRetriable.WriteTimeout; set => _innerRetriable.WriteTimeout = value; } + +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + public override void CopyTo(Stream destination, int bufferSize) => _innerRetriable.CopyTo(destination, bufferSize); + + public override void Write(ReadOnlySpan buffer) => _innerRetriable.Write(buffer); + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) => _innerRetriable.WriteAsync(buffer, cancellationToken); +#endif + #endregion +} diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs index aa94b8df350d2..37b15a2245750 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs @@ -38,6 +38,57 @@ namespace Azure.Storage.Shared; /// internal class StructuredMessageDecodingStream : Stream { + internal class DecodedData : IDisposable + { + private byte[] _crcBackingArray; + + public long? InnerStreamLength { get; private set; } + public int? TotalSegments { get; private set; } + public StructuredMessage.Flags? Flags { get; private set; } + public List<(ReadOnlyMemory SegmentCrc, long SegmentEnd)> SegmentCrcs { get; private set; } + public ReadOnlyMemory TotalCrc { get; private set; } + public bool DecodeCompleted { get; private set; } + + internal void SetStreamHeaderData(int totalSegments, long innerStreamLength, StructuredMessage.Flags flags) + { + TotalSegments = totalSegments; + InnerStreamLength = innerStreamLength; + Flags = flags; + + if (flags.HasFlag(StructuredMessage.Flags.StorageCrc64)) + { + _crcBackingArray = ArrayPool.Shared.Rent((totalSegments + 1) * StructuredMessage.Crc64Length); + SegmentCrcs = new(); + } + } + + internal void ReportSegmentCrc(ReadOnlySpan crc, int segmentNum, long segmentEnd) + { + int offset = (segmentNum - 1) * StructuredMessage.Crc64Length; + crc.CopyTo(new Span(_crcBackingArray, offset, StructuredMessage.Crc64Length)); + SegmentCrcs.Add((new ReadOnlyMemory(_crcBackingArray, offset, StructuredMessage.Crc64Length), segmentEnd)); + } + + internal void ReportTotalCrc(ReadOnlySpan crc) + { + int offset = (TotalSegments.Value) * StructuredMessage.Crc64Length; + crc.CopyTo(new Span(_crcBackingArray, offset, StructuredMessage.Crc64Length)); + TotalCrc = new ReadOnlyMemory(_crcBackingArray, offset, StructuredMessage.Crc64Length); + } + internal void MarkComplete() + { + DecodeCompleted = true; + } + + public void Dispose() + { + if (_crcBackingArray is not null) + { + ArrayPool.Shared.Return(_crcBackingArray); + } + } + } + private enum SMRegion { StreamHeader, @@ -58,16 +109,16 @@ private enum SMRegion private int _segmentHeaderLength; private int _segmentFooterLength; - private int _totalSegments; - private long _innerStreamLength; + private long? _expectedInnerStreamLength; - private StructuredMessage.Flags _flags; - private bool _processedFooter = false; private bool _disposed; + private readonly DecodedData _decodedData; private StorageCrc64HashAlgorithm _totalContentCrc; private StorageCrc64HashAlgorithm _segmentCrc; + private readonly bool _validateChecksums; + public override bool CanRead => true; public override bool CanWrite => false; @@ -88,18 +139,31 @@ public override long Position set => throw new NotSupportedException(); } - public StructuredMessageDecodingStream( + public static (Stream DecodedStream, DecodedData DecodedData) WrapStream( + Stream innerStream, + long? expextedStreamLength = default) + { + DecodedData data = new(); + return (new StructuredMessageDecodingStream(innerStream, data, expextedStreamLength), data); + } + + private StructuredMessageDecodingStream( Stream innerStream, - long? expectedStreamLength = default) + DecodedData decodedData, + long? expectedStreamLength) { Argument.AssertNotNull(innerStream, nameof(innerStream)); + Argument.AssertNotNull(decodedData, nameof(decodedData)); - _innerStreamLength = expectedStreamLength ?? -1; + _expectedInnerStreamLength = expectedStreamLength; _innerBufferedStream = new BufferedStream(innerStream); + _decodedData = decodedData; // Assumes stream will be structured message 1.0. Will validate this when consuming stream. _streamHeaderLength = StructuredMessage.V1_0.StreamHeaderLength; _segmentHeaderLength = StructuredMessage.V1_0.SegmentHeaderLength; + + _validateChecksums = true; } #region Write @@ -191,14 +255,15 @@ public override async ValueTask ReadAsync(Memory buf, CancellationTok private void AssertDecodeFinished() { - if (_streamFooterLength > 0 && !_processedFooter) + if (_streamFooterLength > 0 && !_decodedData.DecodeCompleted) { throw Errors.InvalidStructuredMessage("Premature end of stream."); } - _processedFooter = true; + _decodedData.MarkComplete(); } private long _innerStreamConsumed = 0; + private long _decodedContentConsumed = 0; private SMRegion _currentRegion = SMRegion.StreamHeader; private int _currentSegmentNum = 0; private long _currentSegmentContentLength; @@ -243,6 +308,7 @@ private int Decode(Span buffer) _totalContentCrc?.Append(buffer.Slice(bufferConsumed, read)); _segmentCrc?.Append(buffer.Slice(bufferConsumed, read)); bufferConsumed += read; + _decodedContentConsumed += read; _currentSegmentContentRemaining -= read; if (_currentSegmentContentRemaining == 0) { @@ -370,24 +436,25 @@ private int ProcessStreamHeader(ReadOnlySpan span) StructuredMessage.V1_0.ReadStreamHeader( span.Slice(0, _streamHeaderLength), out long streamLength, - out _flags, - out _totalSegments); + out StructuredMessage.Flags flags, + out int totalSegments); + + _decodedData.SetStreamHeaderData(totalSegments, streamLength, flags); - if (_innerStreamLength > 0 && streamLength != _innerStreamLength) + if (_expectedInnerStreamLength.HasValue && _expectedInnerStreamLength.Value != streamLength) { throw Errors.InvalidStructuredMessage("Unexpected message size."); } - else - { - _innerStreamLength = streamLength; - } - if (_flags.HasFlag(StructuredMessage.Flags.StorageCrc64)) + if (_decodedData.Flags.Value.HasFlag(StructuredMessage.Flags.StorageCrc64)) { - _segmentFooterLength = _flags.HasFlag(StructuredMessage.Flags.StorageCrc64) ? StructuredMessage.Crc64Length : 0; - _streamFooterLength = _flags.HasFlag(StructuredMessage.Flags.StorageCrc64) ? StructuredMessage.Crc64Length : 0; - _segmentCrc = StorageCrc64HashAlgorithm.Create(); - _totalContentCrc = StorageCrc64HashAlgorithm.Create(); + _segmentFooterLength = StructuredMessage.Crc64Length; + _streamFooterLength = StructuredMessage.Crc64Length; + if (_validateChecksums) + { + _segmentCrc = StorageCrc64HashAlgorithm.Create(); + _totalContentCrc = StorageCrc64HashAlgorithm.Create(); + } } _currentRegion = SMRegion.SegmentHeader; return _streamHeaderLength; @@ -396,30 +463,34 @@ private int ProcessStreamHeader(ReadOnlySpan span) private int ProcessStreamFooter(ReadOnlySpan span) { int totalProcessed = 0; - if (_flags.HasFlag(StructuredMessage.Flags.StorageCrc64)) + if (_decodedData.Flags.Value.HasFlag(StructuredMessage.Flags.StorageCrc64)) { totalProcessed += StructuredMessage.Crc64Length; - using (ArrayPool.Shared.RentAsSpanDisposable(StructuredMessage.Crc64Length, out Span calculated)) + ReadOnlySpan expected = span.Slice(0, StructuredMessage.Crc64Length); + _decodedData.ReportTotalCrc(expected); + if (_validateChecksums) { - _totalContentCrc.GetCurrentHash(calculated); - ReadOnlySpan expected = span.Slice(0, StructuredMessage.Crc64Length); - if (!calculated.SequenceEqual(expected)) + using (ArrayPool.Shared.RentAsSpanDisposable(StructuredMessage.Crc64Length, out Span calculated)) { - throw Errors.ChecksumMismatch(calculated, expected); + _totalContentCrc.GetCurrentHash(calculated); + if (!calculated.SequenceEqual(expected)) + { + throw Errors.ChecksumMismatch(calculated, expected); + } } } } - if (_innerStreamConsumed != _innerStreamLength) + if (_innerStreamConsumed != _decodedData.InnerStreamLength) { throw Errors.InvalidStructuredMessage("Unexpected message size."); } - if (_currentSegmentNum != _totalSegments) + if (_currentSegmentNum != _decodedData.TotalSegments) { throw Errors.InvalidStructuredMessage("Missing expected message segments."); } - _processedFooter = true; + _decodedData.MarkComplete(); return totalProcessed; } @@ -442,21 +513,25 @@ private int ProcessSegmentHeader(ReadOnlySpan span) private int ProcessSegmentFooter(ReadOnlySpan span) { int totalProcessed = 0; - if (_flags.HasFlag(StructuredMessage.Flags.StorageCrc64)) + if (_decodedData.Flags.Value.HasFlag(StructuredMessage.Flags.StorageCrc64)) { totalProcessed += StructuredMessage.Crc64Length; - using (ArrayPool.Shared.RentAsSpanDisposable(StructuredMessage.Crc64Length, out Span calculated)) + ReadOnlySpan expected = span.Slice(0, StructuredMessage.Crc64Length); + if (_validateChecksums) { - _segmentCrc.GetCurrentHash(calculated); - _segmentCrc = StorageCrc64HashAlgorithm.Create(); - ReadOnlySpan expected = span.Slice(0, StructuredMessage.Crc64Length); - if (!calculated.SequenceEqual(expected)) + using (ArrayPool.Shared.RentAsSpanDisposable(StructuredMessage.Crc64Length, out Span calculated)) { - throw Errors.ChecksumMismatch(calculated, expected); + _segmentCrc.GetCurrentHash(calculated); + _segmentCrc = StorageCrc64HashAlgorithm.Create(); + if (!calculated.SequenceEqual(expected)) + { + throw Errors.ChecksumMismatch(calculated, expected); + } } } + _decodedData.ReportSegmentCrc(expected, _currentSegmentNum, _decodedContentConsumed); } - _currentRegion = _currentSegmentNum == _totalSegments ? SMRegion.StreamFooter : SMRegion.SegmentHeader; + _currentRegion = _currentSegmentNum == _decodedData.TotalSegments ? SMRegion.StreamFooter : SMRegion.SegmentHeader; return totalProcessed; } #endregion diff --git a/sdk/storage/Azure.Storage.Common/tests/Azure.Storage.Common.Tests.csproj b/sdk/storage/Azure.Storage.Common/tests/Azure.Storage.Common.Tests.csproj index 0c3807d9b74ff..8bf802d14e766 100644 --- a/sdk/storage/Azure.Storage.Common/tests/Azure.Storage.Common.Tests.csproj +++ b/sdk/storage/Azure.Storage.Common/tests/Azure.Storage.Common.Tests.csproj @@ -13,6 +13,8 @@ + + @@ -46,6 +48,7 @@ + diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs new file mode 100644 index 0000000000000..2b25a887d2e39 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Threading.Tasks; +using Azure.Storage.Shared; +using NUnit.Framework; + +namespace Azure.Storage.Tests; + +[TestFixture(true)] +[TestFixture(false)] +public class StructuredMessageDecodingRetriableStreamTests +{ + public bool Async { get; } + + public StructuredMessageDecodingRetriableStreamTests(bool async) + { + Async = async; + } + + [Test] + public async ValueTask UninterruptedStream() + { + byte[] data = new Random().NextBytesInline(4 * Constants.KB).ToArray(); + byte[] dest = new byte[data.Length]; + + using (Stream src = new MemoryStream(data)) + using (Stream dst = new MemoryStream(dest)) + using (StructuredMessageDecodingRetriableStream retriable = new( + src, + offset => (new MemoryStream(data, (int)offset, data.Length - (int)offset), new StructuredMessageDecodingStream.DecodedData()), + offset => new(Task.FromResult(((Stream)new MemoryStream(data, (int)offset, data.Length - (int)offset), new StructuredMessageDecodingStream.DecodedData()))), + null, + 5)) + { + await retriable.CopyToInternal(dst, Async, default); + } + + Assert.AreEqual(data, dest); + } +} diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs index f881a70c8e78f..2789672df4976 100644 --- a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs @@ -116,7 +116,7 @@ public async Task DecodesData( new Random().NextBytes(originalData); byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, segmentContentLength, flags); - Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); byte[] decodedData; using (MemoryStream dest = new()) { @@ -136,7 +136,7 @@ public void BadStreamBadVersion() encodedData[0] = byte.MaxValue; - Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); } @@ -154,7 +154,7 @@ public async Task BadSegmentCrcThrows() encodedData[badBytePos] = (byte)~encodedData[badBytePos]; MemoryStream encodedDataStream = new(encodedData); - Stream decodingStream = new StructuredMessageDecodingStream(encodedDataStream); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(encodedDataStream); // manual try/catch to validate the proccess failed mid-stream rather than the end const int copyBufferSize = 4; @@ -183,7 +183,7 @@ public void BadStreamCrcThrows() encodedData[originalData.Length - 1] = (byte)~encodedData[originalData.Length - 1]; - Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); } @@ -196,7 +196,7 @@ public void BadStreamWrongContentLength() BinaryPrimitives.WriteInt64LittleEndian(new Span(encodedData, V1_0.StreamHeaderMessageLengthOffset, 8), 123456789L); - Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); } @@ -216,7 +216,7 @@ public void BadStreamWrongSegmentCount(int difference) BinaryPrimitives.WriteInt16LittleEndian( new Span(encodedData, V1_0.StreamHeaderSegmentCountOffset, 2), (short)(numSegments + difference)); - Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); } @@ -230,7 +230,7 @@ public void BadStreamWrongSegmentNum() BinaryPrimitives.WriteInt16LittleEndian( new Span(encodedData, V1_0.StreamHeaderLength + V1_0.SegmentHeaderNumOffset, 2), 123); - Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); } @@ -248,7 +248,7 @@ public async Task BadStreamWrongContentLength( new Span(encodedData, V1_0.StreamHeaderMessageLengthOffset, 8), encodedData.Length + difference); - Stream decodingStream = new StructuredMessageDecodingStream( + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream( new MemoryStream(encodedData), lengthProvided ? (long?)encodedData.Length : default); @@ -284,14 +284,14 @@ public void BadStreamMissingExpectedStreamFooter() byte[] brokenData = new byte[encodedData.Length - Crc64Length]; new Span(encodedData, 0, encodedData.Length - Crc64Length).CopyTo(brokenData); - Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(brokenData)); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(brokenData)); Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); } [Test] public void NoSeek() { - StructuredMessageDecodingStream stream = new(new MemoryStream()); + (Stream stream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream()); Assert.That(stream.CanSeek, Is.False); Assert.That(() => stream.Length, Throws.TypeOf()); @@ -303,7 +303,7 @@ public void NoSeek() [Test] public void NoWrite() { - StructuredMessageDecodingStream stream = new(new MemoryStream()); + (Stream stream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream()); byte[] data = new byte[1024]; new Random().NextBytes(data); diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageStreamRoundtripTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageStreamRoundtripTests.cs index 633233db2e73c..61583aa1ebe4e 100644 --- a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageStreamRoundtripTests.cs +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageStreamRoundtripTests.cs @@ -113,8 +113,8 @@ public async Task RoundTrip( byte[] roundtripData; using (MemoryStream source = new(originalData)) - using (StructuredMessageEncodingStream encode = new(source, segmentLength, flags)) - using (StructuredMessageDecodingStream decode = new(encode)) + using (Stream encode = new StructuredMessageEncodingStream(source, segmentLength, flags)) + using (Stream decode = StructuredMessageDecodingStream.WrapStream(encode).DecodedStream) using (MemoryStream dest = new()) { await CopyStream(source, dest, readLen); From 57535c5b73d3c98638ff509f7c067666fe8a7e1d Mon Sep 17 00:00:00 2001 From: Jocelyn Schreppler Date: Mon, 13 May 2024 15:47:10 -0400 Subject: [PATCH 2/5] rewind mock test --- ...tructuredMessageDecodingRetriableStream.cs | 9 +- .../tests/Shared/FaultyStream.cs | 13 +- ...uredMessageDecodingRetriableStreamTests.cs | 140 +++++++++++++++++- 3 files changed, 148 insertions(+), 14 deletions(-) diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs index 87e1800baa895..bfca1739527c2 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs @@ -18,13 +18,14 @@ internal class StructuredMessageDecodingRetriableStream : Stream private readonly Stream _innerRetriable; private long _decodedBytesRead; - private readonly List _decodedDatas = new(); + private readonly List _decodedDatas; private readonly Func _decodingStreamFactory; private readonly Func> _decodingAsyncStreamFactory; public StructuredMessageDecodingRetriableStream( - Stream initialResponse, + Stream initialDecodingStream, + StructuredMessageDecodingStream.DecodedData initialDecodedData, Func decodingStreamFactory, Func> decodingAsyncStreamFactory, ResponseClassifier responseClassifier, @@ -32,7 +33,8 @@ public StructuredMessageDecodingRetriableStream( { _decodingStreamFactory = decodingStreamFactory; _decodingAsyncStreamFactory = decodingAsyncStreamFactory; - _innerRetriable = RetriableStream.Create(initialResponse, StreamFactory, StreamFactoryAsync, responseClassifier, maxRetries); + _innerRetriable = RetriableStream.Create(initialDecodingStream, StreamFactory, StreamFactoryAsync, responseClassifier, maxRetries); + _decodedDatas = new() { initialDecodedData }; } private Stream StreamFactory(long _) @@ -85,7 +87,6 @@ protected override void Dispose(bool disposing) } #region Read - public override int Read(byte[] buffer, int offset, int count) { int read = _innerRetriable.Read(buffer, offset, count); diff --git a/sdk/storage/Azure.Storage.Common/tests/Shared/FaultyStream.cs b/sdk/storage/Azure.Storage.Common/tests/Shared/FaultyStream.cs index 7411eb1499312..f4e4b92ed73c4 100644 --- a/sdk/storage/Azure.Storage.Common/tests/Shared/FaultyStream.cs +++ b/sdk/storage/Azure.Storage.Common/tests/Shared/FaultyStream.cs @@ -15,6 +15,7 @@ internal class FaultyStream : Stream private readonly Exception _exceptionToRaise; private int _remainingExceptions; private Action _onFault; + private long _position = 0; public FaultyStream( Stream innerStream, @@ -40,7 +41,7 @@ public FaultyStream( public override long Position { - get => _innerStream.Position; + get => CanSeek ? _innerStream.Position : _position; set => _innerStream.Position = value; } @@ -53,7 +54,9 @@ public override int Read(byte[] buffer, int offset, int count) { if (_remainingExceptions == 0 || Position + count <= _raiseExceptionAt || _raiseExceptionAt >= _innerStream.Length) { - return _innerStream.Read(buffer, offset, count); + int read = _innerStream.Read(buffer, offset, count); + _position += read; + return read; } else { @@ -61,11 +64,13 @@ public override int Read(byte[] buffer, int offset, int count) } } - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { if (_remainingExceptions == 0 || Position + count <= _raiseExceptionAt || _raiseExceptionAt >= _innerStream.Length) { - return _innerStream.ReadAsync(buffer, offset, count, cancellationToken); + int read = await _innerStream.ReadAsync(buffer, offset, count, cancellationToken); + _position += read; + return read; } else { diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs index 2b25a887d2e39..c9b0171c4d593 100644 --- a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs @@ -3,8 +3,12 @@ using System; using System.IO; +using System.Threading; using System.Threading.Tasks; +using Azure.Core; using Azure.Storage.Shared; +using Azure.Storage.Test.Shared; +using Moq; using NUnit.Framework; namespace Azure.Storage.Tests; @@ -20,24 +24,148 @@ public StructuredMessageDecodingRetriableStreamTests(bool async) Async = async; } + private Mock AllExceptionsRetry() + { + Mock mock = new(MockBehavior.Strict); + mock.Setup(rc => rc.IsRetriableException(It.IsAny())).Returns(true); + return mock; + } + [Test] public async ValueTask UninterruptedStream() { byte[] data = new Random().NextBytesInline(4 * Constants.KB).ToArray(); byte[] dest = new byte[data.Length]; + // mock with a simple MemoryStream rather than an actual StructuredMessageDecodingStream using (Stream src = new MemoryStream(data)) + using (Stream retriableSrc = new StructuredMessageDecodingRetriableStream(src, new(), default, default, default, 1)) using (Stream dst = new MemoryStream(dest)) - using (StructuredMessageDecodingRetriableStream retriable = new( - src, - offset => (new MemoryStream(data, (int)offset, data.Length - (int)offset), new StructuredMessageDecodingStream.DecodedData()), + { + await retriableSrc.CopyToInternal(dst, Async, default); + } + + Assert.AreEqual(data, dest); + } + + [Test] + public async ValueTask OneInterrupt_DataIntact() + { + const int segments = 4; + const int segmentLen = Constants.KB; + const int interruptPos = segmentLen + 10; + + Random r = new(); + byte[] data = r.NextBytesInline(segments * Constants.KB).ToArray(); + byte[] dest = new byte[data.Length]; + + // Mock a decoded data for the mocked StructuredMessageDecodingStream + StructuredMessageDecodingStream.DecodedData initialDecodedData = new(); + initialDecodedData.SetStreamHeaderData(segments, data.Length, StructuredMessage.Flags.StorageCrc64); + // By the time of interrupt, there will be one segment reported + initialDecodedData.ReportSegmentCrc(r.NextBytesInline(StructuredMessage.Crc64Length), 1, segmentLen); + + // mock with a simple MemoryStream rather than an actual StructuredMessageDecodingStream + using (Stream src = new MemoryStream(data)) + using (Stream faultySrc = new FaultyStream(src, interruptPos, 1, new Exception(), default)) + using (Stream retriableSrc = new StructuredMessageDecodingRetriableStream( + faultySrc, + initialDecodedData, + offset => (new MemoryStream(data, (int)offset, data.Length - (int)offset), new()), offset => new(Task.FromResult(((Stream)new MemoryStream(data, (int)offset, data.Length - (int)offset), new StructuredMessageDecodingStream.DecodedData()))), - null, - 5)) + AllExceptionsRetry().Object, + 1)) + using (Stream dst = new MemoryStream(dest)) { - await retriable.CopyToInternal(dst, Async, default); + await retriableSrc.CopyToInternal(dst, 128, Async, default); } Assert.AreEqual(data, dest); } + + [Test] + public async ValueTask OneInterrupt_AppropriateRewind() + { + const int segments = 2; + const int segmentLen = Constants.KB; + const int dataLen = segments * segmentLen; + const int readLen = segmentLen / 4; + const int interruptOffset = 10; + const int interruptPos = segmentLen + (2 * readLen) + interruptOffset; + Random r = new(); + + // Mock a decoded data for the mocked StructuredMessageDecodingStream + StructuredMessageDecodingStream.DecodedData initialDecodedData = new(); + initialDecodedData.SetStreamHeaderData(segments, segments * segmentLen, StructuredMessage.Flags.StorageCrc64); + // By the time of interrupt, there will be one segment reported + initialDecodedData.ReportSegmentCrc(r.NextBytesInline(StructuredMessage.Crc64Length), 1, segmentLen); + + Mock mock = new(MockBehavior.Strict); + mock.SetupGet(s => s.CanRead).Returns(true); + mock.SetupGet(s => s.CanSeek).Returns(false); + if (Async) + { + mock.SetupSequence(s => s.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), default)) + .Returns(Task.FromResult(readLen)) // start first segment + .Returns(Task.FromResult(readLen)) + .Returns(Task.FromResult(readLen)) + .Returns(Task.FromResult(readLen)) // finish first segment + .Returns(Task.FromResult(readLen)) // start second segment + .Returns(Task.FromResult(readLen)) + // faulty stream interrupt + .Returns(Task.FromResult(readLen * 2)) // restart second segment. fast-forward uses an internal 4KB buffer, so it will leap the 512 byte catchup all at once + .Returns(Task.FromResult(readLen)) + .Returns(Task.FromResult(readLen)) // end second segment + .Returns(Task.FromResult(0)) // signal end of stream + .Returns(Task.FromResult(0)) // second signal needed for stream wrapping reasons + ; + } + else + { + mock.SetupSequence(s => s.Read(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(readLen) // start first segment + .Returns(readLen) + .Returns(readLen) + .Returns(readLen) // finish first segment + .Returns(readLen) // start second segment + .Returns(readLen) + // faulty stream interrupt + .Returns(readLen * 2) // restart second segment. fast-forward uses an internal 4KB buffer, so it will leap the 512 byte catchup all at once + .Returns(readLen) + .Returns(readLen) // end second segment + .Returns(0) // signal end of stream + .Returns(0) // second signal needed for stream wrapping reasons + ; + } + Stream faultySrc = new FaultyStream(mock.Object, interruptPos, 1, new Exception(), default); + Stream retriableSrc = new StructuredMessageDecodingRetriableStream( + faultySrc, + initialDecodedData, + offset => (mock.Object, new()), + offset => new(Task.FromResult((mock.Object, new StructuredMessageDecodingStream.DecodedData()))), + AllExceptionsRetry().Object, + 1); + + int totalRead = 0; + int read = 0; + byte[] buf = new byte[readLen]; + if (Async) + { + while ((read = await retriableSrc.ReadAsync(buf, 0, buf.Length)) > 0) + { + totalRead += read; + } + } + else + { + while ((read = retriableSrc.Read(buf, 0, buf.Length)) > 0) + { + totalRead += read; + } + } + await retriableSrc.CopyToInternal(Stream.Null, readLen, Async, default); + + // Asserts we read exactly the data length, excluding the fastforward of the inner stream + Assert.That(totalRead, Is.EqualTo(dataLen)); + } } From 844312b07935fc3a54587ad1ad6fe51fc8367d95 Mon Sep 17 00:00:00 2001 From: Jocelyn Schreppler Date: Mon, 13 May 2024 16:12:39 -0400 Subject: [PATCH 3/5] bugfix --- .../src/Shared/StructuredMessageDecodingRetriableStream.cs | 4 ---- .../tests/StructuredMessageDecodingRetriableStreamTests.cs | 5 +++-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs index bfca1739527c2..99f2dfec20d16 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs @@ -141,8 +141,6 @@ public override int EndRead(IAsyncResult asyncResult) public override bool CanTimeout => _innerRetriable.CanTimeout; - public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) => _innerRetriable.CopyToAsync(destination, bufferSize, cancellationToken); - public override long Length => _innerRetriable.Length; public override long Position { get => _innerRetriable.Position; set => _innerRetriable.Position = value; } @@ -172,8 +170,6 @@ public override int EndRead(IAsyncResult asyncResult) public override int WriteTimeout { get => _innerRetriable.WriteTimeout; set => _innerRetriable.WriteTimeout = value; } #if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER - public override void CopyTo(Stream destination, int bufferSize) => _innerRetriable.CopyTo(destination, bufferSize); - public override void Write(ReadOnlySpan buffer) => _innerRetriable.Write(buffer); public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) => _innerRetriable.WriteAsync(buffer, cancellationToken); diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs index c9b0171c4d593..6719cc0f75ecf 100644 --- a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs @@ -53,7 +53,8 @@ public async ValueTask OneInterrupt_DataIntact() { const int segments = 4; const int segmentLen = Constants.KB; - const int interruptPos = segmentLen + 10; + const int readLen = 128; + const int interruptPos = segmentLen + (3 * readLen) + 10; Random r = new(); byte[] data = r.NextBytesInline(segments * Constants.KB).ToArray(); @@ -77,7 +78,7 @@ public async ValueTask OneInterrupt_DataIntact() 1)) using (Stream dst = new MemoryStream(dest)) { - await retriableSrc.CopyToInternal(dst, 128, Async, default); + await retriableSrc.CopyToInternal(dst, readLen, Async, default); } Assert.AreEqual(data, dest); From c243bb600cfaadd815aa20a22bfb081d8a32c75f Mon Sep 17 00:00:00 2001 From: Jocelyn Schreppler Date: Thu, 16 May 2024 11:35:52 -0400 Subject: [PATCH 4/5] bugfix --- ...tructuredMessageDecodingRetriableStream.cs | 4 +-- ...uredMessageDecodingRetriableStreamTests.cs | 27 ++++++++++++++----- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs index 99f2dfec20d16..444fe3eb2e0a9 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs @@ -39,7 +39,7 @@ public StructuredMessageDecodingRetriableStream( private Stream StreamFactory(long _) { - long offset = _decodedDatas.LastOrDefault()?.SegmentCrcs?.LastOrDefault().SegmentEnd ?? 0; + long offset = _decodedDatas.Select(d => d.SegmentCrcs?.LastOrDefault().SegmentEnd ?? 0).Sum(); (Stream decodingStream, StructuredMessageDecodingStream.DecodedData decodedData) = _decodingStreamFactory(offset); _decodedDatas.Add(decodedData); FastForwardInternal(decodingStream, _decodedBytesRead - offset, false).EnsureCompleted(); @@ -48,7 +48,7 @@ private Stream StreamFactory(long _) private async ValueTask StreamFactoryAsync(long _) { - long offset = _decodedDatas.LastOrDefault()?.SegmentCrcs?.LastOrDefault().SegmentEnd ?? 0; + long offset = _decodedDatas.Select(d => d.SegmentCrcs?.LastOrDefault().SegmentEnd ?? 0).Sum(); (Stream decodingStream, StructuredMessageDecodingStream.DecodedData decodedData) = await _decodingAsyncStreamFactory(offset).ConfigureAwait(false); _decodedDatas.Add(decodedData); await FastForwardInternal(decodingStream, _decodedBytesRead - offset, true).ConfigureAwait(false); diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs index 6719cc0f75ecf..d3733dc2ffe9e 100644 --- a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs @@ -49,7 +49,7 @@ public async ValueTask UninterruptedStream() } [Test] - public async ValueTask OneInterrupt_DataIntact() + public async ValueTask Interrupt_DataIntact([Values(true, false)] bool multipleInterrupts) { const int segments = 4; const int segmentLen = Constants.KB; @@ -63,19 +63,34 @@ public async ValueTask OneInterrupt_DataIntact() // Mock a decoded data for the mocked StructuredMessageDecodingStream StructuredMessageDecodingStream.DecodedData initialDecodedData = new(); initialDecodedData.SetStreamHeaderData(segments, data.Length, StructuredMessage.Flags.StorageCrc64); - // By the time of interrupt, there will be one segment reported + // for test purposes, initialize a DecodedData, since we are not actively decoding in this test initialDecodedData.ReportSegmentCrc(r.NextBytesInline(StructuredMessage.Crc64Length), 1, segmentLen); + (Stream DecodingStream, StructuredMessageDecodingStream.DecodedData DecodedData) Factory(long offset, bool faulty) + { + Stream stream = new MemoryStream(data, (int)offset, data.Length - (int)offset); + if (faulty) + { + stream = new FaultyStream(stream, interruptPos, 1, new Exception(), () => { }); + } + // Mock a decoded data for the mocked StructuredMessageDecodingStream + StructuredMessageDecodingStream.DecodedData decodedData = new(); + decodedData.SetStreamHeaderData(segments, data.Length, StructuredMessage.Flags.StorageCrc64); + // for test purposes, initialize a DecodedData, since we are not actively decoding in this test + decodedData.ReportSegmentCrc(r.NextBytesInline(StructuredMessage.Crc64Length), 1, segmentLen); + return (stream, decodedData); + } + // mock with a simple MemoryStream rather than an actual StructuredMessageDecodingStream using (Stream src = new MemoryStream(data)) - using (Stream faultySrc = new FaultyStream(src, interruptPos, 1, new Exception(), default)) + using (Stream faultySrc = new FaultyStream(src, interruptPos, 1, new Exception(), () => { })) using (Stream retriableSrc = new StructuredMessageDecodingRetriableStream( faultySrc, initialDecodedData, - offset => (new MemoryStream(data, (int)offset, data.Length - (int)offset), new()), - offset => new(Task.FromResult(((Stream)new MemoryStream(data, (int)offset, data.Length - (int)offset), new StructuredMessageDecodingStream.DecodedData()))), + offset => Factory(offset, multipleInterrupts), + offset => new ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.DecodedData DecodedData)>(Factory(offset, multipleInterrupts)), AllExceptionsRetry().Object, - 1)) + int.MaxValue)) using (Stream dst = new MemoryStream(dest)) { await retriableSrc.CopyToInternal(dst, readLen, Async, default); From 68d9654662fb8e7a3c90c7e1e60e8031646b8650 Mon Sep 17 00:00:00 2001 From: Jocelyn Schreppler Date: Thu, 16 May 2024 15:57:06 -0400 Subject: [PATCH 5/5] tests --- ...uredMessageDecodingRetriableStreamTests.cs | 43 ++++++++++++++++++- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs index d3733dc2ffe9e..666933e546189 100644 --- a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs @@ -49,7 +49,7 @@ public async ValueTask UninterruptedStream() } [Test] - public async ValueTask Interrupt_DataIntact([Values(true, false)] bool multipleInterrupts) + public async Task Interrupt_DataIntact([Values(true, false)] bool multipleInterrupts) { const int segments = 4; const int segmentLen = Constants.KB; @@ -100,7 +100,7 @@ public async ValueTask Interrupt_DataIntact([Values(true, false)] bool multipleI } [Test] - public async ValueTask OneInterrupt_AppropriateRewind() + public async Task Interrupt_AppropriateRewind() { const int segments = 2; const int segmentLen = Constants.KB; @@ -184,4 +184,43 @@ public async ValueTask OneInterrupt_AppropriateRewind() // Asserts we read exactly the data length, excluding the fastforward of the inner stream Assert.That(totalRead, Is.EqualTo(dataLen)); } + + [Test] + public async Task Interrupt_ProperDecode([Values(true, false)] bool multipleInterrupts) + { + // decoding stream inserts a buffered layer of 4 KB. use larger sizes to avoid interference from it. + const int segments = 4; + const int segmentLen = 128 * Constants.KB; + const int readLen = 8 * Constants.KB; + const int interruptPos = segmentLen + (3 * readLen) + 10; + + Random r = new(); + byte[] data = r.NextBytesInline(segments * Constants.KB).ToArray(); + byte[] dest = new byte[data.Length]; + + (Stream DecodingStream, StructuredMessageDecodingStream.DecodedData DecodedData) Factory(long offset, bool faulty) + { + Stream stream = new MemoryStream(data, (int)offset, data.Length - (int)offset); + stream = new StructuredMessageEncodingStream(stream, segmentLen, StructuredMessage.Flags.StorageCrc64); + if (faulty) + { + stream = new FaultyStream(stream, interruptPos, 1, new Exception(), () => { }); + } + return StructuredMessageDecodingStream.WrapStream(stream); + } + + (Stream decodingStream, StructuredMessageDecodingStream.DecodedData decodedData) = Factory(0, true); + using Stream retriableSrc = new StructuredMessageDecodingRetriableStream( + decodingStream, + decodedData, + offset => Factory(offset, multipleInterrupts), + offset => new ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.DecodedData DecodedData)>(Factory(offset, multipleInterrupts)), + AllExceptionsRetry().Object, + int.MaxValue); + using Stream dst = new MemoryStream(dest); + + await retriableSrc.CopyToInternal(dst, readLen, Async, default); + + Assert.AreEqual(data, dest); + } }