From 5080a3b95a01a02ae4834fe508f7efdbadc3c135 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Tue, 25 Oct 2022 23:16:44 +0100 Subject: [PATCH 1/2] add disposable stack temp ref struct and use --- .../src/Microsoft.Data.SqlClient.csproj | 3 + .../Microsoft/Data/SqlClient/SqlDataReader.cs | 139 +++++++++--------- .../netfx/src/Microsoft.Data.SqlClient.csproj | 3 + .../Microsoft/Data/SqlClient/SqlDataReader.cs | 112 +++++++------- .../SqlClient/DisposableTemporaryOnStack.cs | 40 +++++ 5 files changed, 177 insertions(+), 120 deletions(-) create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/DisposableTemporaryOnStack.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 242b636730..07644b1355 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -109,6 +109,9 @@ Microsoft\Data\SqlClient\DataClassification\SensitivityClassification.cs + + Microsoft\Data\SqlClient\DisposableTemporaryOnStack.cs + Microsoft\Data\SqlClient\EnclaveDelegate.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs index cf271ea749..d3e96aa8f2 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -4399,6 +4399,7 @@ private void AssertReaderState(bool requireData, bool permitAsync, int? columnIn public override Task NextResultAsync(CancellationToken cancellationToken) { using (TryEventScope.Create("SqlDataReader.NextResultAsync | API | Object Id {0}", ObjectID)) + using (var registrationHolder = new DisposableTemporaryOnStack()) { TaskCompletionSource source = new TaskCompletionSource(); @@ -4408,7 +4409,6 @@ public override Task NextResultAsync(CancellationToken cancellationToken) return source.Task; } - CancellationTokenRegistration registration = default; if (cancellationToken.CanBeCanceled) { if (cancellationToken.IsCancellationRequested) @@ -4416,7 +4416,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken) source.SetCanceled(); return source.Task; } - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); + registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); } Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); @@ -4434,7 +4434,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken) return source.Task; } - return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registration)); + return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registrationHolder.Take())); } } @@ -4729,6 +4729,7 @@ out bytesRead public override Task ReadAsync(CancellationToken cancellationToken) { using (TryEventScope.Create("SqlDataReader.ReadAsync | API | Object Id {0}", ObjectID)) + using (var registrationHolder = new DisposableTemporaryOnStack()) { if (IsClosed) { @@ -4736,10 +4737,9 @@ public override Task ReadAsync(CancellationToken cancellationToken) } // Register first to catch any already expired tokens to be able to trigger cancellation event. - CancellationTokenRegistration registration = default; if (cancellationToken.CanBeCanceled) { - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); + registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); } // If user's token is canceled, return a canceled task @@ -4852,7 +4852,7 @@ public override Task ReadAsync(CancellationToken cancellationToken) Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == default, "cached ReadAsyncCallContext was not properly disposed"); - context.Set(this, source, registration); + context.Set(this, source, registrationHolder.Take()); context._hasMoreData = more; context._hasReadRowToken = rowTokenRead; @@ -4990,49 +4990,51 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo return Task.FromException(ex); } - // Setup and check for pending task - TaskCompletionSource source = new TaskCompletionSource(); - Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); - if (original != null) + using (var registrationHolder = new DisposableTemporaryOnStack()) { - source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); - return source.Task; - } + // Setup and check for pending task + TaskCompletionSource source = new TaskCompletionSource(); + Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); + if (original != null) + { + source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); + return source.Task; + } - // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) - if (_cancelAsyncOnCloseToken.IsCancellationRequested) - { - source.SetCanceled(); - _currentTask = null; - return source.Task; - } + // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) + if (_cancelAsyncOnCloseToken.IsCancellationRequested) + { + source.SetCanceled(); + _currentTask = null; + return source.Task; + } - // Setup cancellations - CancellationTokenRegistration registration = default; - if (cancellationToken.CanBeCanceled) - { - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); - } + // Setup cancellations + if (cancellationToken.CanBeCanceled) + { + registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); + } - IsDBNullAsyncCallContext context = null; - if (_connection?.InnerConnection is SqlInternalConnection sqlInternalConnection) - { - context = Interlocked.Exchange(ref sqlInternalConnection.CachedDataReaderIsDBNullContext, null); - } - if (context is null) - { - context = new IsDBNullAsyncCallContext(); - } + IsDBNullAsyncCallContext context = null; + if (_connection?.InnerConnection is SqlInternalConnection sqlInternalConnection) + { + context = Interlocked.Exchange(ref sqlInternalConnection.CachedDataReaderIsDBNullContext, null); + } + if (context is null) + { + context = new IsDBNullAsyncCallContext(); + } - Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == default, "cached ISDBNullAsync context not properly disposed"); + Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == default, "cached ISDBNullAsync context not properly disposed"); - context.Set(this, source, registration); - context._columnIndex = i; + context.Set(this, source, registrationHolder.Take()); + context._columnIndex = i; - // Setup async - PrepareAsyncInvocation(useSnapshot: true); + // Setup async + PrepareAsyncInvocation(useSnapshot: true); - return InvokeAsyncCall(context); + return InvokeAsyncCall(context); + } } } @@ -5137,37 +5139,39 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat return Task.FromException(ex); } - // Setup and check for pending task - TaskCompletionSource source = new TaskCompletionSource(); - Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); - if (original != null) + using (var registrationHolder = new DisposableTemporaryOnStack()) { - source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); - return source.Task; - } + // Setup and check for pending task + TaskCompletionSource source = new TaskCompletionSource(); + Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); + if (original != null) + { + source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); + return source.Task; + } - // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) - if (_cancelAsyncOnCloseToken.IsCancellationRequested) - { - source.SetCanceled(); - _currentTask = null; - return source.Task; - } + // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) + if (_cancelAsyncOnCloseToken.IsCancellationRequested) + { + source.SetCanceled(); + _currentTask = null; + return source.Task; + } - // Setup cancellations - CancellationTokenRegistration registration = default; - if (cancellationToken.CanBeCanceled) - { - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); - } + // Setup cancellations + if (cancellationToken.CanBeCanceled) + { + registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); + } - // Setup async - PrepareAsyncInvocation(useSnapshot: true); + // Setup async + PrepareAsyncInvocation(useSnapshot: true); - GetFieldValueAsyncCallContext context = new GetFieldValueAsyncCallContext(this, source, registration); - context._columnIndex = i; + GetFieldValueAsyncCallContext context = new GetFieldValueAsyncCallContext(this, source, registrationHolder.Take()); + context._columnIndex = i; - return InvokeAsyncCall(context); + return InvokeAsyncCall(context); + } } private static Task GetFieldValueAsyncExecute(Task task, object state) @@ -5382,6 +5386,9 @@ protected override void Clear() internal override Func> Execute => s_execute; } + + + /// /// Starts the process of executing an async call using an SqlDataReaderAsyncCallContext derived context object. /// After this call the context lifetime is handled by BeginAsyncCall ContinueAsyncCall and CompleteAsyncCall AsyncCall methods diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index bf6b3f3ff2..ff5b7e3c63 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -185,6 +185,9 @@ Microsoft\Data\SqlClient\DataClassification\SensitivityClassification.cs + + Microsoft\Data\SqlClient\DisposableTemporaryOnStack.cs + Microsoft\Data\SqlClient\EnclaveDelegate.cs diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs index c733b7fc8a..fd7306a8f7 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -4966,6 +4966,7 @@ private void AssertReaderState(bool requireData, bool permitAsync, int? columnIn public override Task NextResultAsync(CancellationToken cancellationToken) { using (TryEventScope.Create(" {0}", ObjectID)) + using (var registrationHolder = new DisposableTemporaryOnStack()) { TaskCompletionSource source = new TaskCompletionSource(); @@ -4976,7 +4977,6 @@ public override Task NextResultAsync(CancellationToken cancellationToken) return source.Task; } - IDisposable registration = null; if (cancellationToken.CanBeCanceled) { if (cancellationToken.IsCancellationRequested) @@ -4984,7 +4984,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken) source.SetCanceled(); return source.Task; } - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); + registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); } Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); @@ -5002,7 +5002,7 @@ public override Task NextResultAsync(CancellationToken cancellationToken) return source.Task; } - return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registration)); + return InvokeAsyncCall(new HasNextResultAsyncCallContext(this, source, registrationHolder.Take())); } } @@ -5303,6 +5303,7 @@ out bytesRead public override Task ReadAsync(CancellationToken cancellationToken) { using (TryEventScope.Create(" {0}", ObjectID)) + using (var registrationHolder = new DisposableTemporaryOnStack()) { if (IsClosed) { @@ -5310,10 +5311,9 @@ public override Task ReadAsync(CancellationToken cancellationToken) } // Register first to catch any already expired tokens to be able to trigger cancellation event. - IDisposable registration = null; if (cancellationToken.CanBeCanceled) { - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); + registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); } // If user's token is canceled, return a canceled task @@ -5419,7 +5419,7 @@ public override Task ReadAsync(CancellationToken cancellationToken) Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ReadAsyncCallContext was not properly disposed"); - context.Set(this, source, registration); + context.Set(this, source, registrationHolder.Take()); context._hasMoreData = more; context._hasReadRowToken = rowTokenRead; @@ -5551,41 +5551,43 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo return ADP.CreatedTaskWithException(ex); } - // Setup and check for pending task - TaskCompletionSource source = new TaskCompletionSource(); - Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); - if (original != null) + using (var registrationHolder = new DisposableTemporaryOnStack()) { - source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); - return source.Task; - } + // Setup and check for pending task + TaskCompletionSource source = new TaskCompletionSource(); + Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); + if (original != null) + { + source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); + return source.Task; + } - // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) - if (_cancelAsyncOnCloseToken.IsCancellationRequested) - { - source.SetCanceled(); - _currentTask = null; - return source.Task; - } + // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) + if (_cancelAsyncOnCloseToken.IsCancellationRequested) + { + source.SetCanceled(); + _currentTask = null; + return source.Task; + } - // Setup cancellations - IDisposable registration = null; - if (cancellationToken.CanBeCanceled) - { - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); - } + // Setup cancellations + if (cancellationToken.CanBeCanceled) + { + registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); + } - IsDBNullAsyncCallContext context = Interlocked.Exchange(ref _cachedIsDBNullContext, null) ?? new IsDBNullAsyncCallContext(); + IsDBNullAsyncCallContext context = Interlocked.Exchange(ref _cachedIsDBNullContext, null) ?? new IsDBNullAsyncCallContext(); - Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ISDBNullAsync context not properly disposed"); + Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ISDBNullAsync context not properly disposed"); - context.Set(this, source, registration); - context._columnIndex = i; + context.Set(this, source, registrationHolder.Take()); + context._columnIndex = i; - // Setup async - PrepareAsyncInvocation(useSnapshot: true); + // Setup async + PrepareAsyncInvocation(useSnapshot: true); - return InvokeAsyncCall(context); + return InvokeAsyncCall(context); + } } } @@ -5687,31 +5689,33 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat return ADP.CreatedTaskWithException(ex); } - // Setup and check for pending task - TaskCompletionSource source = new TaskCompletionSource(); - Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); - if (original != null) + using (var registrationHolder = new DisposableTemporaryOnStack()) { - source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); - return source.Task; - } + // Setup and check for pending task + TaskCompletionSource source = new TaskCompletionSource(); + Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); + if (original != null) + { + source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); + return source.Task; + } - // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) - if (_cancelAsyncOnCloseToken.IsCancellationRequested) - { - source.SetCanceled(); - _currentTask = null; - return source.Task; - } + // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) + if (_cancelAsyncOnCloseToken.IsCancellationRequested) + { + source.SetCanceled(); + _currentTask = null; + return source.Task; + } - // Setup cancellations - IDisposable registration = null; - if (cancellationToken.CanBeCanceled) - { - registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command); - } + // Setup cancellations + if (cancellationToken.CanBeCanceled) + { + registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); + } - return InvokeAsyncCall(new GetFieldValueAsyncCallContext(this, source, registration, i)); + return InvokeAsyncCall(new GetFieldValueAsyncCallContext(this, source, registrationHolder.Take(), i)); + } } private static Task GetFieldValueAsyncExecute(Task task, object state) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/DisposableTemporaryOnStack.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/DisposableTemporaryOnStack.cs new file mode 100644 index 0000000000..b57b80d78f --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/DisposableTemporaryOnStack.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + + +using System; + +namespace Microsoft.Data.SqlClient +{ + internal ref struct DisposableTemporaryOnStack + where T : IDisposable + { + private T _value; + private bool _hasValue; + + public void Set(T value) + { + _value = value; + _hasValue = true; + } + + public T Take() + { + T value = _value; + _value = default; + _hasValue = false; + return value; + } + + public void Dispose() + { + if (_hasValue) + { + _value.Dispose(); + _value = default; + _hasValue = false; + } + } + } +} From 1682b89d83b910a694e3f7c47c0fc65d5235e287 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Thu, 3 Nov 2022 02:04:40 +0000 Subject: [PATCH 2/2] address feedback --- .../netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs index d3e96aa8f2..d3c35e9c17 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -5386,9 +5386,6 @@ protected override void Clear() internal override Func> Execute => s_execute; } - - - /// /// Starts the process of executing an async call using an SqlDataReaderAsyncCallContext derived context object. /// After this call the context lifetime is handled by BeginAsyncCall ContinueAsyncCall and CompleteAsyncCall AsyncCall methods