diff --git a/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs b/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs index d271615f6f4b7..c92e025c0d6ab 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs @@ -1563,7 +1563,7 @@ ValueTask> Factory(long offset, bool force return StructuredMessageDecodingStream.WrapStream(result.Value.Content, result.Value.Details.ContentLength); } Stream stream; - if (response.GetRawResponse().Headers.Contains(Constants.StructuredMessage.CrcStructuredMessageHeader)) + if (response.GetRawResponse().Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)) { (Stream decodingStream, StructuredMessageDecodingStream.DecodedData decodedData) = StructuredMessageDecodingStream.WrapStream( response.Value.Content, response.Value.Details.ContentLength); @@ -1600,7 +1600,7 @@ ValueTask> Factory(long offset, bool force validationOptions.ChecksumAlgorithm != StorageChecksumAlgorithm.None && validationOptions.AutoValidateChecksum && // structured message decoding does the validation for us - !response.GetRawResponse().Headers.Contains(Constants.StructuredMessage.CrcStructuredMessageHeader)) + !response.GetRawResponse().Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)) { // safe-buffer; transactional hash download limit well below maxInt var readDestStream = new MemoryStream((int)response.Value.Details.ContentLength); diff --git a/sdk/storage/Azure.Storage.Blobs/src/PartitionedDownloader.cs b/sdk/storage/Azure.Storage.Blobs/src/PartitionedDownloader.cs index 276bdadb673fa..08a1090716f2b 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/PartitionedDownloader.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/PartitionedDownloader.cs @@ -417,7 +417,7 @@ private async Task CopyToInternal( { CancellationHelper.ThrowIfCancellationRequested(cancellationToken); // if structured message, this crc is validated in the decoding process. don't decode it here. - using IHasher hasher = response.GetRawResponse().Headers.Contains(Constants.StructuredMessage.CrcStructuredMessageHeader) + using IHasher hasher = response.GetRawResponse().Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader) ? null : ContentHasher.GetHasherFromAlgorithmId(_validationAlgorithm); using Stream rawSource = response.Value.Content; diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/Constants.cs b/sdk/storage/Azure.Storage.Common/src/Shared/Constants.cs index 0636041a65134..4893b971d6529 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/Constants.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/Constants.cs @@ -667,8 +667,8 @@ internal static class AccountResources internal static class StructuredMessage { - public const string CrcStructuredMessageHeader = "x-ms-structured-body"; - public const string CrcStructuredContentLength = "x-ms-structured-content-length"; + public const string StructuredMessageHeader = "x-ms-structured-body"; + public const string StructuredContentLength = "x-ms-structured-content-length"; public const string CrcStructuredMessage = "XSM/1.0; properties=crc64"; public const int DefaultSegmentContentLength = 4 * MB; public const int MaxDownloadCrcWithHeader = 4 * MB; diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/Errors.Clients.cs b/sdk/storage/Azure.Storage.Common/src/Shared/Errors.Clients.cs index 22def52ec719c..4d49edeb72ecf 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/Errors.Clients.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/Errors.Clients.cs @@ -106,11 +106,17 @@ public static ArgumentException VersionNotSupported(string paramName) public static RequestFailedException ClientRequestIdMismatch(Response response, string echo, string original) => new RequestFailedException(response.Status, $"Response x-ms-client-request-id '{echo}' does not match the original expected request id, '{original}'.", null); + public static InvalidDataException StructuredMessageNotAcknowledgedGET(Response response) + => new InvalidDataException($"Response does not acknowledge structured message was requested. Unknown data structure in response body."); + + public static InvalidDataException StructuredMessageNotAcknowledgedPUT(Response response) + => new InvalidDataException($"Response does not acknowledge structured message was sent. Unexpected data may have been persisted to storage."); + public static ArgumentException TransactionalHashingNotSupportedWithClientSideEncryption() => new ArgumentException("Client-side encryption and transactional hashing are not supported at the same time."); public static InvalidDataException ExpectedStructuredMessage() - => new InvalidDataException($"Expected {Constants.StructuredMessage.CrcStructuredMessageHeader} in response, but found none."); + => new InvalidDataException($"Expected {Constants.StructuredMessage.StructuredMessageHeader} in response, but found none."); public static void VerifyHttpsTokenAuth(Uri uri) { diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StorageRequestValidationPipelinePolicy.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StorageRequestValidationPipelinePolicy.cs index 0cef4f4d8d4ed..9f4ddb5249e82 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StorageRequestValidationPipelinePolicy.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StorageRequestValidationPipelinePolicy.cs @@ -33,6 +33,35 @@ public override void OnReceivedResponse(HttpMessage message) { throw Errors.ClientRequestIdMismatch(message.Response, echo.First(), original); } + + if (message.Request.Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader) && + message.Request.Headers.Contains(Constants.StructuredMessage.StructuredContentLength)) + { + AssertStructuredMessageAcknowledgedPUT(message); + } + else if (message.Request.Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)) + { + AssertStructuredMessageAcknowledgedGET(message); + } + } + + private static void AssertStructuredMessageAcknowledgedPUT(HttpMessage message) + { + if (!message.Response.IsError && + !message.Response.Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)) + { + throw Errors.StructuredMessageNotAcknowledgedPUT(message.Response); + } + } + + private static void AssertStructuredMessageAcknowledgedGET(HttpMessage message) + { + if (!message.Response.IsError && + !(message.Response.Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader) && + message.Response.Headers.Contains(Constants.StructuredMessage.StructuredContentLength))) + { + throw Errors.StructuredMessageNotAcknowledgedGET(message.Response); + } } } } diff --git a/sdk/storage/Azure.Storage.Common/tests/Shared/RequestExtensions.cs b/sdk/storage/Azure.Storage.Common/tests/Shared/RequestExtensions.cs index 0f9cdb07773f7..ad395e862f827 100644 --- a/sdk/storage/Azure.Storage.Common/tests/Shared/RequestExtensions.cs +++ b/sdk/storage/Azure.Storage.Common/tests/Shared/RequestExtensions.cs @@ -14,7 +14,7 @@ public static string AssertHeaderPresent(this Request request, string headerName { if (request.Headers.TryGetValue(headerName, out string value)) { - return headerName == Constants.StructuredMessage.CrcStructuredMessageHeader ? null : value; + return headerName == Constants.StructuredMessage.StructuredMessageHeader ? null : value; } StringBuilder sb = new StringBuilder() .AppendLine($"`{headerName}` expected on request but was not found.") diff --git a/sdk/storage/Azure.Storage.Common/tests/Shared/TransferValidationTestBase.cs b/sdk/storage/Azure.Storage.Common/tests/Shared/TransferValidationTestBase.cs index a406744f94d46..ed5651b0b0fc5 100644 --- a/sdk/storage/Azure.Storage.Common/tests/Shared/TransferValidationTestBase.cs +++ b/sdk/storage/Azure.Storage.Common/tests/Shared/TransferValidationTestBase.cs @@ -219,7 +219,7 @@ void AssertChecksum(Request req, string headerName) AssertChecksum(request, "Content-MD5"); break; case StorageChecksumAlgorithm.StorageCrc64: - AssertChecksum(request, Constants.StructuredMessage.CrcStructuredMessageHeader); + AssertChecksum(request, Constants.StructuredMessage.StructuredMessageHeader); break; default: throw new Exception($"Bad {nameof(StorageChecksumAlgorithm)} provided to {nameof(GetRequestChecksumHeaderAssertion)}."); @@ -302,7 +302,7 @@ void AssertChecksum(ResponseHeaders headers, string headerName) AssertChecksum(response.Headers, "Content-MD5"); break; case StorageChecksumAlgorithm.StorageCrc64: - AssertChecksum(response.Headers, Constants.StructuredMessage.CrcStructuredMessageHeader); + AssertChecksum(response.Headers, Constants.StructuredMessage.StructuredMessageHeader); break; default: throw new Exception($"Bad {nameof(StorageChecksumAlgorithm)} provided to {nameof(GetRequestChecksumHeaderAssertion)}."); @@ -1744,7 +1744,7 @@ public virtual async Task DownloadSuccessfulHashVerification(StorageChecksumAlgo Assert.True(response.Headers.Contains("Content-MD5")); break; case StorageChecksumAlgorithm.StorageCrc64: - Assert.True(response.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessageHeader)); + Assert.True(response.Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)); break; default: Assert.Fail("Test can't validate given algorithm type."); @@ -1908,7 +1908,7 @@ public virtual async Task DownloadUsesDefaultClientValidationOptions( Assert.True(response.Headers.Contains("Content-MD5")); break; case StorageChecksumAlgorithm.StorageCrc64: - Assert.True(response.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessageHeader)); + Assert.True(response.Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)); break; default: Assert.Fail("Test can't validate given algorithm type."); @@ -1968,7 +1968,7 @@ public virtual async Task DownloadOverwritesDefaultClientValidationOptions( Assert.True(response.Headers.Contains("Content-MD5")); break; case StorageChecksumAlgorithm.StorageCrc64: - Assert.True(response.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessageHeader)); + Assert.True(response.Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)); break; default: Assert.Fail("Test can't validate given algorithm type."); @@ -2072,7 +2072,7 @@ public virtual async Task DownloadRecoversFromInterruptWithValidation( Assert.True(response.Headers.Contains("Content-MD5")); break; case StorageChecksumAlgorithm.StorageCrc64: - Assert.True(response.Headers.Contains(Constants.StructuredMessage.CrcStructuredMessageHeader)); + Assert.True(response.Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)); break; default: Assert.Fail("Test can't validate given algorithm type."); diff --git a/sdk/storage/Azure.Storage.Files.Shares/src/ShareFileClient.cs b/sdk/storage/Azure.Storage.Files.Shares/src/ShareFileClient.cs index 524f9cba6db8a..3f6a4890d9a89 100644 --- a/sdk/storage/Azure.Storage.Files.Shares/src/ShareFileClient.cs +++ b/sdk/storage/Azure.Storage.Files.Shares/src/ShareFileClient.cs @@ -2409,7 +2409,7 @@ async ValueTask> Factory(long offset, bool async return StructuredMessageDecodingStream.WrapStream(result.Value.Content, result.Value.ContentLength); } - if (initialResponse.GetRawResponse().Headers.Contains(Constants.StructuredMessage.CrcStructuredMessageHeader)) + if (initialResponse.GetRawResponse().Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)) { (Stream decodingStream, StructuredMessageDecodingStream.DecodedData decodedData) = StructuredMessageDecodingStream.WrapStream( initialResponse.Value.Content, initialResponse.Value.ContentLength); @@ -2443,7 +2443,7 @@ async ValueTask> Factory(long offset, bool async validationOptions.ChecksumAlgorithm != StorageChecksumAlgorithm.None && validationOptions.AutoValidateChecksum && // structured message decoding does the validation for us - !initialResponse.GetRawResponse().Headers.Contains(Constants.StructuredMessage.CrcStructuredMessageHeader)) + !initialResponse.GetRawResponse().Headers.Contains(Constants.StructuredMessage.StructuredMessageHeader)) { // safe-buffer; transactional hash download limit well below maxInt var readDestStream = new MemoryStream((int)initialResponse.Value.ContentLength);