diff --git a/src/Cbs/AmqpCbsLink.cs b/src/Cbs/AmqpCbsLink.cs index 4ad438d5..59ab9679 100644 --- a/src/Cbs/AmqpCbsLink.cs +++ b/src/Cbs/AmqpCbsLink.cs @@ -4,7 +4,6 @@ namespace Microsoft.Azure.Amqp { using System; - using System.Collections.Generic; using System.Threading.Tasks; using Microsoft.Azure.Amqp.Framing; @@ -23,7 +22,7 @@ public AmqpCbsLink(AmqpConnection connection) { this.connection = connection ?? throw new ArgumentNullException(nameof(connection)); this.linkFactory = new FaultTolerantAmqpObject( - t => TaskHelpers.CreateTask((c, s) => this.BeginCreateCbsLink(t, c, s), this.EndCreateCbsLink), + timeout => this.CreateCbsLinkAsync(timeout), link => CloseLink(link)); this.connection.AddExtension(this); @@ -34,277 +33,116 @@ public void Close() this.linkFactory.Close(); } - public Task SendTokenAsync(ICbsTokenProvider tokenProvider, Uri namespaceAddress, string audience, string resource, string[] requiredClaims, TimeSpan timeout) + public async Task SendTokenAsync(ICbsTokenProvider tokenProvider, Uri namespaceAddress, string audience, string resource, string[] requiredClaims, TimeSpan timeout) { - return TaskHelpers.CreateTask( - (c, s) => this.BeginSendToken( - tokenProvider, namespaceAddress, audience, resource, requiredClaims, timeout, c, s), - (a) => this.EndSendToken(a)); - } - - public IAsyncResult BeginSendToken(ICbsTokenProvider tokenProvider, Uri namespaceAddress, string audience, string resource, string[] requiredClaims, TimeSpan timeout, AsyncCallback callback, object state) - { - if (tokenProvider == null || namespaceAddress == null || audience == null || resource == null || requiredClaims == null) - { - throw new ArgumentNullException( - tokenProvider == null ? "tokenProvider" : namespaceAddress == null ? "namespaceAddress" : audience == null ? "audience" : resource == null ? "resource" : "requiredClaims"); - } - if (this.connection.IsClosing()) { - throw new ObjectDisposedException(CbsConstants.CbsAddress); + throw new OperationCanceledException("Connection is closing or closed."); } - return new SendTokenAsyncResult(this, tokenProvider, namespaceAddress, audience, resource, requiredClaims, timeout, callback, state); - } - - public DateTime EndSendToken(IAsyncResult result) - { - return SendTokenAsyncResult.End(result).ExpiresAtUtc; - } - - static void CloseLink(RequestResponseAmqpLink link) - { - AmqpSession session = link.SendingLink?.Session; - link.Abort(); - session?.SafeClose(); - } - - IAsyncResult BeginCreateCbsLink(TimeSpan timeout, AsyncCallback callback, object state) - { - return new OpenCbsRequestResponseLinkAsyncResult(this.connection, timeout, callback, state); - } - - RequestResponseAmqpLink EndCreateCbsLink(IAsyncResult result) - { - RequestResponseAmqpLink link = OpenCbsRequestResponseLinkAsyncResult.End(result).Link; - return link; - } - - sealed class OpenCbsRequestResponseLinkAsyncResult : IteratorAsyncResult, ILinkFactory - { - readonly AmqpConnection connection; - AmqpSession session = null; - - public OpenCbsRequestResponseLinkAsyncResult(AmqpConnection connection, TimeSpan timeout, AsyncCallback callback, object state) - : base(timeout, callback, state) + CbsToken token = await tokenProvider.GetTokenAsync(namespaceAddress, resource, requiredClaims); + string tokenType = token.TokenType; + if (tokenType == null) { - this.connection = connection; - - this.Start(); + throw new NotSupportedException(AmqpResources.AmqpUnsupportedTokenType); } - public RequestResponseAmqpLink Link { get; private set; } - - protected override IEnumerator GetAsyncSteps() + RequestResponseAmqpLink requestResponseLink; + if (!this.linkFactory.TryGetOpenedObject(out requestResponseLink)) { - string address = CbsConstants.CbsAddress; - while (this.RemainingTime() > TimeSpan.Zero) - { - try - { - AmqpSessionSettings sessionSettings = new AmqpSessionSettings() { Properties = new Fields() }; - this.session = new AmqpSession(this.connection, sessionSettings, this); - connection.AddSession(session, null); - } - catch (InvalidOperationException exception) - { - this.Complete(exception); - yield break; - } - - yield return this.CallAsync( - (thisPtr, t, c, s) => thisPtr.session.BeginOpen(t, c, s), - (thisPtr, r) => thisPtr.session.EndOpen(r), - ExceptionPolicy.Continue); - - Exception lastException = this.LastAsyncStepException; - if (lastException != null) - { - AmqpTrace.Provider.AmqpOpenEntityFailed(this, string.Empty, address, lastException); - this.session.Abort(); - this.Complete(lastException); - yield break; - } - - Fields properties = new Fields(); - properties.Add(CbsConstants.TimeoutName, (uint)this.RemainingTime().TotalMilliseconds); - this.Link = new RequestResponseAmqpLink("cbs", this.session, address, properties); - yield return this.CallAsync( - (thisPtr, t, c, s) => thisPtr.Link.BeginOpen(t, c, s), - (thisPtr, r) => thisPtr.Link.EndOpen(r), - ExceptionPolicy.Continue); - - lastException = this.LastAsyncStepException; - if (lastException != null) - { - AmqpTrace.Provider.AmqpOpenEntityFailed(this, this.Link.Name, address, lastException); - this.session.SafeClose(); - this.Link = null; - this.Complete(lastException); - yield break; - } - - AmqpTrace.Provider.AmqpOpenEntitySucceeded(this.Link, this.Link.Name, address); - yield break; - } - - if (this.session != null) - { - this.session.SafeClose(); - } - - this.Complete(new TimeoutException(AmqpResources.GetString(AmqpResources.AmqpTimeout, this.OriginalTimeout, address))); + requestResponseLink = await this.linkFactory.GetOrCreateAsync(timeout); } - AmqpLink ILinkFactory.CreateLink(AmqpSession session, AmqpLinkSettings settings) - { - AmqpLink link; - if (settings.IsReceiver()) - { - link = new ReceivingAmqpLink(session, settings); - } - else - { - link = new SendingAmqpLink(session, settings); - } + AmqpValue value = new AmqpValue(); + value.Value = token.TokenValue; + AmqpMessage putTokenRequest = AmqpMessage.Create(value); + putTokenRequest.ApplicationProperties.Map[CbsConstants.Operation] = CbsConstants.PutToken.OperationValue; + putTokenRequest.ApplicationProperties.Map[CbsConstants.PutToken.Type] = tokenType; + putTokenRequest.ApplicationProperties.Map[CbsConstants.PutToken.Audience] = audience; + putTokenRequest.ApplicationProperties.Map[CbsConstants.PutToken.Expiration] = token.ExpiresAtUtc; - AmqpTrace.Provider.AmqpLogOperationInformational(this, TraceOperation.Create, link); - return link; - } + AmqpMessage putTokenResponse = await requestResponseLink.RequestAsync(putTokenRequest, timeout); - IAsyncResult ILinkFactory.BeginOpenLink(AmqpLink link, TimeSpan timeout, AsyncCallback callback, object state) + int statusCode = (int)putTokenResponse.ApplicationProperties.Map[CbsConstants.PutToken.StatusCode]; + string statusDescription = (string)putTokenResponse.ApplicationProperties.Map[CbsConstants.PutToken.StatusDescription]; + if (statusCode == (int)AmqpResponseStatusCode.Accepted || statusCode == (int)AmqpResponseStatusCode.OK) { - return new CompletedAsyncResult(callback, state); + return token.ExpiresAtUtc; } - void ILinkFactory.EndOpenLink(IAsyncResult result) + Exception exception; + AmqpResponseStatusCode amqpResponseStatusCode = (AmqpResponseStatusCode)statusCode; + switch (amqpResponseStatusCode) { - CompletedAsyncResult.End(result); + case AmqpResponseStatusCode.BadRequest: + exception = new AmqpException(AmqpErrorCode.InvalidField, AmqpResources.GetString(AmqpResources.AmqpPutTokenFailed, statusCode, statusDescription)); + break; + case AmqpResponseStatusCode.NotFound: + exception = new AmqpException(AmqpErrorCode.NotFound, AmqpResources.GetString(AmqpResources.AmqpPutTokenFailed, statusCode, statusDescription)); + break; + case AmqpResponseStatusCode.Forbidden: + exception = new AmqpException(AmqpErrorCode.TransferLimitExceeded, AmqpResources.GetString(AmqpResources.AmqpPutTokenFailed, statusCode, statusDescription)); + break; + case AmqpResponseStatusCode.Unauthorized: + exception = new AmqpException(AmqpErrorCode.UnauthorizedAccess, AmqpResources.GetString(AmqpResources.AmqpPutTokenFailed, statusCode, statusDescription)); + break; + default: + exception = new AmqpException(AmqpErrorCode.InvalidField, AmqpResources.GetString(AmqpResources.AmqpPutTokenFailed, statusCode, statusDescription)); + break; } + + throw exception; } - sealed class SendTokenAsyncResult : IteratorAsyncResult + static void CloseLink(RequestResponseAmqpLink link) { - readonly ICbsTokenProvider tokenProvider; - readonly AmqpCbsLink cbsLink; - readonly string[] requiredClaims; - readonly Uri namespaceAddress; - readonly string audience; - readonly string resource; - CbsToken token; - Task requestResponseLinkTask; - - public SendTokenAsyncResult( - AmqpCbsLink cbsLink, - ICbsTokenProvider tokenProvider, - Uri namespaceAddress, - string audience, - string resource, - string[] requiredClaims, - TimeSpan timeout, - AsyncCallback callback, - object state) - : base(timeout, callback, state) - { - this.cbsLink = cbsLink; - this.namespaceAddress = namespaceAddress; - this.audience = audience; - this.resource = resource; - this.requiredClaims = requiredClaims; - this.tokenProvider = tokenProvider; - - this.Start(); - } + AmqpSession session = link.SendingLink?.Session; + link.Abort(); + session?.SafeClose(); + } - public DateTime ExpiresAtUtc { get; private set; } + async Task CreateCbsLinkAsync(TimeSpan timeout) + { + string address = CbsConstants.CbsAddress; + TimeoutHelper timeoutHelper = new TimeoutHelper(timeout); + AmqpSession session = null; + RequestResponseAmqpLink link = null; + Exception lastException = null; - protected override IEnumerator GetAsyncSteps() + while (timeoutHelper.RemainingTime() > TimeSpan.Zero) { - Task getTokenTask = null; - yield return this.CallTask( - (thisPtr, t) => getTokenTask = thisPtr.tokenProvider.GetTokenAsync(thisPtr.namespaceAddress, thisPtr.resource, thisPtr.requiredClaims), - ExceptionPolicy.Transfer); - this.token = getTokenTask.Result; - - string tokenType = this.token.TokenType; - if (tokenType == null) + try { - this.Complete(new InvalidOperationException(AmqpResources.AmqpUnsupportedTokenType)); - yield break; - } + AmqpSessionSettings sessionSettings = new AmqpSessionSettings() { Properties = new Fields() }; + session = this.connection.CreateSession(sessionSettings); + await session.OpenAsync(timeoutHelper.RemainingTime()); - RequestResponseAmqpLink requestResponseLink; - if (this.cbsLink.linkFactory.TryGetOpenedObject(out requestResponseLink)) - { - this.requestResponseLinkTask = Task.FromResult(requestResponseLink); - } - else - { - yield return this.CallTask( - (thisPtr, t) => thisPtr.requestResponseLinkTask = thisPtr.cbsLink.linkFactory.GetOrCreateAsync(t), - ExceptionPolicy.Transfer); - } - - AmqpValue value = new AmqpValue(); - value.Value = this.token.TokenValue; - AmqpMessage putTokenRequest = AmqpMessage.Create(value); - putTokenRequest.ApplicationProperties.Map[CbsConstants.Operation] = CbsConstants.PutToken.OperationValue; - putTokenRequest.ApplicationProperties.Map[CbsConstants.PutToken.Type] = tokenType; - putTokenRequest.ApplicationProperties.Map[CbsConstants.PutToken.Audience] = this.audience; - putTokenRequest.ApplicationProperties.Map[CbsConstants.PutToken.Expiration] = this.token.ExpiresAtUtc; - - AmqpMessage putTokenResponse = null; - Fx.Assert(this.requestResponseLinkTask.Result != null, "requestResponseLink cannot be null without exception"); - yield return this.CallAsync( - (thisPtr, t, c, s) => thisPtr.requestResponseLinkTask.Result.BeginRequest(putTokenRequest, t, c, s), - (thisPtr, r) => putTokenResponse = thisPtr.requestResponseLinkTask.Result.EndRequest(r), - ExceptionPolicy.Transfer); + Fields properties = new Fields(); + properties.Add(CbsConstants.TimeoutName, (uint)timeoutHelper.RemainingTime().TotalMilliseconds); + link = new RequestResponseAmqpLink("cbs", session, address, properties); + await link.OpenAsync(timeoutHelper.RemainingTime()); - int statusCode = (int)putTokenResponse.ApplicationProperties.Map[CbsConstants.PutToken.StatusCode]; - string statusDescription = (string)putTokenResponse.ApplicationProperties.Map[CbsConstants.PutToken.StatusDescription]; - if (statusCode == (int)AmqpResponseStatusCode.Accepted || statusCode == (int)AmqpResponseStatusCode.OK) - { - this.ExpiresAtUtc = this.token.ExpiresAtUtc; - } - else - { - this.Complete(ConvertToException(statusCode, statusDescription)); + AmqpTrace.Provider.AmqpOpenEntitySucceeded(this, link.Name, address); + return link; } - } - - static Exception ConvertToException(int statusCode, string statusDescription) - { - Exception exception; - if (Enum.IsDefined(typeof(AmqpResponseStatusCode), statusCode)) + catch (Exception exception) { - AmqpResponseStatusCode amqpResponseStatusCode = (AmqpResponseStatusCode)statusCode; - switch (amqpResponseStatusCode) + if (this.connection.IsClosing()) { - case AmqpResponseStatusCode.BadRequest: - exception = new AmqpException(AmqpErrorCode.InvalidField, AmqpResources.GetString(AmqpResources.AmqpPutTokenFailed, statusCode, statusDescription)); - break; - case AmqpResponseStatusCode.NotFound: - exception = new AmqpException(AmqpErrorCode.NotFound, AmqpResources.GetString(AmqpResources.AmqpPutTokenFailed, statusCode, statusDescription)); - break; - case AmqpResponseStatusCode.Forbidden: - exception = new AmqpException(AmqpErrorCode.TransferLimitExceeded, AmqpResources.GetString(AmqpResources.AmqpPutTokenFailed, statusCode, statusDescription)); - break; - case AmqpResponseStatusCode.Unauthorized: - exception = new AmqpException(AmqpErrorCode.UnauthorizedAccess, AmqpResources.GetString(AmqpResources.AmqpPutTokenFailed, statusCode, statusDescription)); - break; - default: - exception = new AmqpException(AmqpErrorCode.InvalidField, AmqpResources.GetString(AmqpResources.AmqpPutTokenFailed, statusCode, statusDescription)); - break; + throw new OperationCanceledException("Connection is closing or closed.", exception); } - } - else - { - exception = new AmqpException(AmqpErrorCode.InvalidField, AmqpResources.GetString(AmqpResources.AmqpPutTokenFailed, statusCode, statusDescription)); + + lastException = exception; + AmqpTrace.Provider.AmqpOpenEntityFailed(this, this.GetType().Name, address, exception); } - return exception; + await Task.Delay(1000); } + + link?.Abort(); + session?.SafeClose(); + + throw new TimeoutException(AmqpResources.GetString(AmqpResources.AmqpTimeout, timeout, address), lastException); } } } diff --git a/src/Fx/IteratorAsyncResult.cs b/src/Fx/IteratorAsyncResult.cs deleted file mode 100644 index 21af3689..00000000 --- a/src/Fx/IteratorAsyncResult.cs +++ /dev/null @@ -1,560 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -namespace Microsoft.Azure.Amqp -{ - using System; - using System.Collections.Generic; - using System.Diagnostics; - using System.Threading; - using System.Threading.Tasks; - - // --------------------------------------------------------------------------------- - // Base class for async results that compose other async operations. - // - // The goal of this class is to make it easy to write async code - // that composes other async code with arbitrary logic. Here is - // an example: - // - // yield return this.CallAsync( - // (thisPtr, t, c, s) => thisPtr.channel.BeginOpen(t, c, s), - // (thisPtr, r) => thisPtr.channel.EndOpen(r), - // ExceptionPolicy.Transfer - // ); - // - // foreach (Message m in this.GetMessages()) - // { - // if (this.ShouldSend(m)) - // { - // yield return this.CallAsync( - // (thisPtr, t, c, s) => thisPtr.channel.BeginSend(t, c, s), - // (thisPtr, r) => thisPtr.channel.EndSend(r), - // ExceptionPolicy.Transfer - // ); - // } - // } - // - // yield return this.CallAsync( - // (thisPtr, t, c, s) => thisPtr.channel.BeginClose(t, c, s), - // (thisPtr, r) => thisPtr.channel.EndClose(r), - // ExceptionPolicy.Transfer - // ); - // - // The idiom of (thisPtr, t, c, s) represents: - // (XXXAsyncResult thisPtr, TimeSpan timeout, AsyncCallback callback, object state) - // - // The idiom of (thisPtr, r) represents: - // (XXXAsyncResult thisPtr, IAsyncResult result) - // - // Because these patterns arise so frequently with - // IteratorAsyncResult, best practice is to use the single letter - // parameters to maximize the chances the lambda will fit on a - // single line. - // - // This sort of code forms the body of the GetAsyncSteps method. This - // method returns an IEnumerable, and by using "yield return", - // the implementation releases the thread exactly at the async points in - // the flow. The Begin and End lambdas are called by the base class, and - // control resumes just after the yield statement. The C# iterator syntax - // handles capturing all the local state of the method. - // - // The third parameter is a "ExceptionPolicy". - // ExceptionPolicy.Transfer is used if you want the async result to be completed with - // the exception when the step throws. ExceptionPolicy.Continue is used if you want to ignore - // the excpetion and continue next steps. The last exception ignored can be accessed with - // LastAsyncException property. - // --------------------------------------------------------------------------------- - // - // The TIteratorAsyncResult here is type of the derived class itself that you're implementing: - // Ex) - // class OpenAsyncResult : IteratorAsyncResult - // { - // ... - // } - // - // This form permits the iterator to be passed as a parameter to CallAsync delegates, so that instance members - // can be referenced without reference to this or local variables, which would result in allocating additional memory to pass the state. - [DebuggerStepThrough] - abstract class IteratorAsyncResult : AsyncResult - where TIteratorAsyncResult : IteratorAsyncResult - { - static readonly Action onFinally = IteratorAsyncResult.Finally; - - static AsyncCompletion stepCallbackDelegate; - - // DON'T make TimeoutHelper readonly field. - // It is very unfortunate design but TimeoutHelper is a struct (value type) that is mutating. - // Declarating it as readonly has side impact that it prevents TimeoutHelper mutating itself causing RemainingTime() method - // returning the original timeout value everytime. - TimeoutHelper timeoutHelper; - volatile bool everCompletedAsynchronously; - IEnumerator steps; - Exception lastAsyncStepException; - - protected IteratorAsyncResult(TimeSpan timeout, AsyncCallback callback, object state) - : base(callback, state) - { - this.timeoutHelper = new TimeoutHelper(timeout, true); - this.OnCompleting += IteratorAsyncResult.onFinally; - } - - protected delegate IAsyncResult BeginCall(TIteratorAsyncResult thisPtr, TimeSpan timeout, AsyncCallback callback, object state); - - protected delegate void EndCall(TIteratorAsyncResult thisPtr, IAsyncResult ar); - - protected delegate void Call(TIteratorAsyncResult thisPtr, TimeSpan timeout); - - private enum CurrentThreadType - { - Synchronous, - StartingThread, - Callback - } - - protected Exception LastAsyncStepException - { - get { return this.lastAsyncStepException; } - set { this.lastAsyncStepException = value; } - } - - public TimeSpan OriginalTimeout - { - get - { - return this.timeoutHelper.OriginalTimeout; - } - } - - private static AsyncCompletion StepCallbackDelegate - { - get - { - // The race here is intentional and harmless. - if (stepCallbackDelegate == null) - { - stepCallbackDelegate = new AsyncCompletion(StepCallback); - } - - return stepCallbackDelegate; - } - } - - // This is typically called at the end of the derived AsyncResult - // constructor, to start the async operation. - public IAsyncResult Start() - { - Debug.Assert(this.steps == null, "IteratorAsyncResult.Start called twice"); - try - { - this.steps = this.GetAsyncSteps(); - - this.EnumerateSteps(CurrentThreadType.StartingThread); - } - catch (Exception e) when (!Fx.IsFatal(e)) - { - this.Complete(e); - } - - return this; - } - - public TIteratorAsyncResult RunSynchronously() - { - Debug.Assert(this.steps == null, "IteratorAsyncResult.RunSynchronously or .Start called twice"); - try - { - this.steps = this.GetAsyncSteps(); - this.EnumerateSteps(CurrentThreadType.Synchronous); - } - catch (Exception e) when (!Fx.IsFatal(e)) - { - this.Complete(e); - } - - return End(this); - } - - // Utility method to be called from GetAsyncSteps. To create an implementation - // of IAsyncCatch, use the CatchAndTransfer or CatchAndContinue methods. - protected AsyncStep CallAsync(BeginCall beginCall, EndCall endCall, Call call, ExceptionPolicy policy) - { - return new AsyncStep(beginCall, endCall, call, policy); - } - - protected AsyncStep CallAsync(BeginCall beginCall, EndCall endCall, ExceptionPolicy policy) - { - return new AsyncStep(beginCall, endCall, null, policy); - } - - protected AsyncStep CallParallelAsync(ICollection workItems, BeginCall beginCall, EndCall endCall, ExceptionPolicy policy) - { - return this.CallAsync( - (thisPtr, t, c, s) => new ParallelAsyncResult(thisPtr, workItems, beginCall, endCall, t, c, s), - (thisPtr, r) => ParallelAsyncResult.End(r), - policy); - } - - protected AsyncStep CallParallelAsync(ICollection workItems, BeginCall beginCall, EndCall endCall, TimeSpan timeout, ExceptionPolicy policy) - { - return this.CallAsync( - (thisPtr, t, c, s) => new ParallelAsyncResult(thisPtr, workItems, beginCall, endCall, timeout, c, s), - (thisPtr, r) => ParallelAsyncResult.End(r), - policy); - } - - protected AsyncStep CallTask(Func taskFunc, ExceptionPolicy policy) - { - return this.CallAsync( - (thisPtr, t, c, s) => - { - var task = taskFunc(thisPtr, t); - if (task.Status == TaskStatus.Created) - { - // User func might have created a Task without starting. - // This can potentially hang threads. - task.Start(); - } - - return task.ToAsyncResult(c, s); - }, - (thisPtr, r) => TaskHelpers.EndAsyncResult(r), - policy); - } - - protected AsyncStep CallCompletedAsyncStep() - { - return this.CallAsync( - (thisPtr, t, c, s) => new CompletedAsyncResult(c, s), - (thisPtr, r) => CompletedAsyncResult.End(r), - ExceptionPolicy.Transfer); - } - - protected TimeSpan RemainingTime() - { - return this.timeoutHelper.RemainingTime(); - } - - // The derived AsyncResult implements this method as a C# iterator. - // The implementation should make no blocking calls. Instead, it - // runs synchronous code and can "yield return" the result of calling - // "CallAsync" to cause an async method invocation. - protected abstract IEnumerator GetAsyncSteps(); - - protected void Complete(Exception operationException) - { - this.Complete(!this.everCompletedAsynchronously, operationException); - } - - static bool StepCallback(IAsyncResult result) - { - var thisPtr = (IteratorAsyncResult)result.AsyncState; - - bool syncContinue = thisPtr.CheckSyncContinue(result); - - if (!syncContinue) - { - thisPtr.everCompletedAsynchronously = true; - - try - { - // Don't refactor this into a seperate method. It adds one extra call stack reducing readibility of call stack in trace. - thisPtr.steps.Current.EndCall((TIteratorAsyncResult)thisPtr, result); - } - catch (Exception e) when (!Fx.IsFatal(e) && thisPtr.HandleException(e)) - { - } - - thisPtr.EnumerateSteps(CurrentThreadType.Callback); - } - - return syncContinue; - } - - static void Finally(AsyncResult result, Exception exception) - { - var thisPtr = (IteratorAsyncResult)result; - try - { - IEnumerator steps = thisPtr.steps; - if (steps != null) - { - steps.Dispose(); - } - } - catch (Exception e) when (!Fx.IsFatal(e)) - { - ////MessagingClientEtwProvider.Provider.EventWriteExceptionAsWarning(e.ToStringSlim()); - if (exception == null) - { - throw; - } - } - } - - bool MoveNextStep() - { - return this.steps.MoveNext(); - } - - // This runs async steps until one of them completes asynchronously, or until - // Begin throws on the Start thread with a policy of PassThrough. - void EnumerateSteps(CurrentThreadType state) - { - while (!this.IsCompleted && this.MoveNextStep()) - { - this.LastAsyncStepException = null; - AsyncStep step = this.steps.Current; - if (step.BeginCall != null) - { - IAsyncResult result = null; - - if (state == CurrentThreadType.Synchronous && step.HasSynchronous) - { - if (step.Policy == ExceptionPolicy.Transfer) - { - step.Call((TIteratorAsyncResult)this, this.timeoutHelper.RemainingTime()); - } - else - { - try - { - step.Call((TIteratorAsyncResult)this, this.timeoutHelper.RemainingTime()); - } - catch (Exception e) when (!Fx.IsFatal(e) && this.HandleException(e)) - { - } - } - } - else - { - if (step.Policy == ExceptionPolicy.Transfer) - { - // Don't refactor this into a seperate method. It adds one extra call stack reducing readibility of call stack in trace. - result = step.BeginCall( - (TIteratorAsyncResult)this, - this.timeoutHelper.RemainingTime(), - this.PrepareAsyncCompletion(IteratorAsyncResult.StepCallbackDelegate), - this); - } - else - { - try - { - // Don't refactor this into a seperate method. It adds one extra call stack reducing readibility of call stack in trace. - result = step.BeginCall( - (TIteratorAsyncResult)this, - this.timeoutHelper.RemainingTime(), - this.PrepareAsyncCompletion(IteratorAsyncResult.StepCallbackDelegate), - this); - } - catch (Exception e) when (!Fx.IsFatal(e) && this.HandleException(e)) - { - } - } - } - - if (result != null) - { - if (!this.CheckSyncContinue(result)) - { - return; - } - - try - { - // Don't refactor this into a seperate method. It adds one extra call stack reducing readibility of call stack in trace. - this.steps.Current.EndCall((TIteratorAsyncResult)this, result); - } - catch (Exception e) when (!Fx.IsFatal(e) && this.HandleException(e)) - { - } - } - } - } - - if (!this.IsCompleted) - { - this.Complete(!this.everCompletedAsynchronously); - } - } - - // Returns true if a handler matched the Exception, false otherwise. - bool HandleException(Exception e) - { - bool handled; - - this.LastAsyncStepException = e; - AsyncStep step = this.steps.Current; - - switch (step.Policy) - { - case ExceptionPolicy.Continue: - handled = true; - break; - case ExceptionPolicy.Transfer: - handled = false; - if (!this.IsCompleted) - { - this.Complete(e); - handled = true; - } - break; - default: - handled = false; - break; - } - - return handled; - } - - [DebuggerStepThrough] - protected struct AsyncStep - { - readonly ExceptionPolicy policy; - readonly BeginCall beginCall; - readonly EndCall endCall; - readonly Call call; - - public static readonly AsyncStep Empty = new AsyncStep(); - - public AsyncStep( - BeginCall beginCall, - EndCall endCall, - Call call, - ExceptionPolicy policy) - { - this.policy = policy; - this.beginCall = beginCall; - this.endCall = endCall; - this.call = call; - } - - public BeginCall BeginCall - { - get { return this.beginCall; } - } - - public EndCall EndCall - { - get { return this.endCall; } - } - - public Call Call - { - get { return this.call; } - } - - public bool HasSynchronous - { - get - { - return this.call != null; - } - } - - public ExceptionPolicy Policy - { - get - { - return this.policy; - } - } - } - - protected enum ExceptionPolicy - { - /// - /// ExceptionPolicy.Transfer is used if you want the async result to be completed with - /// the exception when the step throws. - /// - Transfer, - - /// - /// ExceptionPolicy.Continue is used if you want to ignore - /// the exception and continue next steps. The last exception ignored can be accessed with - /// LastAsyncException property. - /// - Continue - } - - protected delegate IAsyncResult BeginCall(TIteratorAsyncResult thisPtr, TWorkItem workItem, TimeSpan timeout, AsyncCallback callback, object state); - - protected delegate void EndCall(TIteratorAsyncResult thisPtr, TWorkItem workItem, IAsyncResult ar); - - sealed class ParallelAsyncResult : AsyncResult> - { - static readonly AsyncCallback completed = new AsyncCallback(OnCompleted); - - readonly TIteratorAsyncResult iteratorAsyncResult; - readonly ICollection workItems; - readonly EndCall endCall; - long actions; - Exception firstException; - - public ParallelAsyncResult(TIteratorAsyncResult iteratorAsyncResult, ICollection workItems, BeginCall beginCall, EndCall endCall, TimeSpan timeout, AsyncCallback callback, object state) - : base(callback, state) - { - this.iteratorAsyncResult = iteratorAsyncResult; - this.workItems = workItems; - this.endCall = endCall; - this.actions = this.workItems.Count + 1; - - foreach (TWorkItem source in workItems) - { - try - { - beginCall(iteratorAsyncResult, source, timeout, completed, new CallbackState(this, source)); - } - catch (Exception e) when (!Fx.IsFatal(e)) - { - TryComplete(e, true); - } - } - - TryComplete(null, true); - } - - void TryComplete(Exception exception, bool completedSynchronously) - { - if (this.firstException == null && exception != null) - { - this.firstException = exception; - } - - if (Interlocked.Decrement(ref this.actions) == 0) - { - this.Complete(completedSynchronously, this.firstException); - } - } - - static void OnCompleted(IAsyncResult ar) - { - CallbackState state = (CallbackState)ar.AsyncState; - ParallelAsyncResult thisPtr = state.AsyncResult; - - try - { - thisPtr.endCall(thisPtr.iteratorAsyncResult, state.AsyncData, ar); - thisPtr.TryComplete(null, ar.CompletedSynchronously); - } - catch (Exception e) when (!Fx.IsFatal(e)) - { - thisPtr.TryComplete(e, ar.CompletedSynchronously); - } - } - - sealed class CallbackState - { - public CallbackState(ParallelAsyncResult asyncResult, TWorkItem data) - { - this.AsyncResult = asyncResult; - this.AsyncData = data; - } - - public ParallelAsyncResult AsyncResult { get; private set; } - - public TWorkItem AsyncData { get; private set; } - } - } - } -} \ No newline at end of file diff --git a/src/Fx/Singleton.cs b/src/Fx/Singleton.cs index 0c477ece..5f280750 100644 --- a/src/Fx/Singleton.cs +++ b/src/Fx/Singleton.cs @@ -50,7 +50,7 @@ public Task CloseAsync() { this.Dispose(); - return TaskHelpers.CompletedTask; + return Task.CompletedTask; } public void Close() diff --git a/src/Fx/TaskHelpers.cs b/src/Fx/TaskHelpers.cs deleted file mode 100644 index a5bcd2d6..00000000 --- a/src/Fx/TaskHelpers.cs +++ /dev/null @@ -1,212 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -namespace Microsoft.Azure.Amqp -{ - using System; - using System.Runtime.InteropServices; - using System.Threading; - using System.Threading.Tasks; - - static class TaskHelpers - { - public static readonly Task CompletedTask = Task.FromResult(default(VoidTaskResult)); - - /// - /// Create a Task based on Begin/End IAsyncResult pattern. - /// - /// - /// - /// - /// This parameter helps reduce allocations by passing state to the Funcs. e.g.: - /// await TaskHelpers.CreateTask( - /// (c, s) => ((Transaction)s).BeginCommit(c, s), - /// (a) => ((Transaction)a.AsyncState).EndCommit(a), - /// transaction); - /// - public static Task CreateTask(Func begin, Action end, object state = null) - { - Task retval; - try - { - retval = Task.Factory.FromAsync(begin, end, state); - } - catch (Exception ex) when (!Fx.IsFatal(ex)) - { - var completionSource = new TaskCompletionSource(state); - completionSource.SetException(ex); - retval = completionSource.Task; - } - - return retval; - } - - public static Task CreateTask(Func begin, Func end, object state = null) - { - Task retval; - try - { - retval = Task.Factory.FromAsync(begin, end, state); - } - catch (Exception ex) when (!Fx.IsFatal(ex)) - { - var completionSource = new TaskCompletionSource(state); - completionSource.SetException(ex); - retval = completionSource.Task; - } - - return retval; - } - - public static IAsyncResult ToAsyncResult(this Task task, AsyncCallback callback, object state) - { - if (task.AsyncState == state) - { - if (callback != null) - { - task.ContinueWith( - t => callback(t), - TaskContinuationOptions.ExecuteSynchronously); - } - - return task; - } - - var tcs = new TaskCompletionSource(state); - task.ContinueWith( - t => - { - if (t.IsFaulted) - { - tcs.TrySetException(t.Exception.InnerExceptions); - } - else if (t.IsCanceled) - { - tcs.TrySetCanceled(); - } - else - { - tcs.TrySetResult(null); - } - - callback?.Invoke(tcs.Task); - }, - TaskContinuationOptions.ExecuteSynchronously); - - return tcs.Task; - } - - public static IAsyncResult ToAsyncResult(this Task task, AsyncCallback callback, object state) - { - if (task.AsyncState == state) - { - if (callback != null) - { - task.ContinueWith( - t => callback(t), - TaskContinuationOptions.ExecuteSynchronously); - } - - return task; - } - - var tcs = new TaskCompletionSource(state); - task.ContinueWith( - t => - { - if (t.IsFaulted) - { - tcs.TrySetException(t.Exception.InnerExceptions); - } - else if (t.IsCanceled) - { - tcs.TrySetCanceled(); - } - else - { - tcs.TrySetResult(t.Result); - } - - callback?.Invoke(tcs.Task); - }, - TaskContinuationOptions.ExecuteSynchronously); - - return tcs.Task; - } - - public static void EndAsyncResult(IAsyncResult asyncResult) - { - Task task = asyncResult as Task; - if (task == null) - { - throw new ArgumentException(CommonResources.InvalidAsyncResult); - } - - task.GetAwaiter().GetResult(); - } - - public static TResult EndAsyncResult(IAsyncResult asyncResult) - { - Task task = asyncResult as Task; - if (task == null) - { - throw new ArgumentException(CommonResources.InvalidAsyncResult); - } - - return task.GetAwaiter().GetResult(); - } - - public static Task WithTimeout(this Task task, TimeSpan timeout, Func errorMessage) - { - return WithTimeout(task, timeout, errorMessage, CancellationToken.None); - } - - public static async Task WithTimeout(this Task task, TimeSpan timeout, Func errorMessage, CancellationToken token) - { - if (timeout == TimeSpan.MaxValue) - { - timeout = Timeout.InfiniteTimeSpan; - } - else if (timeout.TotalMilliseconds > Int32.MaxValue) - { - timeout = TimeSpan.FromMilliseconds(Int32.MaxValue); - } - - if (task.IsCompleted || (timeout == Timeout.InfiniteTimeSpan && token == CancellationToken.None)) - { - await task.ConfigureAwait(false); - return; - } - - using (var cts = CancellationTokenSource.CreateLinkedTokenSource(token)) - { - if (task == await Task.WhenAny(task, CreateDelayTask(timeout, cts.Token)).ConfigureAwait(false)) - { - cts.Cancel(); - await task.ConfigureAwait(false); - return; - } - } - - throw new TimeoutException(errorMessage()); - } - - static async Task CreateDelayTask(TimeSpan timeout, CancellationToken token) - { - try - { - await Task.Delay(timeout, token).ConfigureAwait(false); - } - catch (TaskCanceledException) - { - // No need to throw. Caller is responsible for detecting - // which task completed and throwing appropriate Timeout Exception - } - } - - [StructLayout(LayoutKind.Sequential, Size = 1)] - internal struct VoidTaskResult - { - } - } -} diff --git a/src/ReceivingAmqpLink.cs b/src/ReceivingAmqpLink.cs index 50861904..56070bfc 100644 --- a/src/ReceivingAmqpLink.cs +++ b/src/ReceivingAmqpLink.cs @@ -117,7 +117,7 @@ public IAsyncResult BeginReceiveRemoteMessages(int messageCount, TimeSpan batchW public Task ReceiveMessageAsync(TimeSpan timeout) { - return TaskHelpers.CreateTask( + return Task.Factory.FromAsync( (c, s) => ((ReceivingAmqpLink)s).BeginReceiveMessage(timeout, c, s), (a) => { @@ -244,8 +244,8 @@ public Task DisposeMessageAsync(ArraySegment deliveryTag, Outcome public Task DisposeMessageAsync(ArraySegment deliveryTag, ArraySegment txnId, Outcome outcome, bool batchable, TimeSpan timeout) { - return TaskHelpers.CreateTask( - (c, s) => this.BeginDisposeMessage(deliveryTag, txnId, outcome, batchable, timeout, c, s), + return Task.Factory.FromAsync( + (c, s) => ((ReceivingAmqpLink)s).BeginDisposeMessage(deliveryTag, txnId, outcome, batchable, timeout, c, s), a => ((ReceivingAmqpLink)a.AsyncState).EndDisposeMessage(a), this); } diff --git a/src/RequestResponseAmqpLink.cs b/src/RequestResponseAmqpLink.cs index 6659e9af..f3e559c0 100644 --- a/src/RequestResponseAmqpLink.cs +++ b/src/RequestResponseAmqpLink.cs @@ -128,7 +128,7 @@ public AmqpSession Session public Task RequestAsync(AmqpMessage request, TimeSpan timeout) { - return TaskHelpers.CreateTask( + return Task.Factory.FromAsync( (c, s) => ((RequestResponseAmqpLink)s).BeginRequest(request, timeout, c, s), (r) => ((RequestResponseAmqpLink)r.AsyncState).EndRequest(r), this); diff --git a/src/Transport/TransportStream.cs b/src/Transport/TransportStream.cs index 4aa9673a..ade8ae48 100644 --- a/src/Transport/TransportStream.cs +++ b/src/Transport/TransportStream.cs @@ -92,9 +92,9 @@ public override void SetLength(long value) public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - return TaskHelpers.CreateTask( - (c, s) => this.BeginWrite(buffer, offset, count, c, s), - (a) => this.EndWrite(a), + return Task.Factory.FromAsync( + (c, s) => ((Stream)s).BeginWrite(buffer, offset, count, c, s), + (a) => ((Stream)a.AsyncState).EndWrite(a), this); } @@ -124,9 +124,9 @@ public override void EndWrite(IAsyncResult asyncResult) public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - return TaskHelpers.CreateTask( - (c, s) => this.BeginRead(buffer, offset, count, c, s), - (a) => this.EndRead(a), + return Task.Factory.FromAsync( + (c, s) => ((Stream)s).BeginRead(buffer, offset, count, c, s), + (a) => ((Stream)a.AsyncState).EndRead(a), this); } diff --git a/src/Transport/WebSocketTransportInitiator.cs b/src/Transport/WebSocketTransportInitiator.cs index b75abf93..ff2a1de6 100644 --- a/src/Transport/WebSocketTransportInitiator.cs +++ b/src/Transport/WebSocketTransportInitiator.cs @@ -27,7 +27,12 @@ public override bool ConnectAsync(TimeSpan timeout, TransportAsyncCallbackArgs c cws.Options.Proxy = this.settings.Proxy; } - Task task = cws.ConnectAsync(this.settings.Uri, CancellationToken.None).WithTimeout(timeout, () => "timeout"); + var task = new TimeoutTaskSource( + cws, + s => s.ConnectAsync(this.settings.Uri, CancellationToken.None), + s => s.Abort(), + timeout).Task; + if (task.IsCompleted) { callbackArgs.Transport = new WebSocketTransport(cws, this.settings.Uri); @@ -53,5 +58,48 @@ public override bool ConnectAsync(TimeSpan timeout, TransportAsyncCallbackArgs c }); return true; } + + sealed class TimeoutTaskSource : TaskCompletionSource where T : class + { + readonly T t; + readonly TimeSpan timeout; + readonly ITimer timer; + readonly Action onTimeout; + + public TimeoutTaskSource(T t, Func onStart, Action onTimeout, TimeSpan timeout) + { + this.t = t; + this.onTimeout = onTimeout; + this.timeout = timeout; + this.timer = SystemTimerFactory.Default.Create(OnTimer, this, timeout); + + Task task = onStart(t); + task.ContinueWith((_t, _s) => ((TimeoutTaskSource)_s).OnTask(_t), this); + } + + static void OnTimer(object state) + { + var thisPtr = (TimeoutTaskSource)state; + thisPtr.onTimeout(thisPtr.t); + thisPtr.TrySetException(new TimeoutException(AmqpResources.GetString(AmqpResources.AmqpTimeout, thisPtr.timeout, typeof(T).Name))); + } + + void OnTask(Task inner) + { + this.timer.Cancel(); + if (inner.IsFaulted) + { + this.TrySetException(inner.Exception.InnerException); + } + else if (inner.IsCanceled) + { + this.TrySetCanceled(); + } + else + { + this.TrySetResult(this.t); + } + } + } } } \ No newline at end of file