Skip to content

Commit

Permalink
Update IAsyncEnumerable implementation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub committed Jun 24, 2022
1 parent 05b43d4 commit 18aa3de
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,56 @@ public partial class TransformManyBlock<TInput, TOutput>
{
/// <summary>Initializes the <see cref="TransformManyBlock{TInput,TOutput}"/> with the specified function.</summary>
/// <param name="transform">
/// The function to invoke with each data element received. All of the data from the returned <see cref="System.Collections.Generic.IAsyncEnumerable{TOutput}"/>
/// The function to invoke with each data element received. All of the data from the returned <see cref="IAsyncEnumerable{TOutput}"/>
/// will be made available as output from this <see cref="TransformManyBlock{TInput,TOutput}"/>.
/// </param>
/// <exception cref="System.ArgumentNullException">The <paramref name="transform"/> is <see langword="null" />.</exception>
/// <exception cref="ArgumentNullException">The <paramref name="transform"/> is <see langword="null" />.</exception>
public TransformManyBlock(Func<TInput, IAsyncEnumerable<TOutput>> transform) :
this(transform, ExecutionDataflowBlockOptions.Default)
{ }
{
}

/// <summary>Initializes the <see cref="TransformManyBlock{TInput,TOutput}"/> with the specified function and <see cref="ExecutionDataflowBlockOptions"/>.</summary>
/// <param name="transform">
/// The function to invoke with each data element received. All of the data from the returned <see cref="System.Collections.Generic.IAsyncEnumerable{TOutput}"/>
/// The function to invoke with each data element received. All of the data from the returned <see cref="IAsyncEnumerable{TOutput}"/>
/// will be made available as output from this <see cref="TransformManyBlock{TInput,TOutput}"/>.
/// </param>
/// <param name="dataflowBlockOptions">The options with which to configure this <see cref="TransformManyBlock{TInput,TOutput}"/>.</param>
/// <exception cref="System.ArgumentNullException">The <paramref name="transform"/> or <paramref name="dataflowBlockOptions"/> is <see langword="null" />.</exception>
/// <exception cref="ArgumentNullException">The <paramref name="transform"/> or <paramref name="dataflowBlockOptions"/> is <see langword="null" />.</exception>
public TransformManyBlock(Func<TInput, IAsyncEnumerable<TOutput>> transform, ExecutionDataflowBlockOptions dataflowBlockOptions)
{
// Validate arguments.
if (transform == null) throw new ArgumentNullException(nameof(transform));
Initialize(messageWithId => ProcessMessage(transform, messageWithId), dataflowBlockOptions, ref _source!, ref _target!, ref _reorderingBuffer, TargetCoreOptions.UsesAsyncCompletion);
if (transform is null)
{
throw new ArgumentNullException(nameof(transform));
}

Initialize(messageWithId =>
{
Task t = ProcessMessageAsync(transform, messageWithId);
#if DEBUG
// Task returned from ProcessMessageAsync is explicitly ignored.
// That function handles all exceptions.
t.ContinueWith(t => Debug.Assert(t.IsCompletedSuccessfully), CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
#endif
}, dataflowBlockOptions, ref _source, ref _target, ref _reorderingBuffer, TargetCoreOptions.UsesAsyncCompletion);
}

// Note:
// Enumerating the IAsyncEnumerable is done with ConfigureAwait(true), using the default behavior of
// paying attention to the current context/scheduler. This makes it so that the enumerable code runs on the target scheduler.
// For this to work correctly, there can't be any ConfigureAwait(false) in the same method prior to
// these await foreach loops, nor in the call chain prior to the method invocation.

/// <summary>Processes the message with a user-provided transform function that returns an async enumerable.</summary>
/// <param name="transformFunction">The transform function to use to process the message.</param>
/// <param name="messageWithId">The message to be processed.</param>
private void ProcessMessage(Func<TInput, IAsyncEnumerable<TOutput>> transformFunction, KeyValuePair<TInput, long> messageWithId)
private async Task ProcessMessageAsync(Func<TInput, IAsyncEnumerable<TOutput>> transformFunction, KeyValuePair<TInput, long> messageWithId)
{
Debug.Assert(transformFunction != null, "Function to invoke is required.");

try
{
// Run the user transform and store the results.
IAsyncEnumerable<TOutput> outputItems = transformFunction(messageWithId.Key);
StoreOutputItemsAsync(messageWithId, outputItems).GetAwaiter().GetResult();
await StoreOutputItemsAsync(messageWithId, outputItems).ConfigureAwait(false);
}
catch (Exception exc)
{
Expand Down Expand Up @@ -77,12 +93,12 @@ private async Task StoreOutputItemsAsync(
{
// If there's a reordering buffer, pass the data along to it.
// The reordering buffer will handle all details, including bounding.
if (_reorderingBuffer != null)
if (_reorderingBuffer is not null)
{
await StoreOutputItemsReorderedAsync(messageWithId.Value, outputItems).ConfigureAwait(false);
}
// Otherwise, output the data directly.
else if (outputItems != null)
else if (outputItems is not null)
{
await StoreOutputItemsNonReorderedWithIterationAsync(outputItems).ConfigureAwait(false);
}
Expand All @@ -103,36 +119,32 @@ private async Task StoreOutputItemsAsync(
/// <param name="item">The async enumerable.</param>
private async Task StoreOutputItemsReorderedAsync(long id, IAsyncEnumerable<TOutput>? item)
{
Debug.Assert(_reorderingBuffer != null, "Expected a reordering buffer");
Debug.Assert(_reorderingBuffer is not null, "Expected a reordering buffer");
Debug.Assert(id != Common.INVALID_REORDERING_ID, "This ID should never have been handed out.");

// Grab info about the transform
TargetCore<TInput> target = _target;
bool isBounded = target.IsBounded;

// Handle invalid items (null enumerables) by delegating to the base
if (item == null)
if (item is null)
{
_reorderingBuffer.AddItem(id, null, false);
if (isBounded) target.ChangeBoundingCount(count: -1);
if (isBounded)
{
target.ChangeBoundingCount(count: -1);
}
return;
}

// Determine whether this id is the next item, and if it is and if we have a trusted list,
// try to output it immediately on the fast path. If it can be output, we're done.
// Otherwise, make forward progress based on whether we're next in line.
bool? isNextNullable = _reorderingBuffer.AddItemIfNextAndTrusted(id, null, false);
if (!isNextNullable.HasValue) return; // data was successfully output
bool isNextItem = isNextNullable.Value;

// By this point, either we're not the next item, in which case we need to make a copy of the
// data and store it, or we are the next item and can store it immediately but we need to enumerate
// the items and store them individually because we don't want to enumerate while holding a lock.
List<TOutput>? itemCopy = null;
try
{
// If this is the next item, we can output it now.
if (isNextItem)
if (_reorderingBuffer.IsNext(id))
{
await StoreOutputItemsNonReorderedWithIterationAsync(item).ConfigureAwait(false);
// here itemCopy remains null, so that base.AddItem will finish our interactions with the reordering buffer
Expand All @@ -145,7 +157,7 @@ private async Task StoreOutputItemsReorderedAsync(long id, IAsyncEnumerable<TOut
try
{
itemCopy = new List<TOutput>();
await foreach (TOutput element in item.ConfigureAwait(false))
await foreach (TOutput element in item.ConfigureAwait(true))
{
itemCopy.Add(element);
}
Expand All @@ -158,7 +170,10 @@ private async Task StoreOutputItemsReorderedAsync(long id, IAsyncEnumerable<TOut
// If we're here because ToList threw an exception, then itemCount will be 0,
// and we still need to update the bounding count with this in order to counteract
// the increased bounding count for the corresponding input.
if (isBounded) UpdateBoundingCountWithOutputCount(count: itemCount);
if (isBounded)
{
UpdateBoundingCountWithOutputCount(count: itemCount);
}
}
}
// else if the item isn't valid, the finally block will see itemCopy as null and output invalid
Expand All @@ -169,7 +184,7 @@ private async Task StoreOutputItemsReorderedAsync(long id, IAsyncEnumerable<TOut
// all of the data, itemCopy will be null, and we just pass down the invalid item.
// If we haven't, pass down the real thing. We do this even in the case of an exception,
// in which case this will be a dummy element.
_reorderingBuffer.AddItem(id, itemCopy, itemIsValid: itemCopy != null);
_reorderingBuffer.AddItem(id, itemCopy, itemIsValid: itemCopy is not null);
}
}

Expand All @@ -187,7 +202,7 @@ private async Task StoreOutputItemsNonReorderedWithIterationAsync(IAsyncEnumerab
// it guarantees that we're invoked serially, and we don't need to lock.
bool isSerial =
_target.DataflowBlockOptions.MaxDegreeOfParallelism == 1 ||
_reorderingBuffer != null;
_reorderingBuffer is not null;

// If we're bounding, we need to increment the bounded count
// for each individual item as we enumerate it.
Expand All @@ -200,10 +215,13 @@ private async Task StoreOutputItemsNonReorderedWithIterationAsync(IAsyncEnumerab
bool outputFirstItem = false;
try
{
await foreach (TOutput item in outputItems.ConfigureAwait(false))
await foreach (TOutput item in outputItems.ConfigureAwait(true))
{
if (outputFirstItem) _target.ChangeBoundingCount(count: 1);
else outputFirstItem = true;
if (outputFirstItem)
{
_target.ChangeBoundingCount(count: 1);
}
outputFirstItem = true;

if (isSerial)
{
Expand All @@ -220,20 +238,25 @@ private async Task StoreOutputItemsNonReorderedWithIterationAsync(IAsyncEnumerab
}
finally
{
if (!outputFirstItem) _target.ChangeBoundingCount(count: -1);
if (!outputFirstItem)
{
_target.ChangeBoundingCount(count: -1);
}
}
}
// If we're not bounding, just output each individual item.
else
{
if (isSerial)
{
await foreach (TOutput item in outputItems.ConfigureAwait(false))
await foreach (TOutput item in outputItems.ConfigureAwait(true))
{
_source.AddMessage(item);
}
}
else
{
await foreach (TOutput item in outputItems.ConfigureAwait(false))
await foreach (TOutput item in outputItems.ConfigureAwait(true))
{
lock (ParallelSourceLock) // don't hold lock while enumerating
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ public sealed partial class TransformManyBlock<TInput, TOutput> : IPropagatorBlo
/// <exception cref="System.ArgumentNullException">The <paramref name="transform"/> is null (Nothing in Visual Basic).</exception>
public TransformManyBlock(Func<TInput, IEnumerable<TOutput>> transform) :
this(transform, ExecutionDataflowBlockOptions.Default)
{ }
{
}

/// <summary>Initializes the <see cref="TransformManyBlock{TInput,TOutput}"/> with the specified function and <see cref="ExecutionDataflowBlockOptions"/>.</summary>
/// <param name="transform">
Expand All @@ -66,9 +67,8 @@ public TransformManyBlock(Func<TInput, IEnumerable<TOutput>> transform) :
/// <exception cref="System.ArgumentNullException">The <paramref name="dataflowBlockOptions"/> is null (Nothing in Visual Basic).</exception>
public TransformManyBlock(Func<TInput, IEnumerable<TOutput>> transform, ExecutionDataflowBlockOptions dataflowBlockOptions)
{
// Validate arguments.
if (transform == null) throw new ArgumentNullException(nameof(transform));
Initialize(messageWithId => ProcessMessage(transform, messageWithId), dataflowBlockOptions, ref _source!, ref _target!, ref _reorderingBuffer, TargetCoreOptions.None);
Initialize(messageWithId => ProcessMessage(transform, messageWithId), dataflowBlockOptions, ref _source, ref _target, ref _reorderingBuffer, TargetCoreOptions.None);
}

/// <summary>Initializes the <see cref="TransformManyBlock{TInput,TOutput}"/> with the specified function.</summary>
Expand All @@ -91,20 +91,18 @@ public TransformManyBlock(Func<TInput, Task<IEnumerable<TOutput>>> transform) :
/// <exception cref="System.ArgumentNullException">The <paramref name="dataflowBlockOptions"/> is null (Nothing in Visual Basic).</exception>
public TransformManyBlock(Func<TInput, Task<IEnumerable<TOutput>>> transform, ExecutionDataflowBlockOptions dataflowBlockOptions)
{
// Validate arguments.
if (transform == null) throw new ArgumentNullException(nameof(transform));
Initialize(messageWithId => ProcessMessageWithTask(transform, messageWithId), dataflowBlockOptions, ref _source!, ref _target!, ref _reorderingBuffer, TargetCoreOptions.UsesAsyncCompletion);
Initialize(messageWithId => ProcessMessageWithTask(transform, messageWithId), dataflowBlockOptions, ref _source, ref _target, ref _reorderingBuffer, TargetCoreOptions.UsesAsyncCompletion);
}

private void Initialize(
Action<KeyValuePair<TInput, long>> processMessageAction,
ExecutionDataflowBlockOptions dataflowBlockOptions,
ref SourceCore<TOutput> source,
ref TargetCore<TInput> target,
[NotNull] ref SourceCore<TOutput>? source,
[NotNull] ref TargetCore<TInput>? target,
ref ReorderingBuffer<IEnumerable<TOutput>>? reorderingBuffer,
TargetCoreOptions targetCoreOptions)
{
// Validate arguments.
if (dataflowBlockOptions == null) throw new ArgumentNullException(nameof(dataflowBlockOptions));

// Ensure we have options that can't be changed by the caller
Expand All @@ -113,7 +111,9 @@ private void Initialize(
// Initialize onItemsRemoved delegate if necessary
Action<ISourceBlock<TOutput>, int>? onItemsRemoved = null;
if (dataflowBlockOptions.BoundedCapacity > 0)
{
onItemsRemoved = (owningSource, count) => ((TransformManyBlock<TInput, TOutput>)owningSource)._target.ChangeBoundingCount(-count);
}

// Initialize source component
source = new SourceCore<TOutput>(this, dataflowBlockOptions,
Expand Down Expand Up @@ -169,8 +169,6 @@ private void Initialize(
/// <param name="messageWithId">The message to be processed.</param>
private void ProcessMessage(Func<TInput, IEnumerable<TOutput>> transformFunction, KeyValuePair<TInput, long> messageWithId)
{
Debug.Assert(transformFunction != null, "Function to invoke is required.");

bool userDelegateSucceeded = false;
try
{
Expand All @@ -179,10 +177,9 @@ private void ProcessMessage(Func<TInput, IEnumerable<TOutput>> transformFunction
userDelegateSucceeded = true;
StoreOutputItems(messageWithId, outputItems);
}
catch (Exception exc)
catch (Exception exc) when (Common.IsCooperativeCancellation(exc))
{
// If this exception represents cancellation, swallow it rather than shutting down the block.
if (!Common.IsCooperativeCancellation(exc)) throw;
}
finally
{
Expand All @@ -197,8 +194,6 @@ private void ProcessMessage(Func<TInput, IEnumerable<TOutput>> transformFunction
/// <param name="messageWithId">The message to be processed.</param>
private void ProcessMessageWithTask(Func<TInput, Task<IEnumerable<TOutput>>> function, KeyValuePair<TInput, long> messageWithId)
{
Debug.Assert(function != null, "Function to invoke is required.");

// Run the transform function to get the resulting task
Task<IEnumerable<TOutput>>? task = null;
Exception? caughtException = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ internal void AddItem(long id, TOutput? item, bool itemIsValid)
}
}

/// <summary>Determines whether the specified id is next to be output.</summary>
/// <param name="id">The id of the item.</param>
/// <returns>true if the item is next in line; otherwise, false.</returns>
internal bool IsNext(long id)
{
Debug.Assert(id != Common.INVALID_REORDERING_ID, "This ID should never have been handed out.");
Common.ContractAssertMonitorStatus(ValueLock, held: false);

lock (ValueLock)
{
return _nextReorderedIdToOutput == id;
}
}

/// <summary>
/// Determines whether the specified id is next to be output, and if it is
/// and if the item is "trusted" (meaning it may be output into the output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ internal static partial class DataflowTestHelpers

internal static partial class AsyncEnumerable
{
#pragma warning disable 1998
internal static async IAsyncEnumerable<int> Repeat(int item, int count)
{
for (int i = 0; i < count; i++)
{
await Task.Yield();
yield return item;
}
}
Expand All @@ -26,6 +26,7 @@ internal static async IAsyncEnumerable<int> Range(int start, int count)
var end = start + count;
for (int i = start; i < end; i++)
{
await Task.Yield();
yield return i;
}
}
Expand All @@ -34,9 +35,9 @@ internal static async IAsyncEnumerable<T> ToAsyncEnumerable<T>(this IEnumerable<
{
foreach (T item in enumerable)
{
await Task.Yield();
yield return item;
}
}
#pragma warning restore 1998
}
}
Loading

0 comments on commit 18aa3de

Please sign in to comment.