Skip to content

Commit

Permalink
Fix race condition when cancelling pending HTTP connection attempts (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
MihaZupan authored Jan 8, 2025
1 parent 599c02d commit 0ded51b
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,7 @@ private async Task AddHttp11ConnectionAsync(RequestQueue<HttpConnection>.QueueIt
HttpConnection? connection = null;
Exception? connectionException = null;

CancellationTokenSource cts = GetConnectTimeoutCancellationTokenSource();
waiter.ConnectionCancellationTokenSource = cts;
CancellationTokenSource cts = GetConnectTimeoutCancellationTokenSource(waiter);
try
{
connection = await CreateHttp11ConnectionAsync(queueItem.Request, true, cts.Token).ConfigureAwait(false);
Expand Down Expand Up @@ -691,8 +690,7 @@ private async Task AddHttp2ConnectionAsync(RequestQueue<Http2Connection?>.QueueI
Exception? connectionException = null;
HttpConnectionWaiter<Http2Connection?> waiter = queueItem.Waiter;

CancellationTokenSource cts = GetConnectTimeoutCancellationTokenSource();
waiter.ConnectionCancellationTokenSource = cts;
CancellationTokenSource cts = GetConnectTimeoutCancellationTokenSource(waiter);
try
{
(Stream stream, TransportContext? transportContext, IPEndPoint? remoteEndPoint) = await ConnectAsync(queueItem.Request, true, cts.Token).ConfigureAwait(false);
Expand Down Expand Up @@ -1520,7 +1518,27 @@ public ValueTask<HttpResponseMessage> SendAsync(HttpRequestMessage request, bool
return SendWithProxyAuthAsync(request, async, doRequestAuth, cancellationToken);
}

private CancellationTokenSource GetConnectTimeoutCancellationTokenSource() => new CancellationTokenSource(Settings._connectTimeout);
private CancellationTokenSource GetConnectTimeoutCancellationTokenSource<T>(HttpConnectionWaiter<T> waiter)
where T : HttpConnectionBase?
{
var cts = new CancellationTokenSource(Settings._connectTimeout);

lock (waiter)
{
waiter.ConnectionCancellationTokenSource = cts;

// The initiating request for this connection attempt may complete concurrently at any time.
// If it completed before we've set the CTS, CancelIfNecessary would no-op.
// Check it again now that we're holding the lock and ensure we always set a timeout.
if (waiter.Task.IsCompleted)
{
CancelIfNecessary(waiter, requestCancelled: waiter.Task.IsCanceled);
waiter.ConnectionCancellationTokenSource = null;
}
}

return cts;
}

private async ValueTask<(Stream, TransportContext?, IPEndPoint?)> ConnectAsync(HttpRequestMessage request, bool async, CancellationToken cancellationToken)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,68 @@ public void PendingConnectionTimeout_HighValue_PendingConnectionIsNotCancelled(i
}, UseVersion.ToString(), timeout.ToString()).Dispose();
}

[OuterLoop("We wait for PendingConnectionTimeout which defaults to 5 seconds.")]
[Fact]
public async Task PendingConnectionTimeout_SignalsAllConnectionAttempts()
{
if (UseVersion == HttpVersion.Version30)
{
// HTTP3 does not support ConnectCallback
return;
}

int pendingConnectionAttempts = 0;
bool connectionAttemptTimedOut = false;

using var handler = new SocketsHttpHandler
{
ConnectCallback = async (context, cancellation) =>
{
Interlocked.Increment(ref pendingConnectionAttempts);
try
{
await Assert.ThrowsAsync<TaskCanceledException>(() => Task.Delay(-1, cancellation)).WaitAsync(TestHelper.PassingTestTimeout);
cancellation.ThrowIfCancellationRequested();
throw new UnreachableException();
}
catch (TimeoutException)
{
connectionAttemptTimedOut = true;
throw;
}
finally
{
Interlocked.Decrement(ref pendingConnectionAttempts);
}
}
};

using HttpClient client = CreateHttpClient(handler);
client.Timeout = TimeSpan.FromSeconds(2);

// Many of these requests should trigger new connection attempts, and all of those should eventually be cleaned up.
await Parallel.ForAsync(0, 100, async (_, _) =>
{
await Assert.ThrowsAnyAsync<TaskCanceledException>(() => client.GetAsync("https://dummy"));
});

Stopwatch stopwatch = Stopwatch.StartNew();

while (Volatile.Read(ref pendingConnectionAttempts) > 0)
{
Assert.False(connectionAttemptTimedOut);

if (stopwatch.Elapsed > 2 * TestHelper.PassingTestTimeout)
{
Assert.Fail("Connection attempts took too long to get cleaned up");
}

await Task.Delay(100);
}

Assert.False(connectionAttemptTimedOut);
}

private sealed class SetTcsContent : StreamContent
{
private readonly TaskCompletionSource<bool> _tcs;
Expand Down

0 comments on commit 0ded51b

Please sign in to comment.