From 418f0449e1df8de30a3e3a79e231320297784012 Mon Sep 17 00:00:00 2001
From: "David R. Williamson" <drwill@microsoft.com>
Date: Fri, 31 Mar 2023 09:59:00 -0700
Subject: [PATCH 1/4] Pass byte[] through as payload for C2D for binary
 payloads

---
 .../service/src/Messaging/Models/OutgoingMessage.cs   | 11 ++++++++---
 1 file changed, 8 insertions(+), 3 deletions(-)

diff --git a/iothub/service/src/Messaging/Models/OutgoingMessage.cs b/iothub/service/src/Messaging/Models/OutgoingMessage.cs
index d2d9e44670..3f2865c694 100644
--- a/iothub/service/src/Messaging/Models/OutgoingMessage.cs
+++ b/iothub/service/src/Messaging/Models/OutgoingMessage.cs
@@ -221,12 +221,17 @@ public DeliveryAcknowledgement Ack
         internal ArraySegment<byte> DeliveryTag { get; set; }
 
         /// <summary>
-        /// Gets the payload as a byte array.
+        /// Gets the payload as a byte array, serialized and encoded if necessary.
         /// </summary>
-        /// <returns>A fully encoded serialized string as bytes.</returns>
+        /// <remarks>
+        /// If needed, serialization uses Newtonsoft.Json and encoding is UTF8.
+        /// </remarks>
+        /// <returns>A payload as a byte array.</returns>
         internal byte[] GetPayloadObjectBytes()
         {
-            return Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(Payload));
+            return Payload is byte[] payloadAsByteArray
+                ? payloadAsByteArray
+                : Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(Payload));
         }
 
         private T GetSystemProperty<T>(string key)

From ddef360941d5619eac8e6670a75ed450e53f4983 Mon Sep 17 00:00:00 2001
From: "David R. Williamson" <drwill@microsoft.com>
Date: Fri, 31 Mar 2023 11:25:29 -0700
Subject: [PATCH 2/4] Add a test and rename other test classes to match newest
 names

---
 .../helpers/TaskCompletionSourceHelper.cs     | 42 +++++------
 .../CombinedClientOperationsPoolAmqpTests.cs  |  4 +-
 .../DeviceClientX509AuthenticationE2ETests.cs |  2 +-
 ...s => DirectMethodE2eCustomPayloadTests.cs} |  2 +-
 ....MessageSendFaultInjectionPoolAmqpTests.cs |  2 +-
 ...ncomingMessageCallbackE2ePoolAmqpTests.cs} | 10 +--
 ....cs => IncomingMessageCallbackE2eTests.cs} | 14 ++--
 ...mingMessageCallbackFaultInjectionTests.cs} |  4 +-
 .../device/MessageSendFaultInjectionTests.cs  |  2 +-
 ...2eTests.cs => TelemetryMessageE2eTests.cs} |  4 +-
 ...> TelemetryMessageSendE2ePoolAmqpTests.cs} |  8 +--
 .../iothub/service/MessagingClientE2ETests.cs | 70 +++++++++++++++----
 .../tests/Messaging/MessageClientTests.cs     | 24 ++-----
 13 files changed, 107 insertions(+), 81 deletions(-)
 rename e2e/Tests/iothub/device/{MethodE2ECustomPayloadTests.cs => DirectMethodE2eCustomPayloadTests.cs} (99%)
 rename e2e/Tests/iothub/device/{MessageReceiveE2EPoolAmqpTests.cs => IncomingMessageCallbackE2ePoolAmqpTests.cs} (92%)
 rename e2e/Tests/iothub/device/{MessageReceiveE2ETests.cs => IncomingMessageCallbackE2eTests.cs} (96%)
 rename e2e/Tests/iothub/device/{MessageReceiveFaultInjectionTests.cs => IncomingMessageCallbackFaultInjectionTests.cs} (98%)
 rename e2e/Tests/iothub/device/{TelemetryE2eTests.cs => TelemetryMessageE2eTests.cs} (98%)
 rename e2e/Tests/iothub/device/{MessageSendE2EPoolAmqpTests.cs => TelemetryMessageSendE2ePoolAmqpTests.cs} (84%)

diff --git a/e2e/Tests/helpers/TaskCompletionSourceHelper.cs b/e2e/Tests/helpers/TaskCompletionSourceHelper.cs
index 4a74d647f6..d544a7a3cb 100644
--- a/e2e/Tests/helpers/TaskCompletionSourceHelper.cs
+++ b/e2e/Tests/helpers/TaskCompletionSourceHelper.cs
@@ -1,37 +1,31 @@
 // Copyright (c) Microsoft. All rights reserved.
 // Licensed under the MIT license. See LICENSE file in the project root for full license information.
 
-using System;
 using System.Threading;
 using System.Threading.Tasks;
 
-namespace Microsoft.Azure.Devices.E2ETests.helpers
+namespace Microsoft.Azure.Devices.E2ETests
 {
-    public class TaskCompletionSourceHelper
+    /// <summary>
+    /// Modern .NET supports waiting on the TaskCompletionSource with a cancellation token, but older ones
+    /// do not. We can bind that task with a call to Task.Delay to get the same effect, though.
+    /// </summary>
+    internal static class TaskCompletionSourceHelper
     {
-        /// <summary>
-        /// Gets the result of the provided task completion source or throws OperationCancelledException if the provided
-        /// cancellation token is cancelled beforehand.
-        /// </summary>
-        /// <typeparam name="T">The type of the result of the task completion source.</typeparam>
-        /// <param name="taskCompletionSource">The task completion source to asynchronously wait for the result of.</param>
-        /// <param name="cancellationToken">The cancellation token.</param>
-        /// <returns>The result of the provided task completion source if it completes before the provided cancellation token is cancelled.</returns>
-        /// <exception cref="OperationCanceledException">If the cancellation token is cancelled before the provided task completion source finishes.</exception>
-        public static async Task<T> GetTaskCompletionSourceResultAsync<T>(TaskCompletionSource<T> taskCompletionSource, CancellationToken cancellationToken)
+        internal static async Task<T> WaitAsync<T>(this TaskCompletionSource<T> taskCompletionSource, CancellationToken ct)
         {
-            // Note that Task.Delay(-1, cancellationToken) effectively waits until the cancellation token is cancelled. The -1 value
-            // just means that the task is allowed to run indefinitely.
-            Task finishedTask = await Task.WhenAny(taskCompletionSource.Task, Task.Delay(-1, cancellationToken)).ConfigureAwait(false);
+#if NET5_0_OR_GREATER
+            return await taskCompletionSource.Task.WaitAsync(ct).ConfigureAwait(false);
+#else
+            Task finishedTask = await Task
+                .WhenAny(
+                    taskCompletionSource.Task,
+                    Task.Delay(-1, ct))
+                .ConfigureAwait(false);
 
-            // If the finished task is not the cancellation token
-            if (finishedTask is Task<T>)
-            {
-                return await ((Task<T>)finishedTask).ConfigureAwait(false);
-            }
-
-            // Otherwise throw operation cancelled exception since the cancellation token was cancelled before the task finished.
-            throw new OperationCanceledException();
+            ct.ThrowIfCancellationRequested();
+            return await taskCompletionSource.Task.ConfigureAwait(false);
+#endif
         }
     }
 }
