Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Storage] [DataMovement] Fixed bug where adding multiple transfer in parallel could cause a Dictionary Collision in the transfers stored #46919

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,7 @@ public async Task PauseAllTriggersCorrectPauses()
{
unpausable.Add(transfer);
}
manager._dataTransfers.Add(Guid.NewGuid().ToString(), transfer.Object);
manager._dataTransfers.TryAdd(Guid.NewGuid().ToString(), transfer.Object);
}

await manager.PauseAllRunningTransfersAsync(_mockingToken);
Expand Down
1 change: 1 addition & 0 deletions sdk/storage/Azure.Storage.DataMovement/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
### Breaking Changes

### Bugs Fixed
- Fixed bug where adding multiple transfers in parallel could cause a collision (`InvalidOperationException`) in the data transfers stored within the `TransferManager`.

### Other Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,8 @@ public static ArgumentException UnexpectedPropertyType(string propertyName, para

public static InvalidOperationException CheckpointerDisabled(string method)
=> new InvalidOperationException($"Unable to perform {method}. The transfer checkpointer is disabled.");

public static InvalidOperationException CollisionTransferId(string id)
=> new InvalidOperationException($"Transfer Id Collision: The transfer id, {id}, already exists in the transfer manager.");
}
}
28 changes: 18 additions & 10 deletions sdk/storage/Azure.Storage.DataMovement/src/TransferManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Buffers;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
Expand All @@ -28,7 +29,7 @@ public class TransferManager : IAsyncDisposable
/// <summary>
/// Ongoing transfers indexed at the transfer id.
/// </summary>
internal readonly Dictionary<string, DataTransfer> _dataTransfers = new();
internal readonly ConcurrentDictionary<string, DataTransfer> _dataTransfers = new();

/// <summary>
/// Designated checkpointer for the respective transfer manager.
Expand Down Expand Up @@ -290,12 +291,11 @@ bool TryGetStorageResourceProvider(DataTransferProperties properties, bool getSo

transferOptions ??= new DataTransferOptions();

if (_dataTransfers.ContainsKey(dataTransferProperties.TransferId))
{
// Remove the stale DataTransfer so we can pass a new DataTransfer object
// to the user and also track the transfer from the DataTransfer object
_dataTransfers.Remove(dataTransferProperties.TransferId);
}
// Remove the stale DataTransfer so we can pass a new DataTransfer object
// to the user and also track the transfer from the DataTransfer object
// No need to check if we were able to remove the transfer or not.
// If there's no stale DataTransfer to remove, move on.
_dataTransfers.TryRemove(dataTransferProperties.TransferId, out DataTransfer transfer);

if (!TryGetStorageResourceProvider(dataTransferProperties, getSource: true, out StorageResourceProvider sourceProvider))
{
Expand Down Expand Up @@ -409,7 +409,10 @@ private async Task<DataTransfer> BuildAndAddTransferJobAsync(
.ConfigureAwait(false);

transfer.TransferManager = this;
_dataTransfers.Add(transfer.Id, transfer);
if (!_dataTransfers.TryAdd(transfer.Id, transfer))
{
throw Errors.CollisionTransferId(transfer.Id);
}
await _jobsProcessor.QueueAsync(transferJobInternal, cancellationToken).ConfigureAwait(false);

return transfer;
Expand All @@ -424,12 +427,17 @@ private async Task SetDataTransfers(CancellationToken cancellationToken = defaul
foreach (string transferId in storedTransfers)
{
DataTransferStatus jobStatus = await _checkpointer.GetJobStatusAsync(transferId, cancellationToken).ConfigureAwait(false);
_dataTransfers.Add(transferId, new DataTransfer(
// If TryAdd fails here, we need to check if in other places where we are
// adding that every transferId is unique.
if (!_dataTransfers.TryAdd(transferId, new DataTransfer(
id: transferId,
status: jobStatus)
{
TransferManager = this,
});
}))
{
throw Errors.CollisionTransferId(transferId);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,10 @@ public async Task GetTransfers_LocalCheckpointer()

// Act
IList<DataTransfer> result = await manager.GetTransfersAsync().ToListAsync();
List<string> resultIds = result.Select(t => t.Id).ToList();

// Assert
Assert.AreEqual(checkpointerTransfers, result.Select(d => d.Id).ToList());
Assert.IsTrue(Enumerable.SequenceEqual(checkpointerTransfers.OrderBy(id => id), result.Select(t => t.Id).OrderBy(id => id)));
}

[Test]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -14,7 +15,7 @@ namespace Azure.Storage.DataMovement.Tests
/// <typeparam name="T"></typeparam>
internal class StepProcessor<T> : IProcessor<T>
{
private readonly Queue<T> _queue = new();
private readonly ConcurrentQueue<T> _queue = new();

public int ItemsInQueue => _queue.Count;

Expand All @@ -39,7 +40,8 @@ public async ValueTask<bool> TryStepAsync(CancellationToken cancellationToken =
{
if (_queue.Count > 0)
{
await Process?.Invoke(_queue.Dequeue(), cancellationToken);
_queue.TryDequeue(out T result);
await Process?.Invoke(result, cancellationToken);
return true;
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Buffers;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
Expand Down Expand Up @@ -430,6 +431,38 @@ public async Task TransferFailAtPartProcess(
// TODO determine checkpointer status of job chunks
// need checkpointer API refactor for this
}

[Test]
[TestCase(5)]
[TestCase(10)]
[TestCase(12345)]
public async Task MultipleTransfersAddedCheckpointer(int numJobs)
{
Uri srcUri = new("file:///foo/bar");
Uri dstUri = new("https://example.com/fizz/buzz");

(var jobsProcessor, var partsProcessor, var chunksProcessor) = StepProcessors();
JobBuilder jobBuilder = new(ArrayPool<byte>.Shared, default, new(ClientOptions.Default));
Mock<ITransferCheckpointer> checkpointer = new(MockBehavior.Loose);

(StorageResource srcResource, StorageResource dstResource, Func<IDisposable> srcThrowScope, Func<IDisposable> dstThrowScope)
= GetBasicSetupResources(false, srcUri, dstUri);

await using TransferManager transferManager = new(
jobsProcessor,
partsProcessor,
chunksProcessor,
jobBuilder,
checkpointer.Object,
default);

// Add jobs on separate Tasks
var loopResult = Parallel.For(0, numJobs, i =>
{
Task<DataTransfer> task = transferManager.StartTransferAsync(srcResource, dstResource);
});
Assert.That(jobsProcessor.ItemsInQueue, Is.EqualTo(numJobs), "Error during initial Job queueing.");
}
}

internal static partial class MockExtensions
Expand Down