Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add group member query service message protocol #2118

Merged
merged 4 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 68 additions & 34 deletions src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
using System;
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Buffers;
using System.Collections.Concurrent;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Common;
Expand All @@ -12,6 +17,7 @@ namespace Microsoft.Azure.SignalR
internal sealed class AckHandler : IDisposable
{
public static readonly AckHandler Singleton = new();
public static readonly ServiceProtocol _serviceProtocol = new();
private readonly ConcurrentDictionary<int, IAckInfo> _acks = new();
private readonly Timer _timer;
private readonly TimeSpan _defaultAckTimeout;
Expand All @@ -35,15 +41,27 @@ public Task<AckStatus> CreateSingleAck(out int id, TimeSpan? ackTimeout = defaul
{
return Task.FromResult(AckStatus.Ok);
}
var info = (IAckInfo<AckStatus>)_acks.GetOrAdd(id, _ => new SingleAckInfo(ackTimeout ?? _defaultAckTimeout));
if (info is MultiAckInfo)
var info = (IAckInfo<AckStatus>)_acks.GetOrAdd(id, _ => new SingleStatusAck(ackTimeout ?? _defaultAckTimeout));
if (info is MultiAckWithStatusInfo)
{
throw new InvalidOperationException();
}
cancellationToken.Register(() => info.Cancel());
return info.Task;
}

public Task<T> CreateSingleAck<T>(out int id, TimeSpan? ackTimeout = default, CancellationToken cancellationToken = default) where T : IMessagePackSerializable, new()
{
id = NextId();
if (_disposed)
{
return Task.FromResult(new T());
}
var info = (IAckInfo<IMessagePackSerializable>)_acks.GetOrAdd(id, _ => new SinglePayloadAck<T>(ackTimeout ?? _defaultAckTimeout));
cancellationToken.Register(info.Cancel);
return info.Task.ContinueWith(task => (T)task.Result);
Y-Sindo marked this conversation as resolved.
Show resolved Hide resolved
}

public static bool HandleAckStatus(IAckableMessage message, AckStatus status)
{
return status switch
Expand All @@ -62,29 +80,19 @@ public Task<AckStatus> CreateMultiAck(out int id, TimeSpan? ackTimeout = default
{
return Task.FromResult(AckStatus.Ok);
}
var info = (IAckInfo<AckStatus>)_acks.GetOrAdd(id, _ => new MultiAckInfo(ackTimeout ?? _defaultAckTimeout));
if (info is SingleAckInfo)
var info = (IAckInfo<AckStatus>)_acks.GetOrAdd(id, _ => new MultiAckWithStatusInfo(ackTimeout ?? _defaultAckTimeout));
if (info is SingleAckInfo<AckStatus>)
{
throw new InvalidOperationException();
}
return info.Task;
}

public void TriggerAck(int id, AckStatus status = AckStatus.Ok)
public void TriggerAck(int id, AckStatus status = AckStatus.Ok, ReadOnlySequence<byte>? payload = default)
{
if (_acks.TryGetValue(id, out var info))
if (_acks.TryGetValue(id, out var info) && info.Ack(status, payload))
{
switch (info)
{
case IAckInfo<AckStatus> ackInfo:
if (ackInfo.Ack(status))
{
_acks.TryRemove(id, out _);
}
break;
default:
throw new InvalidCastException($"Expected: IAckInfo<{typeof(IAckInfo<AckStatus>).Name}>, actual type: {info.GetType().Name}");
}
_acks.TryRemove(id, out _);
}
Y-Sindo marked this conversation as resolved.
Show resolved Hide resolved
}

Expand Down Expand Up @@ -125,16 +133,13 @@ private void CheckAcks()
{
if (_acks.TryRemove(id, out _))
{
if (ack is SingleAckInfo singleAckInfo)
{
singleAckInfo.Ack(AckStatus.Timeout);
}
else if (ack is MultiAckInfo multipleAckInfo)
if (ack is MultiAckWithStatusInfo multipleAckInfo)
{
multipleAckInfo.ForceAck(AckStatus.Timeout);
}
else
{
ack.Ack(AckStatus.Timeout);
ack.Cancel();
}
}
Expand Down Expand Up @@ -170,39 +175,68 @@ private interface IAckInfo
{
DateTime TimeoutAt { get; }
void Cancel();
bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = null);
}

private interface IAckInfo<T> : IAckInfo
{
Task<T> Task { get; }
bool Ack(T status);
}

public interface IMultiAckInfo
{
bool SetExpectedCount(int expectedCount);
}

private sealed class SingleAckInfo : IAckInfo<AckStatus>
private abstract class SingleAckInfo<T> : IAckInfo<T>
{
public readonly TaskCompletionSource<AckStatus> _tcs = new(TaskCreationOptions.RunContinuationsAsynchronously);

public readonly TaskCompletionSource<T> _tcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
public DateTime TimeoutAt { get; }

public SingleAckInfo(TimeSpan timeout)
{
TimeoutAt = DateTime.UtcNow + timeout;
}
public abstract bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = null);
public Task<T> Task => _tcs.Task;
public void Cancel() => _tcs.TrySetCanceled();
}

private class SingleStatusAck : SingleAckInfo<AckStatus>
{

public bool Ack(AckStatus status = AckStatus.Ok) =>
public SingleStatusAck(TimeSpan timeout) : base(timeout) { }

public override bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = null) =>
_tcs.TrySetResult(status);
}

public Task<AckStatus> Task => _tcs.Task;
private sealed class SinglePayloadAck<T> : SingleAckInfo<IMessagePackSerializable> where T : IMessagePackSerializable, new()
{
public SinglePayloadAck(TimeSpan timeout) : base(timeout) { }
public override bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = null)
{
Y-Sindo marked this conversation as resolved.
Show resolved Hide resolved
if (status == AckStatus.Timeout)
{
return _tcs.TrySetException(new TimeoutException($"Waiting for a {typeof(T).Name} response timed out."));
}
if (payload == null)
{
return _tcs.TrySetException(new InvalidDataException($"The expected payload is null."));
}

public void Cancel() => _tcs.TrySetCanceled();
try
{
var result = _serviceProtocol.ParseMessagePayload<T>(payload.Value);
return _tcs.TrySetResult(result);
}
catch (Exception e)
{
return _tcs.TrySetException(e);
}
}
}

private sealed class MultiAckInfo : IAckInfo<AckStatus>, IMultiAckInfo
private sealed class MultiAckWithStatusInfo : IAckInfo<AckStatus>, IMultiAckInfo
{
public readonly TaskCompletionSource<AckStatus> _tcs = new(TaskCreationOptions.RunContinuationsAsynchronously);

Expand All @@ -211,7 +245,7 @@ private sealed class MultiAckInfo : IAckInfo<AckStatus>, IMultiAckInfo

public DateTime TimeoutAt { get; }

public MultiAckInfo(TimeSpan timeout)
public MultiAckWithStatusInfo(TimeSpan timeout)
{
TimeoutAt = DateTime.UtcNow + timeout;
}
Expand Down Expand Up @@ -239,7 +273,7 @@ public bool SetExpectedCount(int expectedCount)
return result;
}

public bool Ack(AckStatus status = AckStatus.Ok)
public bool Ack(AckStatus status = AckStatus.Ok, ReadOnlySequence<byte>? payload = null)
{
bool result;
lock (_tcs)
Expand Down
Loading
Loading