diff --git a/e2e/Tests/iothub/device/CombinedClientOperationsPoolAmqpTests.cs b/e2e/Tests/iothub/device/CombinedClientOperationsPoolAmqpTests.cs
index 70823c1b40..82a0ecbadc 100644
--- a/e2e/Tests/iothub/device/CombinedClientOperationsPoolAmqpTests.cs
+++ b/e2e/Tests/iothub/device/CombinedClientOperationsPoolAmqpTests.cs
@@ -95,7 +95,7 @@ async Task TestOperationAsync(TestDevice testDevice, TestDeviceCallbackHandler t
 
                 // D2C Operation
                 VerboseTestLogger.WriteLine($"{nameof(CombinedClientOperationsPoolAmqpTests)}: Operation 1: Send D2C for device={testDevice.Id}");
-                TelemetryMessage message = TelemetryE2ETests.ComposeD2cTestMessage(out string _, out string _);
+                TelemetryMessage message = TelemetryMessageE2eTests.ComposeD2cTestMessage(out string _, out string _);
                 Task sendD2cMessage = testDevice.DeviceClient.SendTelemetryAsync(message);
                 clientOperations.Add(sendD2cMessage);
 
@@ -105,7 +105,7 @@ async Task TestOperationAsync(TestDevice testDevice, TestDeviceCallbackHandler t
                 OutgoingMessage msg = msgSent.Item1;
                 string payload = msgSent.Item2;
 
-                Task verifyDeviceClientReceivesMessage = MessageReceiveE2ETests.VerifyReceivedC2dMessageAsync(testDevice.DeviceClient, testDevice.Id, msg, payload);
+                Task verifyDeviceClientReceivesMessage = IncomingMessageCallbackE2eTests.VerifyReceivedC2dMessageAsync(testDevice.DeviceClient, testDevice.Id, msg, payload);
                 clientOperations.Add(verifyDeviceClientReceivesMessage);
 
                 // Invoke direct methods
diff --git a/e2e/Tests/iothub/device/DeviceClientX509AuthenticationE2ETests.cs b/e2e/Tests/iothub/device/DeviceClientX509AuthenticationE2ETests.cs
index 1c1c308bde..2ada65d671 100644
--- a/e2e/Tests/iothub/device/DeviceClientX509AuthenticationE2ETests.cs
+++ b/e2e/Tests/iothub/device/DeviceClientX509AuthenticationE2ETests.cs
@@ -176,7 +176,7 @@ private static async Task SendMessageTestAsync(IotHubClientTransportSettings tra
             await using TestDevice testDevice = await TestDevice.GetTestDeviceAsync(s_devicePrefix, TestDeviceType.X509).ConfigureAwait(false);
             IotHubDeviceClient deviceClient = testDevice.CreateDeviceClient(new IotHubClientOptions(transportSetting));
             await testDevice.OpenWithRetryAsync().ConfigureAwait(false);
-            TelemetryMessage message = TelemetryE2ETests.ComposeD2cTestMessage(out string _, out string _);
+            TelemetryMessage message = TelemetryMessageE2eTests.ComposeD2cTestMessage(out string _, out string _);
             await deviceClient.SendTelemetryAsync(message).ConfigureAwait(false);
         }
 
diff --git a/e2e/Tests/iothub/device/MethodE2ECustomPayloadTests.cs b/e2e/Tests/iothub/device/DirectMethodE2eCustomPayloadTests.cs
similarity index 99%
rename from e2e/Tests/iothub/device/MethodE2ECustomPayloadTests.cs
rename to e2e/Tests/iothub/device/DirectMethodE2eCustomPayloadTests.cs
index 1b95984f52..4579d45b7f 100644
--- a/e2e/Tests/iothub/device/MethodE2ECustomPayloadTests.cs
+++ b/e2e/Tests/iothub/device/DirectMethodE2eCustomPayloadTests.cs
@@ -16,7 +16,7 @@ namespace Microsoft.Azure.Devices.E2ETests.Methods
     [TestClass]
     [TestCategory("E2E")]
     [TestCategory("IoTHub-Client")]
-    public class MethodE2ECustomPayloadTests : E2EMsTestBase
+    public class DirectMethodE2eCustomPayloadTests : E2EMsTestBase
     {
         private static readonly DirectMethodRequestPayload _customTypeRequest = new() { DesiredState = "on" };
         private static readonly DirectMethodResponsePayload _customTypeResponse = new() { CurrentState = "off" };
diff --git a/e2e/Tests/iothub/device/FaultInjectionPoolAmqpTests.MessageSendFaultInjectionPoolAmqpTests.cs b/e2e/Tests/iothub/device/FaultInjectionPoolAmqpTests.MessageSendFaultInjectionPoolAmqpTests.cs
index f6e0308fa0..ab26b75fea 100644
--- a/e2e/Tests/iothub/device/FaultInjectionPoolAmqpTests.MessageSendFaultInjectionPoolAmqpTests.cs
+++ b/e2e/Tests/iothub/device/FaultInjectionPoolAmqpTests.MessageSendFaultInjectionPoolAmqpTests.cs
@@ -272,7 +272,7 @@ async Task TestInitAsync(IotHubDeviceClient deviceClient, TestDevice testDevice,
 
             async Task TestOperationAsync(IotHubDeviceClient deviceClient, TestDevice testDevice, TestDeviceCallbackHandler _)
             {
-                TelemetryMessage testMessage = TelemetryE2ETests.ComposeD2cTestMessage(out string payload, out string p1Value);
+                TelemetryMessage testMessage = TelemetryMessageE2eTests.ComposeD2cTestMessage(out string payload, out string p1Value);
 
                 VerboseTestLogger.WriteLine($"{nameof(FaultInjectionPoolAmqpTests)}.{testDevice.Id}: payload='{payload}' p1Value='{p1Value}'");
                 using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(20));
diff --git a/e2e/Tests/iothub/device/MessageReceiveE2EPoolAmqpTests.cs b/e2e/Tests/iothub/device/IncomingMessageCallbackE2ePoolAmqpTests.cs
similarity index 92%
rename from e2e/Tests/iothub/device/MessageReceiveE2EPoolAmqpTests.cs
rename to e2e/Tests/iothub/device/IncomingMessageCallbackE2ePoolAmqpTests.cs
index af440307c9..3e7648b7c5 100644
--- a/e2e/Tests/iothub/device/MessageReceiveE2EPoolAmqpTests.cs
+++ b/e2e/Tests/iothub/device/IncomingMessageCallbackE2ePoolAmqpTests.cs
@@ -15,9 +15,9 @@ namespace Microsoft.Azure.Devices.E2ETests.Messaging
     [TestClass]
     [TestCategory("E2E")]
     [TestCategory("IoTHub-Client")]
-    public class MessageReceiveE2EPoolAmqpTests : E2EMsTestBase
+    public class IncomingMessageCallbackE2ePoolAmqpTests : E2EMsTestBase
     {
-        private readonly string DevicePrefix = $"{nameof(MessageReceiveE2EPoolAmqpTests)}_";
+        private readonly string DevicePrefix = $"{nameof(IncomingMessageCallbackE2ePoolAmqpTests)}_";
 
         [TestMethod]
         [Timeout(TestTimeoutMilliseconds)]
@@ -132,11 +132,11 @@ async Task InitOperationAsync(TestDevice testDevice, TestDeviceCallbackHandler _
 
             async Task TestOperationAsync(TestDevice testDevice, TestDeviceCallbackHandler _)
             {
-                VerboseTestLogger.WriteLine($"{nameof(MessageReceiveE2EPoolAmqpTests)}: Preparing to receive message for device {testDevice.Id}");
+                VerboseTestLogger.WriteLine($"{nameof(IncomingMessageCallbackE2ePoolAmqpTests)}: Preparing to receive message for device {testDevice.Id}");
                 await testDevice.OpenWithRetryAsync().ConfigureAwait(false);
 
                 Tuple<OutgoingMessage, string> msgSent = messagesSent[testDevice.Id];
-                await MessageReceiveE2ETests.VerifyReceivedC2dMessageAsync(testDevice.DeviceClient, testDevice.Id, msgSent.Item1, msgSent.Item2).ConfigureAwait(false);
+                await IncomingMessageCallbackE2eTests.VerifyReceivedC2dMessageAsync(testDevice.DeviceClient, testDevice.Id, msgSent.Item1, msgSent.Item2).ConfigureAwait(false);
             }
 
             await PoolingOverAmqp
@@ -176,7 +176,7 @@ async Task InitOperationAsync(TestDevice testDevice, TestDeviceCallbackHandler t
 
             async Task TestOperationAsync(TestDevice testDevice, TestDeviceCallbackHandler testDeviceCallbackHandler)
             {
-                VerboseTestLogger.WriteLine($"{nameof(MessageReceiveE2EPoolAmqpTests)}: Preparing to receive message for device {testDevice.Id}");
+                VerboseTestLogger.WriteLine($"{nameof(IncomingMessageCallbackE2ePoolAmqpTests)}: Preparing to receive message for device {testDevice.Id}");
 
                 using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(20));
                 await testDeviceCallbackHandler.WaitForReceiveMessageCallbackAsync(cts.Token).ConfigureAwait(false);
diff --git a/e2e/Tests/iothub/device/MessageReceiveE2ETests.cs b/e2e/Tests/iothub/device/IncomingMessageCallbackE2eTests.cs
similarity index 96%
rename from e2e/Tests/iothub/device/MessageReceiveE2ETests.cs
rename to e2e/Tests/iothub/device/IncomingMessageCallbackE2eTests.cs
index bfd7b1367f..be97a9b483 100644
--- a/e2e/Tests/iothub/device/MessageReceiveE2ETests.cs
+++ b/e2e/Tests/iothub/device/IncomingMessageCallbackE2eTests.cs
@@ -5,11 +5,11 @@
 using System.Collections.Generic;
 using System.Diagnostics;
 using System.Linq;
+using System.Text;
 using System.Threading;
 using System.Threading.Tasks;
 using FluentAssertions;
 using Microsoft.Azure.Devices.Client;
-using Microsoft.Azure.Devices.E2ETests.helpers;
 using Microsoft.Azure.Devices.E2ETests.Helpers;
 using Microsoft.Azure.Devices.E2ETests.Helpers.Templates;
 using Microsoft.VisualStudio.TestTools.UnitTesting;
@@ -20,9 +20,9 @@ namespace Microsoft.Azure.Devices.E2ETests.Messaging
     [TestCategory("E2E")]
     [TestCategory("IoTHub-Client")]
     [TestCategory("LongRunning")]
-    public class MessageReceiveE2ETests : E2EMsTestBase
+    public class IncomingMessageCallbackE2eTests : E2EMsTestBase
     {
-        private static readonly string s_devicePrefix = $"{nameof(MessageReceiveE2ETests)}_";
+        private static readonly string s_devicePrefix = $"{nameof(IncomingMessageCallbackE2eTests)}_";
 
         private static readonly TimeSpan s_oneSecond = TimeSpan.FromSeconds(1);
         private static readonly TimeSpan s_fiveSeconds = TimeSpan.FromSeconds(5);
@@ -88,16 +88,14 @@ public static async Task VerifyReceivedC2dMessageAsync(IotHubDeviceClient dc, st
 
                 using var cts = new CancellationTokenSource(s_oneMinute);
                 var c2dMessageReceived = new TaskCompletionSource<IncomingMessage>(TaskCreationOptions.RunContinuationsAsynchronously);
-                Func<IncomingMessage, Task<MessageAcknowledgement>> OnC2DMessageReceived = (message) =>
+                Task<MessageAcknowledgement> OnC2DMessageReceived(IncomingMessage message)
                 {
                     c2dMessageReceived.TrySetResult(message);
                     return Task.FromResult(MessageAcknowledgement.Complete);
-                };
+                }
                 await dc.SetIncomingMessageCallbackAsync(OnC2DMessageReceived).ConfigureAwait(false);
 
-                IncomingMessage receivedMessage = await TaskCompletionSourceHelper
-                    .GetTaskCompletionSourceResultAsync(c2dMessageReceived, cts.Token)
-                    .ConfigureAwait(false);
+                IncomingMessage receivedMessage = await c2dMessageReceived.WaitAsync(cts.Token).ConfigureAwait(false);
 
                 receivedMessage.MessageId.Should().Be(message.MessageId, "Received message Id is not what was sent by service");
                 receivedMessage.UserId.Should().Be(message.UserId, "Received user Id is not what was sent by service");
diff --git a/e2e/Tests/iothub/device/MessageReceiveFaultInjectionTests.cs b/e2e/Tests/iothub/device/IncomingMessageCallbackFaultInjectionTests.cs
similarity index 98%
rename from e2e/Tests/iothub/device/MessageReceiveFaultInjectionTests.cs
rename to e2e/Tests/iothub/device/IncomingMessageCallbackFaultInjectionTests.cs
index c8343d9f3a..7d698d5325 100644
--- a/e2e/Tests/iothub/device/MessageReceiveFaultInjectionTests.cs
+++ b/e2e/Tests/iothub/device/IncomingMessageCallbackFaultInjectionTests.cs
@@ -14,9 +14,9 @@ namespace Microsoft.Azure.Devices.E2ETests.Messaging
     [TestClass]
     [TestCategory("FaultInjection")]
     [TestCategory("IoTHub-Client")]
-    public partial class MessageReceiveFaultInjectionTests : E2EMsTestBase
+    public partial class IncomingMessageCallbackFaultInjectionTests : E2EMsTestBase
     {
-        private readonly string DevicePrefix = $"{nameof(MessageReceiveFaultInjectionTests)}_";
+        private readonly string DevicePrefix = $"{nameof(IncomingMessageCallbackFaultInjectionTests)}_";
 
         [TestMethod]
         [Timeout(TestTimeoutMilliseconds)]
diff --git a/e2e/Tests/iothub/device/MessageSendFaultInjectionTests.cs b/e2e/Tests/iothub/device/MessageSendFaultInjectionTests.cs
index c0c537b269..37b3637528 100644
--- a/e2e/Tests/iothub/device/MessageSendFaultInjectionTests.cs
+++ b/e2e/Tests/iothub/device/MessageSendFaultInjectionTests.cs
@@ -320,7 +320,7 @@ async Task InitAsync(IotHubDeviceClient deviceClient, TestDevice testDevice)
 
             async Task TestOperationAsync(IotHubDeviceClient deviceClient, TestDevice testDevice)
             {
-                TelemetryMessage testMessage = TelemetryE2ETests.ComposeD2cTestMessage(out string _, out string _);
+                TelemetryMessage testMessage = TelemetryMessageE2eTests.ComposeD2cTestMessage(out string _, out string _);
                 using var cts = new CancellationTokenSource(operationTimeout);
                 await deviceClient.SendTelemetryAsync(testMessage, cts.Token).ConfigureAwait(false);
             };
diff --git a/e2e/Tests/iothub/device/TelemetryE2eTests.cs b/e2e/Tests/iothub/device/TelemetryMessageE2eTests.cs
similarity index 98%
rename from e2e/Tests/iothub/device/TelemetryE2eTests.cs
rename to e2e/Tests/iothub/device/TelemetryMessageE2eTests.cs
index 0c52516f34..0137a012fb 100644
--- a/e2e/Tests/iothub/device/TelemetryE2eTests.cs
+++ b/e2e/Tests/iothub/device/TelemetryMessageE2eTests.cs
@@ -16,7 +16,7 @@ namespace Microsoft.Azure.Devices.E2ETests.Messaging
     [TestClass]
     [TestCategory("E2E")]
     [TestCategory("IoTHub-Client")]
-    public partial class TelemetryE2ETests : E2EMsTestBase
+    public partial class TelemetryMessageE2eTests : E2EMsTestBase
     {
         private const int MessageBatchCount = 5;
 
@@ -28,7 +28,7 @@ public partial class TelemetryE2ETests : E2EMsTestBase
         // the message size is less than 1 MB.
         private const int OverlyExceedAllowedMessageSizeInBytes = 3000 * 1024;
 
-        private readonly string _idPrefix = $"{nameof(TelemetryE2ETests)}_";
+        private readonly string _idPrefix = $"{nameof(TelemetryMessageE2eTests)}_";
         private static readonly string s_proxyServerAddress = TestConfiguration.IotHub.ProxyServerAddress;
 
         [TestMethod]
diff --git a/e2e/Tests/iothub/device/MessageSendE2EPoolAmqpTests.cs b/e2e/Tests/iothub/device/TelemetryMessageSendE2ePoolAmqpTests.cs
similarity index 84%
rename from e2e/Tests/iothub/device/MessageSendE2EPoolAmqpTests.cs
rename to e2e/Tests/iothub/device/TelemetryMessageSendE2ePoolAmqpTests.cs
index 82ad05248f..13ee125b68 100644
--- a/e2e/Tests/iothub/device/MessageSendE2EPoolAmqpTests.cs
+++ b/e2e/Tests/iothub/device/TelemetryMessageSendE2ePoolAmqpTests.cs
@@ -14,9 +14,9 @@ namespace Microsoft.Azure.Devices.E2ETests.Messaging
     [TestClass]
     [TestCategory("E2E")]
     [TestCategory("IoTHub-Client")]
-    public class MessageSendE2EPoolAmqpTests : E2EMsTestBase
+    public class TelemetryMessageSendE2ePoolAmqpTests : E2EMsTestBase
     {
-        private readonly string _devicePrefix = $"{nameof(MessageSendE2EPoolAmqpTests)}_";
+        private readonly string _devicePrefix = $"{nameof(TelemetryMessageSendE2ePoolAmqpTests)}_";
 
         [TestMethod]
         [Timeout(LongRunningTestTimeoutMilliseconds)]
@@ -53,8 +53,8 @@ async Task InitAsync(TestDevice testDevice, TestDeviceCallbackHandler c)
 
             async Task TestOperationAsync(TestDevice testDevice, TestDeviceCallbackHandler _)
             {
-                TelemetryMessage testMessage = TelemetryE2ETests.ComposeD2cTestMessage(out string payload, out string p1Value);
-                VerboseTestLogger.WriteLine($"{nameof(MessageSendE2EPoolAmqpTests)}.{testDevice.Id}: messageId='{testMessage.MessageId}' payload='{payload}' p1Value='{p1Value}'");
+                TelemetryMessage testMessage = TelemetryMessageE2eTests.ComposeD2cTestMessage(out string payload, out string p1Value);
+                VerboseTestLogger.WriteLine($"{nameof(TelemetryMessageSendE2ePoolAmqpTests)}.{testDevice.Id}: messageId='{testMessage.MessageId}' payload='{payload}' p1Value='{p1Value}'");
                 await testDevice.DeviceClient.SendTelemetryAsync(testMessage).ConfigureAwait(false);
             }
 
diff --git a/e2e/Tests/iothub/service/MessagingClientE2ETests.cs b/e2e/Tests/iothub/service/MessagingClientE2ETests.cs
index 11bc7fe9cc..682bb8d841 100644
--- a/e2e/Tests/iothub/service/MessagingClientE2ETests.cs
+++ b/e2e/Tests/iothub/service/MessagingClientE2ETests.cs
@@ -4,9 +4,11 @@
 using System;
 using System.Diagnostics;
 using System.Net;
+using System.Text;
 using System.Threading;
 using System.Threading.Tasks;
 using FluentAssertions;
+using Microsoft.Azure.Devices.Client;
 using Microsoft.Azure.Devices.E2ETests.Helpers;
 using Microsoft.VisualStudio.TestTools.UnitTesting;
 
@@ -50,10 +52,10 @@ private async Task DefaultTimeout()
         private async Task TestTimeout(CancellationToken cancellationToken)
         {
             await using TestDevice testDevice = await TestDevice.GetTestDeviceAsync(DevicePrefix).ConfigureAwait(false);
-            IotHubServiceClient sender = TestDevice.ServiceClient;
+            IotHubServiceClient serviceClient = TestDevice.ServiceClient;
 
             // don't pass in cancellation token here. This test is for seeing how SendAsync reacts with an valid or expired token.
-            await sender.Messages.OpenAsync(CancellationToken.None).ConfigureAwait(false);
+            await serviceClient.Messages.OpenAsync(CancellationToken.None).ConfigureAwait(false);
 
             var sw = new Stopwatch();
             sw.Start();
@@ -62,14 +64,14 @@ private async Task TestTimeout(CancellationToken cancellationToken)
             try
             {
                 var testMessage = new OutgoingMessage("Test Message");
-                await sender.Messages.SendAsync(testDevice.Id, testMessage, cancellationToken).ConfigureAwait(false);
+                await serviceClient.Messages.SendAsync(testDevice.Id, testMessage, cancellationToken).ConfigureAwait(false);
 
                 // Pass in the cancellation token to see how the operation reacts to it.
-                await sender.Messages.SendAsync(testDevice.Id, testMessage, cancellationToken).ConfigureAwait(false);
+                await serviceClient.Messages.SendAsync(testDevice.Id, testMessage, cancellationToken).ConfigureAwait(false);
             }
             finally
             {
-                await sender.Messages.CloseAsync(CancellationToken.None).ConfigureAwait(false);
+                await serviceClient.Messages.CloseAsync(CancellationToken.None).ConfigureAwait(false);
                 sw.Stop();
                 VerboseTestLogger.WriteLine($"Testing ServiceClient SendAsync(): exiting test after time={sw.Elapsed}; ticks={sw.ElapsedTicks}");
             }
@@ -156,7 +158,7 @@ public async Task MessagingClient_OpeningAlreadyOpenClient_DoesNotThrow(IotHubTr
         [TestMethod]
         [Timeout(TestTimeoutMilliseconds)]
         [DataRow(IotHubTransportProtocol.Tcp)]
-        [DataRow (IotHubTransportProtocol.WebSocket)]
+        [DataRow(IotHubTransportProtocol.WebSocket)]
         public async Task MessagingClient_SendMessageOnClosedClient_ThrowsInvalidOperationException(IotHubTransportProtocol protocol)
         {
             // arrange
@@ -165,11 +167,11 @@ public async Task MessagingClient_SendMessageOnClosedClient_ThrowsInvalidOperati
                 Protocol = protocol
             };
 
-            await using TestDevice testDevice = await TestDevice.GetTestDeviceAsync (DevicePrefix).ConfigureAwait(false);
+            await using TestDevice testDevice = await TestDevice.GetTestDeviceAsync(DevicePrefix).ConfigureAwait(false);
             using var serviceClient = new IotHubServiceClient(TestConfiguration.IotHub.ConnectionString, options);
 
             // act
-            var message = new OutgoingMessage(new byte[10]);
+            var message = new OutgoingMessage(new object());
             await serviceClient.Messages.OpenAsync().ConfigureAwait(false);
             await serviceClient.Messages.CloseAsync();
 
@@ -213,8 +215,8 @@ public async Task MessageIdDefaultNotSet_SendEventDoesNotSetMessageId()
         {
             // arrange
             await using TestDevice testDevice = await TestDevice.GetTestDeviceAsync(DevicePrefix).ConfigureAwait(false);
-            using var sender = new IotHubServiceClient(TestConfiguration.IotHub.ConnectionString);
-            await sender.Messages.OpenAsync().ConfigureAwait(false);
+            IotHubServiceClient serviceClient = TestDevice.ServiceClient;
+            await serviceClient.Messages.OpenAsync().ConfigureAwait(false);
             string messageId = Guid.NewGuid().ToString();
 
             // act
@@ -223,10 +225,10 @@ public async Task MessageIdDefaultNotSet_SendEventDoesNotSetMessageId()
             {
                 MessageId = messageId,
             };
-            await sender.Messages.SendAsync(testDevice.Id, messageWithoutId).ConfigureAwait(false);
-            await sender.Messages.SendAsync(testDevice.Id, messageWithId).ConfigureAwait(false);
+            await serviceClient.Messages.SendAsync(testDevice.Id, messageWithoutId).ConfigureAwait(false);
+            await serviceClient.Messages.SendAsync(testDevice.Id, messageWithId).ConfigureAwait(false);
 
-            await sender.Messages.CloseAsync().ConfigureAwait(false);
+            await serviceClient.Messages.CloseAsync().ConfigureAwait(false);
 
             // assert
             messageWithoutId.MessageId.Should().BeNull();
@@ -367,5 +369,47 @@ public async Task MessagingClient_SendToNonexistentModule_ThrowIotHubServiceExce
                 await sender.Messages.CloseAsync().ConfigureAwait(false);
             }
         }
+
+        // By default, the service client serializes to JSON and encodes with UTF8. For clients wishing to use a binary payload
+        // They should be able to specify the payload as a byte array and not have it serialized and encoded.
+        // Then on the receiving end in the device client, rather than use TryGetPayload<T> which uses the configured payload
+        // convention, they can get the payload as bytes and do their own deserialization.
+        [TestMethod]
+        public async Task OutgoingMessage_GetPayloadObjectBytes_DoesNotSerialize()
+        {
+            // arrange
+            string actualPayloadString = null;
+            var messageReceived = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
+            using var cts = new CancellationTokenSource(TestTimeoutMilliseconds);
+
+            Encoding binaryEncoder = Encoding.UTF32; // use a different encoder than JSON
+
+            const string payload = "My custom payload";
+            byte[] payloadBytes = binaryEncoder.GetBytes(payload);
+            var outgoingMessage = new OutgoingMessage(payloadBytes);
+
+            await using TestDevice testDevice = await TestDevice.GetTestDeviceAsync(nameof(OutgoingMessage_GetPayloadObjectBytes_DoesNotSerialize)).ConfigureAwait(false);
+            IotHubDeviceClient deviceClient = testDevice.CreateDeviceClient();
+            await testDevice.OpenWithRetryAsync(cts.Token).ConfigureAwait(false);
+
+            await deviceClient
+                .SetIncomingMessageCallbackAsync((incomingMessage) =>
+                    {
+                        byte[] actualPayloadBytes = incomingMessage.GetPayloadAsBytes();
+                        actualPayloadString = binaryEncoder.GetString(actualPayloadBytes);
+                        messageReceived.TrySetResult(true);
+                        return Task.FromResult(MessageAcknowledgement.Complete);
+                    },
+                     cts.Token)
+                .ConfigureAwait(false);
+
+            // act
+            await TestDevice.ServiceClient.Messages.OpenAsync(cts.Token).ConfigureAwait(false);
+            await TestDevice.ServiceClient.Messages.SendAsync(testDevice.Id, outgoingMessage, cts.Token).ConfigureAwait(false);
+            await messageReceived.WaitAsync(cts.Token).ConfigureAwait(false);
+
+            // assert
+            actualPayloadString.Should().Be(payload);
+        }
     }
 }
diff --git a/iothub/service/tests/Messaging/MessageClientTests.cs b/iothub/service/tests/Messaging/MessageClientTests.cs
index e5fd4e305b..bd5a96388a 100644
--- a/iothub/service/tests/Messaging/MessageClientTests.cs
+++ b/iothub/service/tests/Messaging/MessageClientTests.cs
@@ -70,10 +70,7 @@ public async Task MessagesClient_SendAsync_WithModule_NullDeviceIdThrows()
         {
             // arrange
             string payloadString = "Hello, World!";
-            byte[] payloadBytes = Encoding.UTF8.GetBytes(payloadString);
-            var mockCredentialProvider = new Mock<IotHubConnectionProperties>();
-
-            var msg = new OutgoingMessage(payloadBytes);
+            var msg = new OutgoingMessage(payloadString);
 
             using var serviceClient = new IotHubServiceClient(
                 s_connectionString,
@@ -91,10 +88,9 @@ public async Task MessagesClient_SendAsync_WithModule_NullModuleIdThrows()
         {
             // arrange
             string payloadString = "Hello, World!";
-            byte[] payloadBytes = Encoding.UTF8.GetBytes(payloadString);
             var mockCredentialProvider = new Mock<IotHubConnectionProperties>();
 
-            var msg = new OutgoingMessage(payloadBytes);
+            var msg = new OutgoingMessage(payloadString);
 
             using var serviceClient = new IotHubServiceClient(
                 s_connectionString,
@@ -114,9 +110,7 @@ public async Task MessagesClient_SendAsync_NullParamsThrows(string deviceId, str
         {
             // arrange
             string payloadString = "Hello, World!";
-            byte[] payloadBytes = Encoding.UTF8.GetBytes(payloadString);
-            var mockCredentialProvider = new Mock<IotHubConnectionProperties>();
-            var msg = new OutgoingMessage(payloadBytes);
+            var msg = new OutgoingMessage(payloadString);
 
             using var serviceClient = new IotHubServiceClient(
                 s_connectionString,
@@ -136,8 +130,7 @@ public async Task MessagesClient_SendAsync_EmptyAndSpaceInParamsThrows(string de
         {
             // arrange
             string payloadString = "Hello, World!";
-            byte[] payloadBytes = Encoding.UTF8.GetBytes(payloadString);
-            var msg = new OutgoingMessage(payloadBytes);
+            var msg = new OutgoingMessage(payloadString);
 
             // arrange
             using var serviceClient = new IotHubServiceClient(
@@ -156,8 +149,7 @@ public async Task MessageClient_SendAsync_WithoutExplicitOpenAsync_ThrowsInvalid
         {
             // arrange
             string payloadString = "Hello, World!";
-            byte[] payloadBytes = Encoding.UTF8.GetBytes(payloadString);
-            var msg = new OutgoingMessage(payloadBytes);
+            var msg = new OutgoingMessage(payloadString);
 
             using var serviceClient = new IotHubServiceClient(
                 s_connectionString,
@@ -282,8 +274,7 @@ public async Task MessageClient_SendAsync()
         {
             // arrange
             string payloadString = "Hello, World!";
-            byte[] payloadBytes = Encoding.UTF8.GetBytes(payloadString);
-            var msg = new OutgoingMessage(payloadBytes);
+            var msg = new OutgoingMessage(payloadString);
 
             var mockCredentialProvider = new Mock<IotHubConnectionProperties>();
             mockCredentialProvider
@@ -319,8 +310,7 @@ public async Task MessageClient_SendAsync_DescriptiorCodeNotAcceptedThrows()
         {
             // arrange
             string payloadString = "Hello, World!";
-            byte[] payloadBytes = Encoding.UTF8.GetBytes(payloadString);
-            var msg = new OutgoingMessage(payloadBytes);
+            var msg = new OutgoingMessage(payloadString);
 
             var mockCredentialProvider = new Mock<IotHubConnectionProperties>();
             mockCredentialProvider

From 6821fa5f425b598bbeedaa3911cc4a45a8f70413 Mon Sep 17 00:00:00 2001
From: "David R. Williamson" <drwill@microsoft.com>
Date: Fri, 31 Mar 2023 12:54:15 -0700
Subject: [PATCH 3/4] Share task completion helper more broadly

---
 .../helpers/TaskCompletionSourceHelper.cs     |  2 +-
 .../service/FileUploadNotificationE2ETest.cs  | 21 +++++-----
 .../service/MessageFeedbackReceiverE2ETest.cs | 16 ++------
 .../iothub/service/MessagingClientE2ETests.cs |  8 ++--
 .../Transport/Mqtt/MqttTransportHandler.cs    | 41 +------------------
 .../Utilities/TaskCompletionSourceHelper.cs   | 31 ++++++++++++++
 .../Transports/Amqp/AmqpClientConnection.cs   |  2 +-
 .../Mqtt/ProvisioningTransportHandlerMqtt.cs  | 41 +------------------
 .../Utilities/TaskCompletionSourceHelper.cs   | 31 ++++++++++++++
 9 files changed, 87 insertions(+), 106 deletions(-)
 create mode 100644 iothub/device/src/Utilities/TaskCompletionSourceHelper.cs
 create mode 100644 provisioning/device/src/Utilities/TaskCompletionSourceHelper.cs

diff --git a/e2e/Tests/helpers/TaskCompletionSourceHelper.cs b/e2e/Tests/helpers/TaskCompletionSourceHelper.cs
index d544a7a3cb..bf192c94d3 100644
--- a/e2e/Tests/helpers/TaskCompletionSourceHelper.cs
+++ b/e2e/Tests/helpers/TaskCompletionSourceHelper.cs
@@ -17,7 +17,7 @@ internal static async Task<T> WaitAsync<T>(this TaskCompletionSource<T> taskComp
 #if NET5_0_OR_GREATER
             return await taskCompletionSource.Task.WaitAsync(ct).ConfigureAwait(false);
 #else
-            Task finishedTask = await Task
+            await Task
                 .WhenAny(
                     taskCompletionSource.Task,
                     Task.Delay(-1, ct))
diff --git a/e2e/Tests/iothub/service/FileUploadNotificationE2ETest.cs b/e2e/Tests/iothub/service/FileUploadNotificationE2ETest.cs
index 007bc11dfc..2e376a4b26 100644
--- a/e2e/Tests/iothub/service/FileUploadNotificationE2ETest.cs
+++ b/e2e/Tests/iothub/service/FileUploadNotificationE2ETest.cs
@@ -41,12 +41,14 @@ public class FileUploadNotificationE2ETest : E2EMsTestBase
         [DataRow(IotHubTransportProtocol.WebSocket, 1, true)]
         public async Task FileUploadNotification_FileUploadNotificationProcessor_ReceivesNotifications(IotHubTransportProtocol protocol, int filesToUpload, bool shouldReconnect)
         {
+            // arrange
+
             var options = new IotHubServiceClientOptions
             {
                 Protocol = protocol
             };
 
-            using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(TestTimeoutMilliseconds));
+            using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));
             using var serviceClient = new IotHubServiceClient(TestConfiguration.IotHub.ConnectionString, options);
             using StorageContainer storage = await StorageContainer.GetInstanceAsync("fileupload", false).ConfigureAwait(false);
             using var fileNotification = new SemaphoreSlim(1, 1);
@@ -94,6 +96,7 @@ async Task<AcknowledgementType> OnFileUploadNotificationReceived(FileUploadNotif
                     await serviceClient.FileUploadNotifications.OpenAsync(cts.Token).ConfigureAwait(false);
                 }
 
+                // act
                 for (int i = 0; i < filesToUpload; ++i)
                 {
                     string fileName = $"TestPayload-{Guid.NewGuid()}.txt";
@@ -101,15 +104,15 @@ async Task<AcknowledgementType> OnFileUploadNotificationReceived(FileUploadNotif
                     await UploadFile(fileName, cts.Token).ConfigureAwait(false);
                 }
 
-                await Task
-                    .WhenAny(
-                        allFilesFound.Task,
-                        Task.Delay(-1, cts.Token))
-                    .ConfigureAwait(false);
+                VerboseTestLogger.WriteLine($"Waiting on file upload notification...");
+                await allFilesFound.WaitAsync(cts.Token).ConfigureAwait(false);
+
+                // assert
                 allFilesFound.Task.IsCompleted.Should().BeTrue();
             }
             finally
             {
+                VerboseTestLogger.WriteLine($"Cleanup: closing client...");
                 await serviceClient.FileUploadNotifications.CloseAsync().ConfigureAwait(false);
             }
         }
@@ -149,11 +152,7 @@ public async Task FileUploadNotification_ErrorProcessor_ReceivesNotifications(Io
                 // file upload notification without closing and re-opening as long as there is more
                 // than one file upload notification to consume.
                 using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(TestTimeoutMilliseconds));
-                await Task
-                    .WhenAny(
-                        errorProcessorNotified.Task,
-                        Task.Delay(-1, cts.Token))
-                    .ConfigureAwait(false);
+                await errorProcessorNotified.WaitAsync(cts.Token).ConfigureAwait(false);
                 errorProcessorNotified.Task.IsCompleted.Should().BeTrue();
             }
             finally
diff --git a/e2e/Tests/iothub/service/MessageFeedbackReceiverE2ETest.cs b/e2e/Tests/iothub/service/MessageFeedbackReceiverE2ETest.cs
index 382805e3e4..f954da6180 100644
--- a/e2e/Tests/iothub/service/MessageFeedbackReceiverE2ETest.cs
+++ b/e2e/Tests/iothub/service/MessageFeedbackReceiverE2ETest.cs
@@ -77,22 +77,14 @@ Task<MessageAcknowledgement> OnC2DMessageReceived(IncomingMessage message)
                 await serviceClient.Messages.SendAsync(testDevice.Device.Id, message).ConfigureAwait(false);
 
                 // Wait for the device to receive the message.
-                await Task
-                    .WhenAny(
-                        Task.Delay(TimeSpan.FromSeconds(20)),
-                        c2dMessageReceived.Task)
-                    .ConfigureAwait(false);
+                using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(20));
+                await c2dMessageReceived.WaitAsync(cts.Token).ConfigureAwait(false);
 
                 c2dMessageReceived.Task.IsCompleted.Should().BeTrue("Timed out waiting for C2D message to be received by device");
 
                 // Wait for the service to receive the feedback message.
-                await Task
-                    .WhenAny(
-                        // Wait for up to 200 seconds for the feedback message as the service may not send messages
-                        // until they can batch others, even up to a minute later.
-                        Task.Delay(TimeSpan.FromSeconds(200)),
-                        feedbackMessageReceived.Task)
-                    .ConfigureAwait(false);
+                using var cts2 = new CancellationTokenSource(TimeSpan.FromSeconds(200));
+                await feedbackMessageReceived.WaitAsync(cts2.Token).ConfigureAwait(false);
 
                 feedbackMessageReceived.Task.IsCompleted.Should().BeTrue("service client never received c2d feedback message even though the device received the message");
             }
diff --git a/e2e/Tests/iothub/service/MessagingClientE2ETests.cs b/e2e/Tests/iothub/service/MessagingClientE2ETests.cs
index 682bb8d841..3ab6c1bc45 100644
--- a/e2e/Tests/iothub/service/MessagingClientE2ETests.cs
+++ b/e2e/Tests/iothub/service/MessagingClientE2ETests.cs
@@ -382,10 +382,10 @@ public async Task OutgoingMessage_GetPayloadObjectBytes_DoesNotSerialize()
             var messageReceived = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
             using var cts = new CancellationTokenSource(TestTimeoutMilliseconds);
 
-            Encoding binaryEncoder = Encoding.UTF32; // use a different encoder than JSON
+            Encoding payloadEncoder = Encoding.UTF32; // use a different encoder than JSON
 
             const string payload = "My custom payload";
-            byte[] payloadBytes = binaryEncoder.GetBytes(payload);
+            byte[] payloadBytes = payloadEncoder.GetBytes(payload);
             var outgoingMessage = new OutgoingMessage(payloadBytes);
 
             await using TestDevice testDevice = await TestDevice.GetTestDeviceAsync(nameof(OutgoingMessage_GetPayloadObjectBytes_DoesNotSerialize)).ConfigureAwait(false);
@@ -396,7 +396,8 @@ await deviceClient
                 .SetIncomingMessageCallbackAsync((incomingMessage) =>
                     {
                         byte[] actualPayloadBytes = incomingMessage.GetPayloadAsBytes();
-                        actualPayloadString = binaryEncoder.GetString(actualPayloadBytes);
+                        actualPayloadString = payloadEncoder.GetString(actualPayloadBytes);
+                        VerboseTestLogger.WriteLine($"Received message with payload [{actualPayloadString}].");
                         messageReceived.TrySetResult(true);
                         return Task.FromResult(MessageAcknowledgement.Complete);
                     },
@@ -405,6 +406,7 @@ await deviceClient
 
             // act
             await TestDevice.ServiceClient.Messages.OpenAsync(cts.Token).ConfigureAwait(false);
+            VerboseTestLogger.WriteLine($"Sending message with payload [{payload}] encoded to bytes using {payloadEncoder}.");
             await TestDevice.ServiceClient.Messages.SendAsync(testDevice.Id, outgoingMessage, cts.Token).ConfigureAwait(false);
             await messageReceived.WaitAsync(cts.Token).ConfigureAwait(false);
 
diff --git a/iothub/device/src/Transport/Mqtt/MqttTransportHandler.cs b/iothub/device/src/Transport/Mqtt/MqttTransportHandler.cs
index d494a80669..6f28be6037 100644
--- a/iothub/device/src/Transport/Mqtt/MqttTransportHandler.cs
+++ b/iothub/device/src/Transport/Mqtt/MqttTransportHandler.cs
@@ -622,11 +622,7 @@ public override async Task<TwinProperties> GetTwinAsync(CancellationToken cancel
                     Logging.Info($"Sent get twin request. Waiting on service response with request id {requestId}");
 
                 // Wait until IoT hub sends a message to this client with the response to this patch twin request.
-                GetTwinResponse getTwinResponse = await GetTaskCompletionSourceResultAsync(
-                        taskCompletionSource,
-                        "Timed out waiting for the service to send the twin.",
-                        cancellationToken)
-                    .ConfigureAwait(false);
+                GetTwinResponse getTwinResponse = await taskCompletionSource.WaitAsync(cancellationToken).ConfigureAwait(false);
 
                 if (Logging.IsEnabled)
                     Logging.Info(this, $"Received get twin response for request id {requestId} with status {getTwinResponse.Status}.");
@@ -724,11 +720,7 @@ public override async Task<long> UpdateReportedPropertiesAsync(ReportedPropertie
                     Logging.Info(this, $"Sent twin patch request with request id {requestId}. Now waiting for the service response.");
 
                 // Wait until IoT hub sends a message to this client with the response to this patch twin request.
-                PatchTwinResponse patchTwinResponse = await GetTaskCompletionSourceResultAsync(
-                        taskCompletionSource,
-                        "Timed out waiting for the service to send the updated reported properties version.",
-                        cancellationToken)
-                    .ConfigureAwait(false);
+                PatchTwinResponse patchTwinResponse = await taskCompletionSource.WaitAsync(cancellationToken).ConfigureAwait(false);
 
                 if (Logging.IsEnabled)
                     Logging.Info(this, $"Received twin patch response for request id {requestId} with status {patchTwinResponse.Status}.");
@@ -1333,34 +1325,5 @@ internal static string PopulateMessagePropertiesFromMessage(string topicName, Te
                 : "/";
             return $"{topicName}{properties}{suffix}";
         }
-
-        /// <summary>
-        /// Gets the result of the provided task completion source or throws OperationCanceledException if the provided
-        /// cancellation token is cancelled beforehand.
-        /// </summary>
-        /// <typeparam name="T">The type of the result of the task completion source.</typeparam>
-        /// <param name="taskCompletionSource">The task completion source to asynchronously wait for the result of.</param>
-        /// <param name="timeoutErrorMessage">The error message to put in the OperationCanceledException if this taks times out.</param>
-        /// <param name="cancellationToken">The cancellation token.</param>
-        /// <returns>The result of the provided task completion source if it completes before the provided cancellation token is cancelled.</returns>
-        /// <exception cref="OperationCanceledException">If the cancellation token is cancelled before the provided task completion source finishes.</exception>
-        private static async Task<T> GetTaskCompletionSourceResultAsync<T>(
-            TaskCompletionSource<T> taskCompletionSource,
-            string timeoutErrorMessage,
-            CancellationToken cancellationToken)
-        {
-            // Note that Task.Delay(-1, cancellationToken) effectively waits until the cancellation token is cancelled. The -1 value
-            // just means that the task is allowed to run indefinitely.
-            Task finishedTask = await Task.WhenAny(taskCompletionSource.Task, Task.Delay(-1, cancellationToken)).ConfigureAwait(false);
-
-            // If the finished task is not the cancellation token
-            if (finishedTask == taskCompletionSource.Task)
-            {
-                return await ((Task<T>)finishedTask).ConfigureAwait(false);
-            }
-
-            // Otherwise throw operation cancelled exception since the cancellation token was cancelled before the task finished.
-            throw new OperationCanceledException(timeoutErrorMessage);
-        }
     }
 }
diff --git a/iothub/device/src/Utilities/TaskCompletionSourceHelper.cs b/iothub/device/src/Utilities/TaskCompletionSourceHelper.cs
new file mode 100644
index 0000000000..b107c6d81e
--- /dev/null
+++ b/iothub/device/src/Utilities/TaskCompletionSourceHelper.cs
@@ -0,0 +1,31 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.Azure.Devices
+{
+    /// <summary>
+    /// Modern .NET supports waiting on the TaskCompletionSource with a cancellation token, but older ones
+    /// do not. We can bind that task with a call to Task.Delay to get the same effect, though.
+    /// </summary>
+    internal static class TaskCompletionSourceHelper
+    {
+        internal static async Task<T> WaitAsync<T>(this TaskCompletionSource<T> taskCompletionSource, CancellationToken ct)
+        {
+#if NET5_0_OR_GREATER
+            return await taskCompletionSource.Task.WaitAsync(ct).ConfigureAwait(false);
+#else
+            await Task
+                .WhenAny(
+                    taskCompletionSource.Task,
+                    Task.Delay(-1, ct))
+                .ConfigureAwait(false);
+
+            ct.ThrowIfCancellationRequested();
+            return await taskCompletionSource.Task.ConfigureAwait(false);
+#endif
+        }
+    }
+}
diff --git a/provisioning/device/src/Transports/Amqp/AmqpClientConnection.cs b/provisioning/device/src/Transports/Amqp/AmqpClientConnection.cs
index a62e9a0733..95caf6260a 100644
--- a/provisioning/device/src/Transports/Amqp/AmqpClientConnection.cs
+++ b/provisioning/device/src/Transports/Amqp/AmqpClientConnection.cs
@@ -148,7 +148,7 @@ internal virtual async Task OpenAsync(
                             args.CompletedCallback(args);
                         }
 
-                        _transport = await _tcs.Task.ConfigureAwait(false);
+                        _transport = await _tcs.WaitAsync(cancellationToken).ConfigureAwait(false);
                         await _transport.OpenAsync(cancellationToken).ConfigureAwait(false);
                     }
                 }
diff --git a/provisioning/device/src/Transports/Mqtt/ProvisioningTransportHandlerMqtt.cs b/provisioning/device/src/Transports/Mqtt/ProvisioningTransportHandlerMqtt.cs
index 0f65e40b84..e19d44ce4a 100644
--- a/provisioning/device/src/Transports/Mqtt/ProvisioningTransportHandlerMqtt.cs
+++ b/provisioning/device/src/Transports/Mqtt/ProvisioningTransportHandlerMqtt.cs
@@ -276,11 +276,7 @@ private async Task<RegistrationOperationStatus> PublishRegistrationRequestAsync(
             if (Logging.IsEnabled)
                 Logging.Info(this, "Published the initial registration request, now waiting for the service's response.");
 
-            RegistrationOperationStatus registrationStatus = await GetTaskCompletionSourceResultAsync(
-                    _startProvisioningRequestStatusSource,
-                    "Timed out when sending the registration request.",
-                    cancellationToken)
-                .ConfigureAwait(false);
+            RegistrationOperationStatus registrationStatus = await _startProvisioningRequestStatusSource.WaitAsync(cancellationToken).ConfigureAwait(false);
 
             if (Logging.IsEnabled)
                 Logging.Info(this, $"Service responded to the initial registration request with status '{registrationStatus.Status}'.");
@@ -309,11 +305,7 @@ private async Task<DeviceRegistrationResult> PollUntilProvisionigFinishesAsync(I
                     throw new ProvisioningClientException($"Failed to publish the MQTT registration message with reason code '{publishResult.ReasonCode}'.", true);
                 }
 
-                RegistrationOperationStatus currentStatus = await GetTaskCompletionSourceResultAsync(
-                        _checkRegistrationOperationStatusSource,
-                        "Timed out while polling the registration status.",
-                        cancellationToken)
-                    .ConfigureAwait(false);
+                RegistrationOperationStatus currentStatus = await _checkRegistrationOperationStatusSource.WaitAsync(cancellationToken).ConfigureAwait(false);
 
                 if (Logging.IsEnabled)
                     Logging.Info(this, $"Current provisioning state: {currentStatus.RegistrationState.Status}.");
@@ -454,35 +446,6 @@ private Task HandleReceivedMessageAsync(MqttApplicationMessageReceivedEventArgs
             return Task.CompletedTask;
         }
 
-        /// <summary>
-        /// Gets the result of the provided task completion source or throws OperationCanceledException if the provided
-        /// cancellation token is cancelled beforehand.
-        /// </summary>
-        /// <typeparam name="T">The type of the result of the task completion source.</typeparam>
-        /// <param name="taskCompletionSource">The task completion source to asynchronously wait for the result of.</param>
-        /// <param name="timeoutErrorMessage">The error message to put in the OperationCanceledException if this taks times out.</param>
-        /// <param name="cancellationToken">The cancellation token.</param>
-        /// <returns>The result of the provided task completion source if it completes before the provided cancellation token is cancelled.</returns>
-        /// <exception cref="OperationCanceledException">If the cancellation token is cancelled before the provided task completion source finishes.</exception>
-        private static async Task<T> GetTaskCompletionSourceResultAsync<T>(
-            TaskCompletionSource<T> taskCompletionSource,
-            string timeoutErrorMessage,
-            CancellationToken cancellationToken)
-        {
-            // Note that Task.Delay(-1, cancellationToken) effectively waits until the cancellation token is cancelled. The -1 value
-            // just means that the task is allowed to run indefinitely.
-            Task finishedTask = await Task.WhenAny(taskCompletionSource.Task, Task.Delay(-1, cancellationToken)).ConfigureAwait(false);
-
-            // If the finished task is not the cancellation token
-            if (finishedTask == taskCompletionSource.Task)
-            {
-                return await ((Task<T>)finishedTask).ConfigureAwait(false);
-            }
-
-            // Otherwise throw operation cancelled exception since the cancellation token was cancelled before the task finished.
-            throw new OperationCanceledException(timeoutErrorMessage);
-        }
-
         internal static bool ContainsAuthenticationException(Exception ex)
         {
             return ex != null
diff --git a/provisioning/device/src/Utilities/TaskCompletionSourceHelper.cs b/provisioning/device/src/Utilities/TaskCompletionSourceHelper.cs
new file mode 100644
index 0000000000..fa51df067e
--- /dev/null
+++ b/provisioning/device/src/Utilities/TaskCompletionSourceHelper.cs
@@ -0,0 +1,31 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.Azure.Devices.Provisioning.Client
+{
+    /// <summary>
+    /// Modern .NET supports waiting on the TaskCompletionSource with a cancellation token, but older ones
+    /// do not. We can bind that task with a call to Task.Delay to get the same effect, though.
+    /// </summary>
+    internal static class TaskCompletionSourceHelper
+    {
+        internal static async Task<T> WaitAsync<T>(this TaskCompletionSource<T> taskCompletionSource, CancellationToken ct)
+        {
+#if NET5_0_OR_GREATER
+            return await taskCompletionSource.Task.WaitAsync(ct).ConfigureAwait(false);
+#else
+            await Task
+                .WhenAny(
+                    taskCompletionSource.Task,
+                    Task.Delay(-1, ct))
+                .ConfigureAwait(false);
+
+            ct.ThrowIfCancellationRequested();
+            return await taskCompletionSource.Task.ConfigureAwait(false);
+#endif
+        }
+    }
+}

From ad9f0caad68b8fc76640d6257123f6ea806c1731 Mon Sep 17 00:00:00 2001
From: "David R. Williamson" <drwill@microsoft.com>
Date: Fri, 31 Mar 2023 12:59:58 -0700
Subject: [PATCH 4/4] Fix .NET version requirement

---
 e2e/Tests/helpers/TaskCompletionSourceHelper.cs                 | 2 +-
 iothub/device/src/Utilities/TaskCompletionSourceHelper.cs       | 2 +-
 provisioning/device/src/Utilities/TaskCompletionSourceHelper.cs | 2 +-
 3 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/e2e/Tests/helpers/TaskCompletionSourceHelper.cs b/e2e/Tests/helpers/TaskCompletionSourceHelper.cs
index bf192c94d3..ebc342fc4a 100644
--- a/e2e/Tests/helpers/TaskCompletionSourceHelper.cs
+++ b/e2e/Tests/helpers/TaskCompletionSourceHelper.cs
@@ -14,7 +14,7 @@ internal static class TaskCompletionSourceHelper
     {
         internal static async Task<T> WaitAsync<T>(this TaskCompletionSource<T> taskCompletionSource, CancellationToken ct)
         {
-#if NET5_0_OR_GREATER
+#if NET6_0_OR_GREATER
             return await taskCompletionSource.Task.WaitAsync(ct).ConfigureAwait(false);
 #else
             await Task
diff --git a/iothub/device/src/Utilities/TaskCompletionSourceHelper.cs b/iothub/device/src/Utilities/TaskCompletionSourceHelper.cs
index b107c6d81e..63297a5e62 100644
--- a/iothub/device/src/Utilities/TaskCompletionSourceHelper.cs
+++ b/iothub/device/src/Utilities/TaskCompletionSourceHelper.cs
@@ -14,7 +14,7 @@ internal static class TaskCompletionSourceHelper
     {
         internal static async Task<T> WaitAsync<T>(this TaskCompletionSource<T> taskCompletionSource, CancellationToken ct)
         {
-#if NET5_0_OR_GREATER
+#if NET6_0_OR_GREATER
             return await taskCompletionSource.Task.WaitAsync(ct).ConfigureAwait(false);
 #else
             await Task
diff --git a/provisioning/device/src/Utilities/TaskCompletionSourceHelper.cs b/provisioning/device/src/Utilities/TaskCompletionSourceHelper.cs
index fa51df067e..214e5b4408 100644
--- a/provisioning/device/src/Utilities/TaskCompletionSourceHelper.cs
+++ b/provisioning/device/src/Utilities/TaskCompletionSourceHelper.cs
@@ -14,7 +14,7 @@ internal static class TaskCompletionSourceHelper
     {
         internal static async Task<T> WaitAsync<T>(this TaskCompletionSource<T> taskCompletionSource, CancellationToken ct)
         {
-#if NET5_0_OR_GREATER
+#if NET6_0_OR_GREATER
             return await taskCompletionSource.Task.WaitAsync(ct).ConfigureAwait(false);
 #else
             await Task