Skip to content

Commit

Permalink
Add mechanism to respond with binary payload (#2118)
Browse files Browse the repository at this point in the history
To implement group member query API and client invocation in persistent mode, we need to add a mechanism to respond with binary payload from service.
* Move some util method in `ServiceProtocol` to a separate util class `MessagePackUtils`, so that they could be reused by other classes.
* Deprecate `string message` member in `AckMessage` and add `Payload` member. The `string message` member is never used. 
* Introduce `IMessagePackSerializable` interface, so that we could put the (de)serialization methods of model classes inside themselves.
* Refactor `AckHandler` to allow acking with binary payload
  • Loading branch information
Y-Sindo authored Jan 8, 2025
1 parent 0c8c5ef commit 38c636e
Show file tree
Hide file tree
Showing 8 changed files with 414 additions and 251 deletions.
102 changes: 68 additions & 34 deletions src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
using System;
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Buffers;
using System.Collections.Concurrent;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Common;
Expand All @@ -12,6 +17,7 @@ namespace Microsoft.Azure.SignalR
internal sealed class AckHandler : IDisposable
{
public static readonly AckHandler Singleton = new();
public static readonly ServiceProtocol _serviceProtocol = new();
private readonly ConcurrentDictionary<int, IAckInfo> _acks = new();
private readonly Timer _timer;
private readonly TimeSpan _defaultAckTimeout;
Expand All @@ -35,15 +41,27 @@ public Task<AckStatus> CreateSingleAck(out int id, TimeSpan? ackTimeout = defaul
{
return Task.FromResult(AckStatus.Ok);
}
var info = (IAckInfo<AckStatus>)_acks.GetOrAdd(id, _ => new SingleAckInfo(ackTimeout ?? _defaultAckTimeout));
if (info is MultiAckInfo)
var info = (IAckInfo<AckStatus>)_acks.GetOrAdd(id, _ => new SingleStatusAck(ackTimeout ?? _defaultAckTimeout));
if (info is MultiAckWithStatusInfo)
{
throw new InvalidOperationException();
}
cancellationToken.Register(() => info.Cancel());
return info.Task;
}

public Task<T> CreateSingleAck<T>(out int id, TimeSpan? ackTimeout = default, CancellationToken cancellationToken = default) where T : IMessagePackSerializable, new()
{
id = NextId();
if (_disposed)
{
return Task.FromResult(new T());
}
var info = (IAckInfo<IMessagePackSerializable>)_acks.GetOrAdd(id, _ => new SinglePayloadAck<T>(ackTimeout ?? _defaultAckTimeout));
cancellationToken.Register(info.Cancel);
return info.Task.ContinueWith(task => (T)task.Result);
}

public static bool HandleAckStatus(IAckableMessage message, AckStatus status)
{
return status switch
Expand All @@ -62,29 +80,19 @@ public Task<AckStatus> CreateMultiAck(out int id, TimeSpan? ackTimeout = default
{
return Task.FromResult(AckStatus.Ok);
}
var info = (IAckInfo<AckStatus>)_acks.GetOrAdd(id, _ => new MultiAckInfo(ackTimeout ?? _defaultAckTimeout));
if (info is SingleAckInfo)
var info = (IAckInfo<AckStatus>)_acks.GetOrAdd(id, _ => new MultiAckWithStatusInfo(ackTimeout ?? _defaultAckTimeout));
if (info is SingleAckInfo<AckStatus>)
{
throw new InvalidOperationException();
}
return info.Task;
}

public void TriggerAck(int id, AckStatus status = AckStatus.Ok)
public void TriggerAck(int id, AckStatus status = AckStatus.Ok, ReadOnlySequence<byte>? payload = default)
{
if (_acks.TryGetValue(id, out var info))
if (_acks.TryGetValue(id, out var info) && info.Ack(status, payload))
{
switch (info)
{
case IAckInfo<AckStatus> ackInfo:
if (ackInfo.Ack(status))
{
_acks.TryRemove(id, out _);
}
break;
default:
throw new InvalidCastException($"Expected: IAckInfo<{typeof(IAckInfo<AckStatus>).Name}>, actual type: {info.GetType().Name}");
}
_acks.TryRemove(id, out _);
}
}

Expand Down Expand Up @@ -125,16 +133,13 @@ private void CheckAcks()
{
if (_acks.TryRemove(id, out _))
{
if (ack is SingleAckInfo singleAckInfo)
{
singleAckInfo.Ack(AckStatus.Timeout);
}
else if (ack is MultiAckInfo multipleAckInfo)
if (ack is MultiAckWithStatusInfo multipleAckInfo)
{
multipleAckInfo.ForceAck(AckStatus.Timeout);
}
else
{
ack.Ack(AckStatus.Timeout);
ack.Cancel();
}
}
Expand Down Expand Up @@ -170,39 +175,68 @@ private interface IAckInfo
{
DateTime TimeoutAt { get; }
void Cancel();
bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = null);
}

private interface IAckInfo<T> : IAckInfo
{
Task<T> Task { get; }
bool Ack(T status);
}

public interface IMultiAckInfo
{
bool SetExpectedCount(int expectedCount);
}

private sealed class SingleAckInfo : IAckInfo<AckStatus>
private abstract class SingleAckInfo<T> : IAckInfo<T>
{
public readonly TaskCompletionSource<AckStatus> _tcs = new(TaskCreationOptions.RunContinuationsAsynchronously);

public readonly TaskCompletionSource<T> _tcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
public DateTime TimeoutAt { get; }

public SingleAckInfo(TimeSpan timeout)
{
TimeoutAt = DateTime.UtcNow + timeout;
}
public abstract bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = null);
public Task<T> Task => _tcs.Task;
public void Cancel() => _tcs.TrySetCanceled();
}

private class SingleStatusAck : SingleAckInfo<AckStatus>
{

public bool Ack(AckStatus status = AckStatus.Ok) =>
public SingleStatusAck(TimeSpan timeout) : base(timeout) { }

public override bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = null) =>
_tcs.TrySetResult(status);
}

public Task<AckStatus> Task => _tcs.Task;
private sealed class SinglePayloadAck<T> : SingleAckInfo<IMessagePackSerializable> where T : IMessagePackSerializable, new()
{
public SinglePayloadAck(TimeSpan timeout) : base(timeout) { }
public override bool Ack(AckStatus status, ReadOnlySequence<byte>? payload = null)
{
if (status == AckStatus.Timeout)
{
return _tcs.TrySetException(new TimeoutException($"Waiting for a {typeof(T).Name} response timed out."));
}
if (payload == null)
{
return _tcs.TrySetException(new InvalidDataException($"The expected payload is null."));
}

public void Cancel() => _tcs.TrySetCanceled();
try
{
var result = _serviceProtocol.ParseMessagePayload<T>(payload.Value);
return _tcs.TrySetResult(result);
}
catch (Exception e)
{
return _tcs.TrySetException(e);
}
}
}

private sealed class MultiAckInfo : IAckInfo<AckStatus>, IMultiAckInfo
private sealed class MultiAckWithStatusInfo : IAckInfo<AckStatus>, IMultiAckInfo
{
public readonly TaskCompletionSource<AckStatus> _tcs = new(TaskCreationOptions.RunContinuationsAsynchronously);

Expand All @@ -211,7 +245,7 @@ private sealed class MultiAckInfo : IAckInfo<AckStatus>, IMultiAckInfo

public DateTime TimeoutAt { get; }

public MultiAckInfo(TimeSpan timeout)
public MultiAckWithStatusInfo(TimeSpan timeout)
{
TimeoutAt = DateTime.UtcNow + timeout;
}
Expand Down Expand Up @@ -239,7 +273,7 @@ public bool SetExpectedCount(int expectedCount)
return result;
}

public bool Ack(AckStatus status = AckStatus.Ok)
public bool Ack(AckStatus status = AckStatus.Ok, ReadOnlySequence<byte>? payload = null)
{
bool result;
lock (_tcs)
Expand Down
Loading

0 comments on commit 38c636e

Please sign in to comment.