From 5b63da2372ebc654442642226c96118e7bd4c3cb Mon Sep 17 00:00:00 2001 From: Benjamin Petit Date: Thu, 7 Apr 2022 13:23:01 +0200 Subject: [PATCH 1/3] Broadcast channel to replace SMS --- Orleans.sln | 7 + .../BroadcastChannel.cs | 93 +++++ .../BroadcastChannelConsumerExtension.cs | 99 ++++++ .../BroadcastChannelOptions.cs | 8 + .../BroadcastChannelProvider.cs | 55 +++ .../BroadcastChannelSubscription.cs | 43 +++ src/Orleans.BroadcastChannel/ChannelId.cs | 288 +++++++++++++++ .../Hosting/ChannelHostingExtensions.cs | 56 +++ .../IdMapping/DefaultChannelIdMapper.cs | 79 +++++ .../IdMapping/IChannelIdMapper.cs | 19 + .../Orleans.BroadcastChannel.csproj | 14 + .../ImplicitChannelSubscriberTable.cs | 329 ++++++++++++++++++ .../AllStreamNamespacesPredicate.cs | 17 + ...DefaultStreamNamespacePredicateProvider.cs | 94 +++++ .../ExactMatchStreamNamespacePredicate.cs | 35 ++ .../Predicates/IChannelNamespacePredicate.cs | 35 ++ .../ImplicitChannelSubscriptionAttribute.cs | 118 +++++++ .../RegexChannelNamespacePredicate.cs | 35 ++ .../Manifest/GrainProperties.cs | 15 + src/Orleans.Core/Properties/AssemblyInfo.cs | 1 + .../GrainStreamingExtensions.cs | 1 + .../SimpleMessageStreamProvider.cs | 1 + .../SimpleStreams/SimpleSubscriberGrain.cs | 79 +++++ .../TestGrains/StreamInterceptionGrain.cs | 2 +- test/Grains/TestGrains/TestGrains.csproj | 1 + test/Tester/GrainCallFilterTests.cs | 1 + .../BroadcastChannelTests.cs | 261 ++++++++++++++ .../StreamingTests/SampleStreamingTests.cs | 1 + test/Tester/Tester.csproj | 1 + 29 files changed, 1787 insertions(+), 1 deletion(-) create mode 100644 src/Orleans.BroadcastChannel/BroadcastChannel.cs create mode 100644 src/Orleans.BroadcastChannel/BroadcastChannelConsumerExtension.cs create mode 100644 src/Orleans.BroadcastChannel/BroadcastChannelOptions.cs create mode 100644 src/Orleans.BroadcastChannel/BroadcastChannelProvider.cs create mode 100644 src/Orleans.BroadcastChannel/BroadcastChannelSubscription.cs create mode 100644 src/Orleans.BroadcastChannel/ChannelId.cs create mode 100644 src/Orleans.BroadcastChannel/Hosting/ChannelHostingExtensions.cs create mode 100644 src/Orleans.BroadcastChannel/IdMapping/DefaultChannelIdMapper.cs create mode 100644 src/Orleans.BroadcastChannel/IdMapping/IChannelIdMapper.cs create mode 100644 src/Orleans.BroadcastChannel/Orleans.BroadcastChannel.csproj create mode 100644 src/Orleans.BroadcastChannel/SubscriberTable/ImplicitChannelSubscriberTable.cs create mode 100644 src/Orleans.BroadcastChannel/SubscriberTable/Predicates/AllStreamNamespacesPredicate.cs create mode 100644 src/Orleans.BroadcastChannel/SubscriberTable/Predicates/DefaultStreamNamespacePredicateProvider.cs create mode 100644 src/Orleans.BroadcastChannel/SubscriberTable/Predicates/ExactMatchStreamNamespacePredicate.cs create mode 100644 src/Orleans.BroadcastChannel/SubscriberTable/Predicates/IChannelNamespacePredicate.cs create mode 100644 src/Orleans.BroadcastChannel/SubscriberTable/Predicates/ImplicitChannelSubscriptionAttribute.cs create mode 100644 src/Orleans.BroadcastChannel/SubscriberTable/Predicates/RegexChannelNamespacePredicate.cs create mode 100644 test/Grains/TestGrains/SimpleStreams/SimpleSubscriberGrain.cs create mode 100644 test/Tester/StreamingTests/BroadcastChannels/BroadcastChannelTests.cs diff --git a/Orleans.sln b/Orleans.sln index 4d267e045b..be811b4a99 100644 --- a/Orleans.sln +++ b/Orleans.sln @@ -203,6 +203,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "DistributedTests.Server", " EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Orleans.Serialization.SystemTextJson", "src\Orleans.Serialization.SystemTextJson\Orleans.Serialization.SystemTextJson.csproj", "{5CFBC7AC-C9AE-4C6C-943C-4A157396E427}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Orleans.BroadcastChannel", "src\Orleans.BroadcastChannel\Orleans.BroadcastChannel.csproj", "{497D472A-0BA8-4306-A110-C4D871FD5918}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -545,6 +547,10 @@ Global {5CFBC7AC-C9AE-4C6C-943C-4A157396E427}.Debug|Any CPU.Build.0 = Debug|Any CPU {5CFBC7AC-C9AE-4C6C-943C-4A157396E427}.Release|Any CPU.ActiveCfg = Release|Any CPU {5CFBC7AC-C9AE-4C6C-943C-4A157396E427}.Release|Any CPU.Build.0 = Release|Any CPU + {497D472A-0BA8-4306-A110-C4D871FD5918}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {497D472A-0BA8-4306-A110-C4D871FD5918}.Debug|Any CPU.Build.0 = Debug|Any CPU + {497D472A-0BA8-4306-A110-C4D871FD5918}.Release|Any CPU.ActiveCfg = Release|Any CPU + {497D472A-0BA8-4306-A110-C4D871FD5918}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -646,6 +652,7 @@ Global {25D20278-8901-47CC-AD1D-F3C4BEB845BF} = {FFEC9FEE-FEDF-4510-B7D2-0B0B3374ED2F} {E8335DC9-9A7F-45C1-AFA3-0AA93ABD4FA5} = {FFEC9FEE-FEDF-4510-B7D2-0B0B3374ED2F} {5CFBC7AC-C9AE-4C6C-943C-4A157396E427} = {4CD3AA9E-D937-48CA-BB6C-158E12257D23} + {497D472A-0BA8-4306-A110-C4D871FD5918} = {4CD3AA9E-D937-48CA-BB6C-158E12257D23} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {7BFB3429-B5BB-4DB1-95B4-67D77A864952} diff --git a/src/Orleans.BroadcastChannel/BroadcastChannel.cs b/src/Orleans.BroadcastChannel/BroadcastChannel.cs new file mode 100644 index 0000000000..0bc8b46db6 --- /dev/null +++ b/src/Orleans.BroadcastChannel/BroadcastChannel.cs @@ -0,0 +1,93 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Orleans.BroadcastChannel.SubscriberTable; +using Orleans.Providers; +using Orleans.Runtime; + +namespace Orleans.BroadcastChannel +{ + public interface IBroadcastChannel + { + Task Publish(T item); + } + + internal class BroadcastChannel : IBroadcastChannel + { + private readonly InternalChannelId _channelId; + private readonly IGrainFactory _grainFactory; + private readonly ImplicitChannelSubscriberTable _subscriberTable; + private readonly bool _fireAndForgetDelivery; + private readonly ILogger _logger; + + public BroadcastChannel( + InternalChannelId channelId, + IGrainFactory grainFactory, + ImplicitChannelSubscriberTable subscriberTable, + bool fireAndForgetDelivery, + ILoggerFactory loggerFactory) + { + _channelId = channelId; + _grainFactory = grainFactory; + _subscriberTable = subscriberTable; + _fireAndForgetDelivery = fireAndForgetDelivery; + _logger = loggerFactory.CreateLogger($"{nameof(BroadcastChannel)}-{_channelId}"); + } + + public async Task Publish(T item) + { + var subscribers = _subscriberTable.GetImplicitSubscribers(_channelId, _grainFactory); + + if (subscribers.Count == 0) + { + if (_logger.IsEnabled(LogLevel.Debug)) _logger.LogDebug("No consumer found for {Item}", item); + return; + } + + if (_logger.IsEnabled(LogLevel.Debug)) _logger.LogDebug("Publishing item {Item} to {ConsumerCount} consumers", item, subscribers.Count); + + if (_fireAndForgetDelivery) + { + foreach (var sub in subscribers) + { + PublishToSubscriber(sub.Value, item).Ignore(); + } + } + else + { + var tasks = new List(); + foreach (var sub in subscribers) + { + tasks.Add(PublishToSubscriber(sub.Value, item)); + } + try + { + await Task.WhenAll(tasks); + } + catch (Exception) + { + throw new AggregateException(tasks.Select(t => t.Exception).Where(ex => ex != null)); + } + } + } + + private async Task PublishToSubscriber(IBroadcastChannelConsumerExtension consumer, T item) + { + try + { + await consumer.OnPublished(_channelId, item); + } + catch (Exception ex) + { + _logger.LogError(ex, "Exception when sending item to {GrainId}", consumer.GetGrainId()); + if (!_fireAndForgetDelivery) + { + throw; + } + } + } + } +} + diff --git a/src/Orleans.BroadcastChannel/BroadcastChannelConsumerExtension.cs b/src/Orleans.BroadcastChannel/BroadcastChannelConsumerExtension.cs new file mode 100644 index 0000000000..844b96feca --- /dev/null +++ b/src/Orleans.BroadcastChannel/BroadcastChannelConsumerExtension.cs @@ -0,0 +1,99 @@ +using System; +using System.Collections.Concurrent; +using System.Threading.Tasks; +using Orleans.Runtime; + +namespace Orleans.BroadcastChannel +{ + internal interface IBroadcastChannelConsumerExtension : IGrainExtension + { + Task OnError(InternalChannelId streamId, Exception exception); + Task OnPublished(InternalChannelId streamId, object item); + } + + internal class BroadcastChannelConsumerExtension : IBroadcastChannelConsumerExtension + { + private readonly ConcurrentDictionary _handlers = new(); + private readonly IOnBroadcastChannelSubscribed _subscriptionObserver; + private AsyncLock _lock = new AsyncLock(); + + private interface ICallback + { + Task OnError(Exception exception); + + Task OnPublished(object item); + } + + private class Callback : ICallback + { + private readonly Func _onPublished; + private readonly Func _onError; + + private static Task NoOp(Exception _) => Task.CompletedTask; + + public Callback(Func onPublished, Func onError) + { + _onPublished = onPublished; + _onError = onError ?? NoOp; + } + + public Task OnError(Exception exception) => _onError(exception); + + public Task OnPublished(object item) + { + return item is T typedItem + ? _onPublished(typedItem) + : _onError(new InvalidCastException($"Received an item of type {item.GetType().Name}, expected {typeof(T).FullName}")); + } + } + + public BroadcastChannelConsumerExtension(IGrainContextAccessor grainContextAccessor) + { + _subscriptionObserver = grainContextAccessor.GrainContext?.GrainInstance as IOnBroadcastChannelSubscribed; + } + + public async Task OnError(InternalChannelId streamId, Exception exception) + { + var callback = await GetStreamCallback(streamId); + if (callback != default) + { + await callback.OnError(exception); + } + } + + public async Task OnPublished(InternalChannelId streamId, object item) + { + var callback = await GetStreamCallback(streamId); + if (callback != default) + { + await callback.OnPublished(item); + } + } + + public void Attach(InternalChannelId streamId, Func onPublished, Func onError) + { + _handlers.TryAdd(streamId, new Callback(onPublished, onError)); + } + + private async ValueTask GetStreamCallback(InternalChannelId streamId) + { + ICallback callback; + if (_handlers.TryGetValue(streamId, out callback)) + { + return callback; + } + using (await _lock.LockAsync()) + { + if (_handlers.TryGetValue(streamId, out callback)) + { + return callback; + } + var subscription = new BroadcastChannelSubscription(this, streamId); + await _subscriptionObserver.OnSubscribed(subscription); + } + _handlers.TryGetValue(streamId, out callback); + return callback; + } + } +} + diff --git a/src/Orleans.BroadcastChannel/BroadcastChannelOptions.cs b/src/Orleans.BroadcastChannel/BroadcastChannelOptions.cs new file mode 100644 index 0000000000..6b919d2e2a --- /dev/null +++ b/src/Orleans.BroadcastChannel/BroadcastChannelOptions.cs @@ -0,0 +1,8 @@ +namespace Orleans.BroadcastChannel +{ + public class BroadcastChannelOptions + { + public bool FireAndForgetDelivery { get; set; } = true; + } +} + diff --git a/src/Orleans.BroadcastChannel/BroadcastChannelProvider.cs b/src/Orleans.BroadcastChannel/BroadcastChannelProvider.cs new file mode 100644 index 0000000000..dcde7413ec --- /dev/null +++ b/src/Orleans.BroadcastChannel/BroadcastChannelProvider.cs @@ -0,0 +1,55 @@ +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Orleans.BroadcastChannel.SubscriberTable; +using Orleans.Providers; +using Orleans.Runtime; + +namespace Orleans.BroadcastChannel +{ + public interface IBroadcastChannelProvider + { + IBroadcastChannel GetChannel(ChannelId streamId); + } + + internal class BroadcastChannelProvider : IBroadcastChannelProvider + { + private readonly string _providerName; + private readonly BroadcastChannelOptions _options; + private readonly IGrainFactory _grainFactory; + private readonly ImplicitChannelSubscriberTable _subscriberTable; + private readonly ILoggerFactory _loggerFactory; + + public BroadcastChannelProvider( + string providerName, + BroadcastChannelOptions options, + IGrainFactory grainFactory, + ImplicitChannelSubscriberTable subscriberTable, + ILoggerFactory loggerFactory) + { + _providerName = providerName; + _options = options; + _grainFactory = grainFactory; + _subscriberTable = subscriberTable; + _loggerFactory = loggerFactory; + } + + public IBroadcastChannel GetChannel(ChannelId streamId) + { + return new BroadcastChannel( + new InternalChannelId(_providerName, streamId), + _grainFactory, + _subscriberTable, + _options.FireAndForgetDelivery, + _loggerFactory); + } + + public static IBroadcastChannelProvider Create(IServiceProvider sp, string name) + { + var opt = sp.GetOptionsByName(name); + return ActivatorUtilities.CreateInstance(sp, name, sp.GetOptionsByName(name)); + } + } +} + diff --git a/src/Orleans.BroadcastChannel/BroadcastChannelSubscription.cs b/src/Orleans.BroadcastChannel/BroadcastChannelSubscription.cs new file mode 100644 index 0000000000..afe6f32331 --- /dev/null +++ b/src/Orleans.BroadcastChannel/BroadcastChannelSubscription.cs @@ -0,0 +1,43 @@ +using System; +using System.Threading.Tasks; +using Orleans.Runtime; + +namespace Orleans.BroadcastChannel +{ + public interface IBroadcastChannelSubscription + { + public ChannelId ChannelId { get; } + + public string ProviderName { get; } + + Task Attach(Func onPublished, Func onError = null); + } + + public interface IOnBroadcastChannelSubscribed + { + public Task OnSubscribed(IBroadcastChannelSubscription streamSubscription); + } + + internal class BroadcastChannelSubscription : IBroadcastChannelSubscription + { + private readonly BroadcastChannelConsumerExtension _consumerExtension; + private readonly InternalChannelId _streamId; + + public ChannelId ChannelId => _streamId.ChannelId; + + public string ProviderName => _streamId.ProviderName; + + public BroadcastChannelSubscription(BroadcastChannelConsumerExtension consumerExtension, InternalChannelId streamId) + { + _consumerExtension = consumerExtension; + _streamId = streamId; + } + + public Task Attach(Func onPublished, Func onError = null) + { + _consumerExtension.Attach(_streamId, onPublished, onError); + return Task.CompletedTask; + } + } +} + diff --git a/src/Orleans.BroadcastChannel/ChannelId.cs b/src/Orleans.BroadcastChannel/ChannelId.cs new file mode 100644 index 0000000000..e132bf8038 --- /dev/null +++ b/src/Orleans.BroadcastChannel/ChannelId.cs @@ -0,0 +1,288 @@ +using System; +using System.Buffers.Text; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Runtime.Serialization; +using System.Text; + +namespace Orleans.BroadcastChannel +{ + /// + /// Identifies a Channel within a provider + /// + [Immutable] + [Serializable] + [StructLayout(LayoutKind.Auto)] + [GenerateSerializer] + public readonly struct ChannelId : IEquatable, IComparable, ISerializable + { + [Id(0)] + private readonly byte[] fullKey; + + [Id(1)] + private readonly ushort keyIndex; + + [Id(2)] + private readonly int hash; + + /// + /// Gets the full key. + /// + /// The full key. + public ReadOnlyMemory FullKey => fullKey; + + /// + /// Gets the namespace. + /// + /// The namespace. + public ReadOnlyMemory Namespace => fullKey.AsMemory(0, this.keyIndex); + + /// + /// Gets the key. + /// + /// The key. + public ReadOnlyMemory Key => fullKey.AsMemory(this.keyIndex); + + private ChannelId(byte[] fullKey, ushort keyIndex, int hash) + { + this.fullKey = fullKey; + this.keyIndex = keyIndex; + this.hash = hash; + } + + internal ChannelId(byte[] fullKey, ushort keyIndex) + : this(fullKey, keyIndex, (int)JenkinsHash.ComputeHash(fullKey)) + { + } + + private ChannelId(SerializationInfo info, StreamingContext context) + { + fullKey = (byte[])info.GetValue("fk", typeof(byte[])); + this.keyIndex = info.GetUInt16("ki"); + this.hash = info.GetInt32("fh"); + } + + + /// + /// Initializes a new instance of the struct. + /// + /// The namespace. + /// The key. + public static ChannelId Create(byte[] ns, byte[] key) + { + if (key == null) + throw new ArgumentNullException(nameof(key)); + + if (ns != null) + { + var fullKeysBytes = new byte[ns.Length + key.Length]; + ns.CopyTo(fullKeysBytes.AsSpan()); + key.CopyTo(fullKeysBytes.AsSpan(ns.Length)); + return new ChannelId(fullKeysBytes, (ushort)ns.Length); + } + else + { + return new ChannelId((byte[])key.Clone(), 0); + } + } + + /// + /// Initializes a new instance of the struct. + /// + /// The namespace. + /// The key. + public static ChannelId Create(string ns, Guid key) + { + if (ns is null) + { + var buf = new byte[32]; + Utf8Formatter.TryFormat(key, buf, out var len, 'N'); + Debug.Assert(len == 32); + return new ChannelId(buf, 0); + } + else + { + var nsLen = Encoding.UTF8.GetByteCount(ns); + var buf = new byte[nsLen + 32]; + Encoding.UTF8.GetBytes(ns, 0, ns.Length, buf, 0); + Utf8Formatter.TryFormat(key, buf.AsSpan(nsLen), out var len, 'N'); + Debug.Assert(len == 32); + return new ChannelId(buf, (ushort)nsLen); + } + } + + /// + /// Initializes a new instance of the struct. + /// + /// The namespace. + /// The key. + public static ChannelId Create(string ns, string key) + { + if (ns is null) + return new ChannelId(Encoding.UTF8.GetBytes(key), 0); + + var nsLen = Encoding.UTF8.GetByteCount(ns); + var keyLen = Encoding.UTF8.GetByteCount(key); + var buf = new byte[nsLen + keyLen]; + Encoding.UTF8.GetBytes(ns, 0, ns.Length, buf, 0); + Encoding.UTF8.GetBytes(key, 0, key.Length, buf, nsLen); + return new ChannelId(buf, (ushort)nsLen); + } + + /// + public int CompareTo(ChannelId other) => fullKey.AsSpan().SequenceCompareTo(other.fullKey); + + /// + public bool Equals(ChannelId other) => fullKey.AsSpan().SequenceEqual(other.fullKey); + + /// + public override bool Equals(object obj) => obj is ChannelId other ? this.Equals(other) : false; + + /// + /// Compares two instances for equality. + /// + /// The first stream identity. + /// The second stream identity. + /// The result of the operator. + public static bool operator ==(ChannelId s1, ChannelId s2) => s1.Equals(s2); + + /// + /// Compares two instances for equality. + /// + /// The first stream identity. + /// The second stream identity. + /// The result of the operator. + public static bool operator !=(ChannelId s1, ChannelId s2) => !s2.Equals(s1); + + /// + public void GetObjectData(SerializationInfo info, StreamingContext context) + { + info.AddValue("fk", fullKey); + info.AddValue("ki", this.keyIndex); + info.AddValue("fh", this.hash); + } + + /// + public override string ToString() + { + var key = this.GetKeyAsString(); + return keyIndex == 0 ? "null/" + key : this.GetNamespace() + "/" + key; + } + + /// + /// Parses a instance from a . + /// + /// The value. + /// The parsed stream identity. + public static ChannelId Parse(string value) + { + if (string.IsNullOrWhiteSpace(value)) + { + ThrowInvalidInternalStreamId(value); + } + + var i = value.IndexOf('/'); + if (i < 0) + { + ThrowInvalidInternalStreamId(value); + } + + return Create(value.Substring(0, i), value.Substring(i + 1)); + } + + private static void ThrowInvalidInternalStreamId(string value) => throw new ArgumentException($"Unable to parse \"{value}\" as a stream id"); + + /// + public override int GetHashCode() => this.hash; + + /// + /// Returns the component of this instance as a string. + /// + /// The key component of this instance. + public string GetKeyAsString() => Encoding.UTF8.GetString(fullKey, keyIndex, fullKey.Length - keyIndex); + + /// + /// Returns the component of this instance as a string. + /// + /// The namespace component of this instance. + public string GetNamespace() => keyIndex == 0 ? null : Encoding.UTF8.GetString(fullKey, 0, keyIndex); + } + + [Immutable] + [Serializable] + [StructLayout(LayoutKind.Auto)] + [GenerateSerializer] + internal readonly struct InternalChannelId : IEquatable, IComparable, ISerializable + { + [Id(0)] + public ChannelId ChannelId { get; } + + [Id(1)] + public string ProviderName { get; } + + public InternalChannelId(string providerName, ChannelId streamId) + { + ProviderName = providerName; + ChannelId = streamId; + } + + private InternalChannelId(SerializationInfo info, StreamingContext context) + { + ProviderName = info.GetString("pvn"); + ChannelId = (ChannelId)info.GetValue("sid", typeof(ChannelId)); + } + + public static implicit operator ChannelId(InternalChannelId internalStreamId) => internalStreamId.ChannelId; + + public bool Equals(InternalChannelId other) => ChannelId.Equals(other) && ProviderName.Equals(other.ProviderName); + + public override bool Equals(object obj) => obj is InternalChannelId other ? this.Equals(other) : false; + + public static bool operator ==(InternalChannelId s1, InternalChannelId s2) => s1.Equals(s2); + + public static bool operator !=(InternalChannelId s1, InternalChannelId s2) => !s2.Equals(s1); + + public int CompareTo(InternalChannelId other) => ChannelId.CompareTo(other.ChannelId); + + public void GetObjectData(SerializationInfo info, StreamingContext context) + { + info.AddValue("pvn", ProviderName); + info.AddValue("sid", ChannelId, typeof(ChannelId)); + } + + public override int GetHashCode() + { + unchecked + { + return ProviderName.GetHashCode() * 43 ^ ChannelId.GetHashCode(); + } + } + + public override string ToString() + { + return $"{ProviderName}/{ChannelId.ToString()}"; + } + + public static InternalChannelId Parse(string value) + { + if (string.IsNullOrWhiteSpace(value)) + { + ThrowInvalidInternalStreamId(value); + } + + var i = value.IndexOf('/'); + if (i < 0) + { + ThrowInvalidInternalStreamId(value); + } + + return new InternalChannelId(value.Substring(0, i), ChannelId.Parse(value.Substring(i + 1))); + } + + private static void ThrowInvalidInternalStreamId(string value) => throw new ArgumentException($"Unable to parse \"{value}\" as a stream id"); + + + internal string GetNamespace() => ChannelId.GetNamespace(); + } +} \ No newline at end of file diff --git a/src/Orleans.BroadcastChannel/Hosting/ChannelHostingExtensions.cs b/src/Orleans.BroadcastChannel/Hosting/ChannelHostingExtensions.cs new file mode 100644 index 0000000000..2f44c0ce15 --- /dev/null +++ b/src/Orleans.BroadcastChannel/Hosting/ChannelHostingExtensions.cs @@ -0,0 +1,56 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Orleans.BroadcastChannel; +using Orleans.BroadcastChannel.SubscriberTable; +using Orleans.Configuration; +using Orleans.Hosting; +using Orleans.Runtime; + +namespace Orleans.Hosting +{ + public static class ChannelHostingExtensions + { + public static ISiloBuilder AddBroadcastChannel(this ISiloBuilder @this, string name, Action configureOptions) + { + @this.Services.AddBroadcastChannel(name, ob => ob.Configure(configureOptions)); + return @this; + } + + public static ISiloBuilder AddBroadcastChannel(this ISiloBuilder @this, string name, Action> configureOptions = null) + { + @this.Services.AddBroadcastChannel(name, configureOptions); + @this.AddGrainExtension(); + return @this; + } + + public static IClientBuilder AddBroadcastChannel(this IClientBuilder @this, string name, Action configureOptions) + { + @this.Services.AddBroadcastChannel(name, ob => ob.Configure(configureOptions)); + return @this; + } + + public static IClientBuilder AddBroadcastChannel(this IClientBuilder @this, string name, Action> configureOptions = null) + { + @this.Services.AddBroadcastChannel(name, configureOptions); + return @this; + } + + public static IBroadcastChannelProvider GetBroadcaseChannelProvider(this IClusterClient @this, string name) + => @this.ServiceProvider.GetRequiredServiceByName(name); + + private static void AddBroadcastChannel(this IServiceCollection services, string name, Action> configureOptions) + { + configureOptions?.Invoke(services.AddOptions(name)); + services.ConfigureNamedOptionForLogging(name); + services + .AddSingleton() + .AddSingleton() + .AddSingleton() + .AddSingletonKeyedService(DefaultChannelIdMapper.Name) + .AddSingletonNamedService(name, BroadcastChannelProvider.Create); + } + } +} diff --git a/src/Orleans.BroadcastChannel/IdMapping/DefaultChannelIdMapper.cs b/src/Orleans.BroadcastChannel/IdMapping/DefaultChannelIdMapper.cs new file mode 100644 index 0000000000..60e0ded58d --- /dev/null +++ b/src/Orleans.BroadcastChannel/IdMapping/DefaultChannelIdMapper.cs @@ -0,0 +1,79 @@ +using System; +using System.Buffers.Text; +using System.Runtime.InteropServices; +using Orleans.Metadata; +using Orleans.Runtime; + +namespace Orleans.BroadcastChannel +{ + /// + /// The default implementation. + /// + public class DefaultChannelIdMapper : IChannelIdMapper + { + /// + /// The name of this stream identity mapper. + /// + public const string Name = "default"; + + /// + public IdSpan GetGrainKeyId(GrainBindings grainBindings, ChannelId streamId) + { + string keyType = null; + bool includeNamespaceInGrainId = false; + + foreach (var grainBinding in grainBindings.Bindings) + { + if (!grainBinding.TryGetValue(WellKnownGrainTypeProperties.BindingTypeKey, out var type) + || !string.Equals(type, WellKnownGrainTypeProperties.BroadcastChannelBindingTypeValue, StringComparison.Ordinal)) + { + continue; + } + + if (grainBinding.TryGetValue(WellKnownGrainTypeProperties.LegacyGrainKeyType, out keyType)) + { + if (grainBinding.TryGetValue(WellKnownGrainTypeProperties.StreamBindingIncludeNamespaceKey, out var value) + && string.Equals(value, "true", StringComparison.OrdinalIgnoreCase)) + { + includeNamespaceInGrainId = true; + } + } + } + + return keyType switch + { + nameof(Guid) => GetGuidKey(streamId, includeNamespaceInGrainId), + nameof(Int64) => GetIntegerKey(streamId, includeNamespaceInGrainId), + _ => GetKey(streamId), // null or string + }; + } + + private static IdSpan GetGuidKey(ChannelId streamId, bool includeNamespaceInGrainId) + { + var key = streamId.Key.Span; + if (!Utf8Parser.TryParse(key, out Guid guidKey, out var len, 'N') || len < key.Length) throw new ArgumentException(nameof(streamId)); + + return includeNamespaceInGrainId + ? GrainIdKeyExtensions.CreateGuidKey(guidKey, streamId.GetNamespace()) + : GrainIdKeyExtensions.CreateGuidKey(guidKey); + } + + private static IdSpan GetIntegerKey(ChannelId streamId, bool includeNamespaceInGrainId) + { + var key = streamId.Key.Span; + if (!Utf8Parser.TryParse(key, out int intKey, out var len) || len < key.Length) throw new ArgumentException(nameof(streamId)); + + return includeNamespaceInGrainId + ? GrainIdKeyExtensions.CreateIntegerKey(intKey, streamId.GetNamespace()) + : GrainIdKeyExtensions.CreateIntegerKey(intKey); + } + + private static IdSpan GetKey(ChannelId streamId) + { + var key = streamId.Key; + return MemoryMarshal.TryGetArray(key, out var seg) && seg.Offset == 0 && seg.Count == seg.Array.Length + ? new IdSpan(seg.Array) + : new IdSpan(key.ToArray()); + } + } +} diff --git a/src/Orleans.BroadcastChannel/IdMapping/IChannelIdMapper.cs b/src/Orleans.BroadcastChannel/IdMapping/IChannelIdMapper.cs new file mode 100644 index 0000000000..adfeeb8525 --- /dev/null +++ b/src/Orleans.BroadcastChannel/IdMapping/IChannelIdMapper.cs @@ -0,0 +1,19 @@ +using Orleans.Metadata; +using Orleans.Runtime; + +namespace Orleans.BroadcastChannel +{ + /// + /// Common interface for component that map a to a + /// + public interface IChannelIdMapper + { + /// + /// Get the which maps to the provided + /// + /// The grain bindings. + /// The stream identifier. + /// The component. + IdSpan GetGrainKeyId(GrainBindings grainBindings, ChannelId streamId); + } +} diff --git a/src/Orleans.BroadcastChannel/Orleans.BroadcastChannel.csproj b/src/Orleans.BroadcastChannel/Orleans.BroadcastChannel.csproj new file mode 100644 index 0000000000..7cfc41ade7 --- /dev/null +++ b/src/Orleans.BroadcastChannel/Orleans.BroadcastChannel.csproj @@ -0,0 +1,14 @@ + + + Microsoft.Orleans.BroadcastChannel + Microsoft Orleans Broadcast Channel Library + Broadcast Channel library for Microsoft Orleans used both on the client and server. + $(MultiTargetFrameworks) + true + false + + + + + + diff --git a/src/Orleans.BroadcastChannel/SubscriberTable/ImplicitChannelSubscriberTable.cs b/src/Orleans.BroadcastChannel/SubscriberTable/ImplicitChannelSubscriberTable.cs new file mode 100644 index 0000000000..8da5f617dc --- /dev/null +++ b/src/Orleans.BroadcastChannel/SubscriberTable/ImplicitChannelSubscriberTable.cs @@ -0,0 +1,329 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using Orleans.Metadata; +using Orleans.Runtime; + +namespace Orleans.BroadcastChannel.SubscriberTable +{ + internal class ImplicitChannelSubscriberTable + { + private readonly object _lockObj = new object(); + private readonly GrainBindingsResolver _bindings; + private readonly IChannelNamespacePredicateProvider[] _providers; + private readonly IServiceProvider _serviceProvider; + private Cache _cache; + + public ImplicitChannelSubscriberTable( + GrainBindingsResolver bindings, + IEnumerable providers, + IServiceProvider serviceProvider) + { + _bindings = bindings; + var initialBindings = bindings.GetAllBindings(); + _providers = providers.ToArray(); + _serviceProvider = serviceProvider; + _cache = BuildCache(initialBindings.Version, initialBindings.Bindings); + } + + private Cache GetCache() + { + var cache = _cache; + var bindings = _bindings.GetAllBindings(); + if (bindings.Version == cache.Version) + { + return cache; + } + + lock (_lockObj) + { + bindings = _bindings.GetAllBindings(); + if (bindings.Version == cache.Version) + { + return cache; + } + + return _cache = BuildCache(bindings.Version, bindings.Bindings); + } + } + + private Cache BuildCache(MajorMinorVersion version, ImmutableDictionary bindings) + { + var newPredicates = new List(); + + foreach (var binding in bindings.Values) + { + foreach (var grainBinding in binding.Bindings) + { + if (!grainBinding.TryGetValue(WellKnownGrainTypeProperties.BindingTypeKey, out var type) + || !string.Equals(type, WellKnownGrainTypeProperties.BroadcastChannelBindingTypeValue, StringComparison.Ordinal)) + { + continue; + } + + if (!grainBinding.TryGetValue(WellKnownGrainTypeProperties.BroadcastChannelBindingPatternKey, out var pattern)) + { + throw new KeyNotFoundException( + $"Channel binding for grain type {binding.GrainType} is missing a \"{WellKnownGrainTypeProperties.BroadcastChannelBindingPatternKey}\" value"); + } + + IChannelNamespacePredicate predicate = null; + foreach (var provider in _providers) + { + if (provider.TryGetPredicate(pattern, out predicate)) break; + } + + if (predicate is null) + { + throw new KeyNotFoundException( + $"Could not find an {nameof(IChannelNamespacePredicate)} for the pattern \"{pattern}\"." + + $" Ensure that a corresponding {nameof(IChannelNamespacePredicateProvider)} is registered"); + } + + if (!grainBinding.TryGetValue(WellKnownGrainTypeProperties.ChannelIdMapperKey, out var mapperName)) + { + throw new KeyNotFoundException( + $"Channel binding for grain type {binding.GrainType} is missing a \"{WellKnownGrainTypeProperties.ChannelIdMapperKey}\" value"); + } + + var channelIdMapper = _serviceProvider.GetServiceByName(string.IsNullOrWhiteSpace(mapperName) ? DefaultChannelIdMapper.Name : mapperName); + var subscriber = new BroadcastChannelSubscriber(binding, channelIdMapper); + newPredicates.Add(new BroadcastChannelSubscriberPredicate(subscriber, predicate)); + } + } + + return new Cache(version, newPredicates); + } + + /// + /// Retrieve a map of implicit subscriptionsIds to implicit subscribers, given a channel ID. This method throws an exception if there's no namespace associated with the channel ID. + /// + /// A channel ID. + /// The grain factory used to get consumer references. + /// A set of references to implicitly subscribed grains. They are expected to support the broadcast channel consumer extension. + /// The channel ID doesn't have an associated namespace. + /// Internal invariant violation. + internal IDictionary GetImplicitSubscribers(InternalChannelId channelId, IGrainFactory grainFactory) + { + if (!IsImplicitSubscribeEligibleNameSpace(channelId.GetNamespace())) + { + throw new ArgumentException("The channel ID doesn't have an associated namespace.", nameof(channelId)); + } + + var entries = GetOrAddImplicitSubscribers(channelId.GetNamespace()); + + var result = new Dictionary(); + foreach (var entry in entries) + { + var consumer = MakeConsumerReference(grainFactory, channelId, entry); + var subscriptionGuid = MakeSubscriptionGuid(entry.GrainType, channelId); + if (result.ContainsKey(subscriptionGuid)) + { + throw new InvalidOperationException( + $"Internal invariant violation: generated duplicate subscriber reference: {consumer}, subscriptionId: {subscriptionGuid}"); + } + result.Add(subscriptionGuid, consumer); + } + return result; + } + + private HashSet GetOrAddImplicitSubscribers(string channelNamespace) + { + var cache = GetCache(); + if (cache.Namespaces.TryGetValue(channelNamespace, out var result)) + { + return result; + } + + return cache.Namespaces[channelNamespace] = FindImplicitSubscribers(channelNamespace, cache.Predicates); + } + + /// + /// Determines whether the specified grain is an implicit subscriber of a given channel. + /// + /// The grain identifier. + /// The channel identifier. + /// true if the grain id describes an implicit subscriber of the channel described by the channel id. + internal bool IsImplicitSubscriber(GrainId grainId, InternalChannelId channelId) + { + return HasImplicitSubscription(channelId.GetNamespace(), grainId.Type); + } + + /// + /// Try to get the implicit subscriptionId. + /// If an implicit subscription exists, return a subscription Id that is unique per grain type, grainId, namespace combination. + /// + /// + /// + /// + /// + internal bool TryGetImplicitSubscriptionGuid(GrainId grainId, InternalChannelId channelId, out Guid subscriptionId) + { + subscriptionId = Guid.Empty; + + if (!IsImplicitSubscriber(grainId, channelId)) + { + return false; + } + + // make subscriptionId + subscriptionId = MakeSubscriptionGuid(grainId.Type, channelId); + + return true; + } + + /// + /// Create a subscriptionId that is unique per grainId, grainType, namespace combination. + /// + private Guid MakeSubscriptionGuid(GrainType grainType, InternalChannelId channelId) + { + // next 2 shorts inc guid are from namespace hash + var namespaceHash = JenkinsHash.ComputeHash(channelId.GetNamespace()); + var namespaceHashByes = BitConverter.GetBytes(namespaceHash); + var s1 = BitConverter.ToInt16(namespaceHashByes, 0); + var s2 = BitConverter.ToInt16(namespaceHashByes, 2); + + // Tailing 8 bytes of the guid are from the hash of the channelId Guid and a hash of the provider name. + // get channelId guid hash code + var channelIdGuidHash = JenkinsHash.ComputeHash(channelId.ChannelId.Key.Span); + // get provider name hash code + var providerHash = JenkinsHash.ComputeHash(channelId.ProviderName); + + // build guid tailing 8 bytes from grainIdHash and the hash of the provider name. + var tail = new List(); + tail.AddRange(BitConverter.GetBytes(channelIdGuidHash)); + tail.AddRange(BitConverter.GetBytes(providerHash)); + + // make guid. + // - First int is grain type + // - Two shorts from namespace hash + // - 8 byte tail from channelId Guid and provider name hash. + var id = new Guid((int)JenkinsHash.ComputeHash(grainType.ToString()), s1, s2, tail.ToArray()); + var result = MarkSubscriptionGuid(id, isImplicitSubscription: true); + return result; + } + + internal static bool IsImplicitSubscribeEligibleNameSpace(string channelNameSpace) + { + return !string.IsNullOrWhiteSpace(channelNameSpace); + } + + private bool HasImplicitSubscription(string channelNamespace, GrainType grainType) + { + if (!IsImplicitSubscribeEligibleNameSpace(channelNamespace)) + { + return false; + } + + var entry = GetOrAddImplicitSubscribers(channelNamespace); + return entry.Any(e => e.GrainType == grainType); + } + + /// + /// Finds all implicit subscribers for the given channel namespace. + /// + private static HashSet FindImplicitSubscribers(string channelNamespace, List predicates) + { + var result = new HashSet(); + foreach (var predicate in predicates) + { + if (predicate.Predicate.IsMatch(channelNamespace)) + { + result.Add(predicate.Subscriber); + } + } + + return result; + } + + private static Guid MarkSubscriptionGuid(Guid subscriptionGuid, bool isImplicitSubscription) + { + byte[] guidBytes = subscriptionGuid.ToByteArray(); + if (isImplicitSubscription) + { + // set high bit of last byte + guidBytes[guidBytes.Length - 1] = (byte)(guidBytes[guidBytes.Length - 1] | 0x80); + } + else + { + // clear high bit of last byte + guidBytes[guidBytes.Length - 1] = (byte)(guidBytes[guidBytes.Length - 1] & 0x7f); + } + + return new Guid(guidBytes); + } + + /// + /// Create a reference to a grain that we expect to support the broadcast channel consumer extension. + /// + /// The grain factory used to get consumer references. + /// The channel ID to use for the grain ID construction. + /// The GrainBindings for the grain to create + /// + private IBroadcastChannelConsumerExtension MakeConsumerReference( + IGrainFactory grainFactory, + InternalChannelId channelId, + BroadcastChannelSubscriber channelSubscriber) + { + var grainId = channelSubscriber.GetGrainId(channelId); + return grainFactory.GetGrain(grainId); + } + + private class BroadcastChannelSubscriberPredicate + { + public BroadcastChannelSubscriberPredicate(BroadcastChannelSubscriber subscriber, IChannelNamespacePredicate predicate) + { + Subscriber = subscriber; + Predicate = predicate; + } + + public BroadcastChannelSubscriber Subscriber { get; } + public IChannelNamespacePredicate Predicate { get; } + } + + private class BroadcastChannelSubscriber + { + public BroadcastChannelSubscriber(GrainBindings grainBindings, IChannelIdMapper channelIdMapper) + { + GrainBindings = grainBindings; + this.channelIdMapper = channelIdMapper; + } + + public GrainType GrainType => GrainBindings.GrainType; + + private GrainBindings GrainBindings { get; } + + private IChannelIdMapper channelIdMapper { get; } + + public override bool Equals(object obj) + { + return obj is BroadcastChannelSubscriber subscriber && + GrainType.Equals(subscriber.GrainType); + } + + public override int GetHashCode() => GrainType.GetHashCode(); + + internal GrainId GetGrainId(InternalChannelId channelId) + { + var grainKeyId = channelIdMapper.GetGrainKeyId(GrainBindings, channelId); + return GrainId.Create(GrainType, grainKeyId); + } + } + + private class Cache + { + public Cache(MajorMinorVersion version, List predicates) + { + Version = version; + Predicates = predicates; + Namespaces = new ConcurrentDictionary>(); + } + + public MajorMinorVersion Version { get; } + public ConcurrentDictionary> Namespaces { get; } + public List Predicates { get; } + } + } +} \ No newline at end of file diff --git a/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/AllStreamNamespacesPredicate.cs b/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/AllStreamNamespacesPredicate.cs new file mode 100644 index 0000000000..f738fd4d92 --- /dev/null +++ b/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/AllStreamNamespacesPredicate.cs @@ -0,0 +1,17 @@ +namespace Orleans.BroadcastChannel +{ + /// + /// A stream namespace predicate which matches all namespaces. + /// + internal class AllStreamNamespacesPredicate : IChannelNamespacePredicate + { + /// + public string PredicatePattern => "*"; + + /// + public bool IsMatch(string streamNamespace) + { + return true; + } + } +} \ No newline at end of file diff --git a/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/DefaultStreamNamespacePredicateProvider.cs b/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/DefaultStreamNamespacePredicateProvider.cs new file mode 100644 index 0000000000..1ca7be79b1 --- /dev/null +++ b/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/DefaultStreamNamespacePredicateProvider.cs @@ -0,0 +1,94 @@ +using System; +using Orleans.BroadcastChannel; +using Orleans.Serialization.TypeSystem; +using Orleans.Utilities; + +namespace Orleans.BroadcastChannel +{ + /// + /// Default implementation of for internally supported stream predicates. + /// + public class DefaultChannelNamespacePredicateProvider : IChannelNamespacePredicateProvider + { + /// + public bool TryGetPredicate(string predicatePattern, out IChannelNamespacePredicate predicate) + { + switch (predicatePattern) + { + case "*": + predicate = new AllStreamNamespacesPredicate(); + return true; + case var regex when regex.StartsWith(RegexChannelNamespacePredicate.Prefix, StringComparison.Ordinal): + predicate = new RegexChannelNamespacePredicate(regex.Substring(RegexChannelNamespacePredicate.Prefix.Length)); + return true; + case var ns when ns.StartsWith(ExactMatchChannelNamespacePredicate.Prefix, StringComparison.Ordinal): + predicate = new ExactMatchChannelNamespacePredicate(ns.Substring(ExactMatchChannelNamespacePredicate.Prefix.Length)); + return true; + } + + predicate = null; + return false; + } + } + + /// + /// Stream namespace predicate provider which supports objects which can be constructed and optionally accept a string as a constructor argument. + /// + public class ConstructorChannelNamespacePredicateProvider : IChannelNamespacePredicateProvider + { + /// + /// The prefix used to identify this predicate provider. + /// + public const string Prefix = "ctor"; + + /// + /// Formats a stream namespace predicate which indicates a concrete type to be constructed, along with an optional argument. + /// + public static string FormatPattern(Type predicateType, string constructorArgument) + { + if (constructorArgument is null) + { + return $"{Prefix}:{RuntimeTypeNameFormatter.Format(predicateType)}"; + } + + return $"{Prefix}:{RuntimeTypeNameFormatter.Format(predicateType)}:{constructorArgument}"; + } + + /// + public bool TryGetPredicate(string predicatePattern, out IChannelNamespacePredicate predicate) + { + if (!predicatePattern.StartsWith(Prefix, StringComparison.Ordinal)) + { + predicate = null; + return false; + } + + var start = Prefix.Length + 1; + string typeName; + string arg; + var index = predicatePattern.IndexOf(':', start); + if (index < 0) + { + typeName = predicatePattern.Substring(start); + arg = null; + } + else + { + typeName = predicatePattern.Substring(start, index - start); + arg = predicatePattern.Substring(index + 1); + } + + var type = Type.GetType(typeName, throwOnError: true); + if (string.IsNullOrEmpty(arg)) + { + predicate = (IChannelNamespacePredicate)Activator.CreateInstance(type); + } + else + { + predicate = (IChannelNamespacePredicate)Activator.CreateInstance(type, arg); + } + + return true; + } + } +} \ No newline at end of file diff --git a/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/ExactMatchStreamNamespacePredicate.cs b/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/ExactMatchStreamNamespacePredicate.cs new file mode 100644 index 0000000000..8dd01aced7 --- /dev/null +++ b/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/ExactMatchStreamNamespacePredicate.cs @@ -0,0 +1,35 @@ +using System; + +namespace Orleans.BroadcastChannel +{ + /// + /// Stream namespace predicate which matches exactly one, specified + /// + [Serializable] + [GenerateSerializer] + internal class ExactMatchChannelNamespacePredicate : IChannelNamespacePredicate + { + internal const string Prefix = "namespace:"; + + [Id(1)] + private readonly string targetStreamNamespace; + + /// + /// Initializes a new instance of the class. + /// + /// The target stream namespace. + public ExactMatchChannelNamespacePredicate(string targetStreamNamespace) + { + this.targetStreamNamespace = targetStreamNamespace; + } + + /// + public string PredicatePattern => $"{Prefix}{this.targetStreamNamespace}"; + + /// + public bool IsMatch(string streamNamespace) + { + return string.Equals(targetStreamNamespace, streamNamespace?.Trim()); + } + } +} \ No newline at end of file diff --git a/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/IChannelNamespacePredicate.cs b/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/IChannelNamespacePredicate.cs new file mode 100644 index 0000000000..80dc6f9e16 --- /dev/null +++ b/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/IChannelNamespacePredicate.cs @@ -0,0 +1,35 @@ +namespace Orleans.BroadcastChannel +{ + /// + /// Stream namespace predicate used for filtering implicit subscriptions using + /// . + /// + /// All implementations must be serializable. + public interface IChannelNamespacePredicate + { + /// + /// Defines if the consumer grain should subscribe to the specified namespace. + /// + /// The target stream namespace to check. + /// true, if the grain should subscribe to the specified namespace; false, otherwise. + /// + bool IsMatch(string streamNamespace); + + /// + /// Gets a pattern to describe this predicate. This value is passed to instances of to recreate this predicate. + /// + string PredicatePattern { get; } + } + + /// + /// Converts predicate pattern strings to instances. + /// + /// + public interface IChannelNamespacePredicateProvider + { + /// + /// Get the predicate matching the provided pattern. Returns if this provider cannot match the predicate. + /// + bool TryGetPredicate(string predicatePattern, out IChannelNamespacePredicate predicate); + } +} \ No newline at end of file diff --git a/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/ImplicitChannelSubscriptionAttribute.cs b/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/ImplicitChannelSubscriptionAttribute.cs new file mode 100644 index 0000000000..b5590faade --- /dev/null +++ b/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/ImplicitChannelSubscriptionAttribute.cs @@ -0,0 +1,118 @@ +using System; +using System.Collections.Generic; +using Orleans.BroadcastChannel; +using Orleans.Metadata; +using Orleans.Runtime; + +namespace Orleans +{ + /// + /// The [Orleans.ImplicitStreamSubscription] attribute is used to mark grains as implicit stream subscriptions. + /// + [AttributeUsage(AttributeTargets.Class, AllowMultiple = true)] + public class ImplicitChannelSubscriptionAttribute : Attribute, IGrainBindingsProviderAttribute + { + /// + /// Gets the stream namespace filter predicate. + /// + public IChannelNamespacePredicate Predicate { get; } + + /// + /// Gets the name of the channel identifier mapper. + /// + /// The name of the channel identifier mapper. + public string ChannelIdMapper { get; } + + /// + /// Used to subscribe to all stream namespaces. + /// + public ImplicitChannelSubscriptionAttribute() + { + Predicate = new AllStreamNamespacesPredicate(); + } + + /// + /// Used to subscribe to the specified stream namespace. + /// + /// The stream namespace to subscribe. + /// The name of the stream identity mapper. + public ImplicitChannelSubscriptionAttribute(string streamNamespace, string channelIdMapper = null) + { + Predicate = new ExactMatchChannelNamespacePredicate(streamNamespace.Trim()); + ChannelIdMapper = channelIdMapper; + } + + /// + /// Allows to pass an arbitrary predicate type to filter stream namespaces to subscribe. The predicate type + /// must have a constructor without parameters. + /// + /// The stream namespace predicate type. + /// The name of the stream identity mapper. + public ImplicitChannelSubscriptionAttribute(Type predicateType, string channelIdMapper = null) + { + Predicate = (IChannelNamespacePredicate) Activator.CreateInstance(predicateType); + ChannelIdMapper = channelIdMapper; + } + + /// + /// Allows to pass an instance of the stream namespace predicate. To be used mainly as an extensibility point + /// via inheriting attributes. + /// + /// The stream namespace predicate. + /// The name of the stream identity mapper. + public ImplicitChannelSubscriptionAttribute(IChannelNamespacePredicate predicate, string channelIdMapper = null) + { + Predicate = predicate; + ChannelIdMapper = channelIdMapper; + } + + /// + public IEnumerable> GetBindings(IServiceProvider services, Type grainClass, GrainType grainType) + { + var binding = new Dictionary + { + [WellKnownGrainTypeProperties.BindingTypeKey] = WellKnownGrainTypeProperties.BroadcastChannelBindingTypeValue, + [WellKnownGrainTypeProperties.BroadcastChannelBindingPatternKey] = this.Predicate.PredicatePattern, + [WellKnownGrainTypeProperties.ChannelIdMapperKey] = this.ChannelIdMapper, + }; + + if (LegacyGrainId.IsLegacyGrainType(grainClass)) + { + string keyType; + + if (typeof(IGrainWithGuidKey).IsAssignableFrom(grainClass) || typeof(IGrainWithGuidCompoundKey).IsAssignableFrom(grainClass)) + keyType = nameof(Guid); + else if (typeof(IGrainWithIntegerKey).IsAssignableFrom(grainClass) || typeof(IGrainWithIntegerCompoundKey).IsAssignableFrom(grainClass)) + keyType = nameof(Int64); + else // fallback to string + keyType = nameof(String); + + binding[WellKnownGrainTypeProperties.LegacyGrainKeyType] = keyType; + } + + if (LegacyGrainId.IsLegacyKeyExtGrainType(grainClass)) + { + binding[WellKnownGrainTypeProperties.StreamBindingIncludeNamespaceKey] = "true"; + } + + yield return binding; + } + } + + /// + /// The [Orleans.RegexImplicitStreamSubscription] attribute is used to mark grains as implicit stream + /// subscriptions by filtering stream namespaces to subscribe using a regular expression. + /// + [AttributeUsage(AttributeTargets.Class, AllowMultiple = true)] + public sealed class RegexImplicitChannelSubscriptionAttribute : ImplicitChannelSubscriptionAttribute + { + /// + /// Allows to pass a regular expression to filter stream namespaces to subscribe to. + /// + /// The stream namespace regular expression filter. + public RegexImplicitChannelSubscriptionAttribute(string pattern) + : base(new RegexChannelNamespacePredicate(pattern)) + { + } + } +} \ No newline at end of file diff --git a/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/RegexChannelNamespacePredicate.cs b/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/RegexChannelNamespacePredicate.cs new file mode 100644 index 0000000000..a08435baa5 --- /dev/null +++ b/src/Orleans.BroadcastChannel/SubscriberTable/Predicates/RegexChannelNamespacePredicate.cs @@ -0,0 +1,35 @@ +using System; +using System.Text.RegularExpressions; + +namespace Orleans.BroadcastChannel +{ + /// + /// implementation allowing to filter stream namespaces by regular + /// expression. + /// + public class RegexChannelNamespacePredicate : IChannelNamespacePredicate + { + internal const string Prefix = "regex:"; + private readonly Regex regex; + + /// + /// Returns a pattern used to describe this instance. The pattern will be parsed by an instance on each node. + /// + public string PredicatePattern => $"{Prefix}{regex}"; + + /// + /// Creates an instance of with the specified regular expression. + /// + /// The stream namespace regular expression. + public RegexChannelNamespacePredicate(string regex) + { + this.regex = new Regex(regex, RegexOptions.Compiled) ?? throw new ArgumentNullException(nameof(regex)); + } + + /// + public bool IsMatch(string streamNameSpace) + { + return regex.IsMatch(streamNameSpace); + } + } +} \ No newline at end of file diff --git a/src/Orleans.Core.Abstractions/Manifest/GrainProperties.cs b/src/Orleans.Core.Abstractions/Manifest/GrainProperties.cs index 3a0bfcb7d1..201079399a 100644 --- a/src/Orleans.Core.Abstractions/Manifest/GrainProperties.cs +++ b/src/Orleans.Core.Abstractions/Manifest/GrainProperties.cs @@ -119,16 +119,31 @@ public static class WellKnownGrainTypeProperties /// public const string StreamBindingTypeValue = "stream"; + /// + /// The binding type for Broadcast Channels. + /// + public const string BroadcastChannelBindingTypeValue = "broadcast-channel"; + /// /// The key to specify a stream binding pattern. /// public const string StreamBindingPatternKey = "pattern"; + /// + /// The key to specify a channel binding pattern. + /// + public const string BroadcastChannelBindingPatternKey = "channel-pattern"; + /// /// The key to specify a stream id mapper /// public const string StreamIdMapperKey = "streamid-mapper"; + /// + /// The key to specify a channel id mapper + /// + public const string ChannelIdMapperKey = "channelid-mapper"; + /// /// Whether to include the namespace name in the grain id. /// diff --git a/src/Orleans.Core/Properties/AssemblyInfo.cs b/src/Orleans.Core/Properties/AssemblyInfo.cs index 5695ea9428..2da2620e8e 100644 --- a/src/Orleans.Core/Properties/AssemblyInfo.cs +++ b/src/Orleans.Core/Properties/AssemblyInfo.cs @@ -1,5 +1,6 @@ using System.Runtime.CompilerServices; +[assembly: InternalsVisibleTo("Orleans.BroadcastChannel")] [assembly: InternalsVisibleTo("Orleans.CodeGeneration")] [assembly: InternalsVisibleTo("Orleans.CodeGeneration.Build")] [assembly: InternalsVisibleTo("Orleans.Runtime")] diff --git a/src/Orleans.Streaming/GrainStreamingExtensions.cs b/src/Orleans.Streaming/GrainStreamingExtensions.cs index 0fe446cb98..5b1adeb46f 100644 --- a/src/Orleans.Streaming/GrainStreamingExtensions.cs +++ b/src/Orleans.Streaming/GrainStreamingExtensions.cs @@ -1,4 +1,5 @@ using System; +using Orleans.Providers.Streams.SimpleMessageStream; using Orleans.Runtime; using Orleans.Streams; diff --git a/src/Orleans.Streaming/SimpleMessageStream/SimpleMessageStreamProvider.cs b/src/Orleans.Streaming/SimpleMessageStream/SimpleMessageStreamProvider.cs index eeea1c2dec..85d41219a0 100644 --- a/src/Orleans.Streaming/SimpleMessageStream/SimpleMessageStreamProvider.cs +++ b/src/Orleans.Streaming/SimpleMessageStream/SimpleMessageStreamProvider.cs @@ -8,6 +8,7 @@ using Orleans.Configuration; using Orleans.Streams.Filtering; using Orleans.Serialization; +using System.Threading.Tasks; namespace Orleans.Providers.Streams.SimpleMessageStream { diff --git a/test/Grains/TestGrains/SimpleStreams/SimpleSubscriberGrain.cs b/test/Grains/TestGrains/SimpleStreams/SimpleSubscriberGrain.cs new file mode 100644 index 0000000000..5e737368bf --- /dev/null +++ b/test/Grains/TestGrains/SimpleStreams/SimpleSubscriberGrain.cs @@ -0,0 +1,79 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Newtonsoft.Json.Serialization; +using Orleans; +using Orleans.BroadcastChannel; +using Orleans.Runtime; + +namespace UnitTests.Grains.BroadcastChannel +{ + public interface ISubscriberGrain : IGrainWithStringKey + { + Task> GetErrors(ChannelId streamId); + + Task> GetValues(ChannelId streamId); + + Task ThrowsOnReceive(bool throwsOnReceive); + } + + public interface ISimpleSubscriberGrain : ISubscriberGrain { } + + public interface IRegexNamespaceSubscriberGrain : ISubscriberGrain { } + + public abstract class SubscriberGrainBase : Grain, ISubscriberGrain, IOnBroadcastChannelSubscribed + { + private Dictionary> _values = new(); + private Dictionary> _errors = new(); + private bool _throwsOnReceive = false; + + public Task> GetErrors(ChannelId streamId) => _errors.TryGetValue(streamId, out var errors) ? Task.FromResult(errors) : Task.FromResult(new List()); + + public Task> GetValues(ChannelId streamId) => _values.TryGetValue(streamId, out var values) ? Task.FromResult(values) : Task.FromResult(new List()); + + public Task OnSubscribed(IBroadcastChannelSubscription streamSubscription) + { + streamSubscription.Attach(item => OnPublished(streamSubscription.ChannelId, item), ex => OnError(streamSubscription.ChannelId, ex)); + return Task.CompletedTask; + + Task OnPublished(ChannelId id, int item) + { + if (_throwsOnReceive) + { + throw new Exception("Some error message here"); + } + if (!_values.TryGetValue(id, out var values)) + { + _values[id] = values = new List(); + } + values.Add(item); + return Task.CompletedTask; + } + + Task OnError(ChannelId id, Exception ex) + { + if (!_errors.TryGetValue(id, out var errors)) + { + _errors[id] = errors = new List(); + } + errors.Add(ex); + return Task.CompletedTask; + } + } + + public Task ThrowsOnReceive(bool throwsOnReceive) + { + _throwsOnReceive = throwsOnReceive; + return Task.CompletedTask; + } + } + + [ImplicitChannelSubscription] + public class SimpleSubscriberGrain : SubscriberGrainBase, ISimpleSubscriberGrain { } + + [RegexImplicitChannelSubscription("multiple-namespaces-(.)+")] + public class RegexNamespaceSubscriberGrain : SubscriberGrainBase, IRegexNamespaceSubscriberGrain { } +} diff --git a/test/Grains/TestGrains/StreamInterceptionGrain.cs b/test/Grains/TestGrains/StreamInterceptionGrain.cs index c1286896ba..cd1e074201 100644 --- a/test/Grains/TestGrains/StreamInterceptionGrain.cs +++ b/test/Grains/TestGrains/StreamInterceptionGrain.cs @@ -1,4 +1,4 @@ -using System.Threading; +using System.Threading; using System.Threading.Tasks; using Orleans; using Orleans.Streams; diff --git a/test/Grains/TestGrains/TestGrains.csproj b/test/Grains/TestGrains/TestGrains.csproj index 552a3690e7..88074884ca 100644 --- a/test/Grains/TestGrains/TestGrains.csproj +++ b/test/Grains/TestGrains/TestGrains.csproj @@ -11,5 +11,6 @@ + diff --git a/test/Tester/GrainCallFilterTests.cs b/test/Tester/GrainCallFilterTests.cs index c38efbc7cd..ccacb1baee 100644 --- a/test/Tester/GrainCallFilterTests.cs +++ b/test/Tester/GrainCallFilterTests.cs @@ -14,6 +14,7 @@ using Xunit; using Orleans.Hosting; using Orleans.Serialization; +using Orleans.Providers.Streams.SimpleMessageStream; namespace UnitTests.General { diff --git a/test/Tester/StreamingTests/BroadcastChannels/BroadcastChannelTests.cs b/test/Tester/StreamingTests/BroadcastChannels/BroadcastChannelTests.cs new file mode 100644 index 0000000000..695ca607c2 --- /dev/null +++ b/test/Tester/StreamingTests/BroadcastChannels/BroadcastChannelTests.cs @@ -0,0 +1,261 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Orleans; +using Orleans.BroadcastChannel; +using Orleans.Hosting; +using Orleans.Runtime; +using Orleans.TestingHost; +using TestExtensions; +using UnitTests.Grains.BroadcastChannel; +using Xunit; +using Xunit.Abstractions; + +namespace Tester.StreamingTests.BroadcastChannel +{ + [TestCategory("BVT")] + public class BroadcastChannelTests : OrleansTestingBase, IClassFixture + { + private const string ProviderName = "BroadcastChannel"; + private const string ProviderNameNonFireAndForget = "BroadcastChannelNonFireAndForget"; + private const int CallTimeoutMs = 500; + private readonly Fixture _fixture; + private IBroadcastChannelProvider _provider => _fixture.Client.GetBroadcaseChannelProvider(ProviderName); + private IBroadcastChannelProvider _providerNonFireAndForget => _fixture.Client.GetBroadcaseChannelProvider(ProviderNameNonFireAndForget); + + public class Fixture : BaseTestClusterFixture + { + protected override void CheckPreconditionsOrThrow() + { + base.CheckPreconditionsOrThrow(); + } + + protected override void ConfigureTestCluster(TestClusterBuilder builder) + { + builder.AddClientBuilderConfigurator(); + builder.AddSiloBuilderConfigurator(); + } + public class SiloConfigurator : ISiloConfigurator + { + public void Configure(ISiloBuilder hostBuilder) + { + hostBuilder.AddBroadcastChannel(ProviderName); + hostBuilder.AddBroadcastChannel(ProviderNameNonFireAndForget, options => options.FireAndForgetDelivery = false); + } + } + public class ClientConfigurator : IClientBuilderConfigurator + { + public void Configure(IConfiguration configuration, IClientBuilder clientBuilder) + { + clientBuilder.AddBroadcastChannel(ProviderName); + clientBuilder.AddBroadcastChannel(ProviderNameNonFireAndForget, options => options.FireAndForgetDelivery = false); + } + } + } + + public BroadcastChannelTests(Fixture fixture) + { + fixture.EnsurePreconditionsMet(); + _fixture = fixture; + } + + [Fact] + public async Task ClientPublishSingleChannelTest() => await ClientPublishSingleChannelTestImpl(_provider); + + [Fact] + public async Task ClientPublishSingleChannelMultipleConsumersTest() => await MultipleSubscribersChannelTestImpl(_provider); + + [Fact] + public async Task ClientPublishMultipleChannelTest() => await ClientPublishMultipleChannelTestImpl(_provider); + + [Fact] + public async Task MultipleSubscribersOneBadActorChannelTest() => await MultipleSubscribersOneBadActorChannelTestImpl(_provider); + + [Fact] + public async Task NonFireAndForgetClientPublishSingleChannelTest() => await ClientPublishSingleChannelTestImpl(_providerNonFireAndForget, false); + + [Fact] + public async Task NonFireAndForgetClientPublishMultipleChannelTest() => await ClientPublishMultipleChannelTestImpl(_providerNonFireAndForget); + + [Fact] + public async Task NonFireAndForgetClientPublishSingleChannelMultipleConsumersTest() => await MultipleSubscribersChannelTestImpl(_providerNonFireAndForget, false); + + [Fact] + public async Task NonFireAndForgetMultipleSubscribersOneBadActorChannelTest() => await MultipleSubscribersOneBadActorChannelTestImpl(_providerNonFireAndForget, false); + + private async Task ClientPublishSingleChannelTestImpl(IBroadcastChannelProvider provider, bool fireAndForget = true) + { + var grainKey = Guid.NewGuid().ToString("N"); + var channelId = ChannelId.Create("some-namespace", grainKey); + var stream = provider.GetChannel(channelId); + + await stream.Publish(1); + await stream.Publish(2); + await stream.Publish(3); + + var grain = _fixture.Client.GetGrain(grainKey); + var values = await Get(() => grain.GetValues(channelId), 3); + + Assert.Equal(3, values.Count); + if (fireAndForget) + { + Assert.Contains(1, values); + Assert.Contains(2, values); + Assert.Contains(3, values); + } + else + { + Assert.Equal(1, values[0]); + Assert.Equal(2, values[1]); + Assert.Equal(3, values[2]); + } + } + + private async Task ClientPublishMultipleChannelTestImpl(IBroadcastChannelProvider provider) + { + var grainKey = Guid.NewGuid().ToString("N"); + var channels = new List<(ChannelId ChannelId, int ExpectedValue)>(); + + for (var i = 0; i < 10; i++) + { + var id = ChannelId.Create($"some-namespace{i}", grainKey); + var value = i + 50; + + channels.Add((id, value)); + + await provider.GetChannel(id).Publish(value); + } + + var grain = _fixture.Client.GetGrain(grainKey); + + foreach (var channel in channels) + { + var values = await Get(() => grain.GetValues(channel.ChannelId), 1); + + Assert.Single(values); + Assert.Equal(channel.ExpectedValue, values[0]); + } + } + + private async Task MultipleSubscribersChannelTestImpl(IBroadcastChannelProvider provider, bool fireAndForget = true) + { + var grainKey = Guid.NewGuid().ToString("N"); + var channelId = ChannelId.Create("multiple-namespaces-0", grainKey); + var stream = provider.GetChannel(channelId); + + await stream.Publish(1); + await stream.Publish(2); + await stream.Publish(3); + + var grains = new ISubscriberGrain[] + { + _fixture.Client.GetGrain(grainKey), + _fixture.Client.GetGrain(grainKey) + }; + + foreach (var grain in grains) + { + var values = await Get(() => grain.GetValues(channelId), 3); + + Assert.Equal(3, values.Count); + if (fireAndForget) + { + Assert.Contains(1, values); + Assert.Contains(2, values); + Assert.Contains(3, values); + } + else + { + Assert.Equal(1, values[0]); + Assert.Equal(2, values[1]); + Assert.Equal(3, values[2]); + } + } + } + + private async Task MultipleSubscribersOneBadActorChannelTestImpl(IBroadcastChannelProvider provider, bool fireAndForget = true) + { + var grainKey = Guid.NewGuid().ToString("N"); + var channelId = ChannelId.Create("multiple-namespaces-0", grainKey); + var stream = provider.GetChannel(channelId); + + var badGrain = _fixture.Client.GetGrain(grainKey); + var goodGrain = _fixture.Client.GetGrain(grainKey); + + await stream.Publish(1); + if (fireAndForget) + { + var values = await Get(() => badGrain.GetValues(channelId), 1); + Assert.Single(values); + } + await badGrain.ThrowsOnReceive(true); + if (fireAndForget) + { + await stream.Publish(2); + } + else + { + var ex = await Assert.ThrowsAsync(() => stream.Publish(2)); + Assert.Single(ex.InnerExceptions); + } + await badGrain.ThrowsOnReceive(false); + await stream.Publish(3); + + var goodValues = await Get(() => goodGrain.GetValues(channelId), 3); + + Assert.Equal(3, goodValues.Count); + if (fireAndForget) + { + Assert.Contains(1, goodValues); + Assert.Contains(2, goodValues); + Assert.Contains(3, goodValues); + } + else + { + Assert.Equal(1, goodValues[0]); + Assert.Equal(2, goodValues[1]); + Assert.Equal(3, goodValues[2]); + } + + var badValues = await Get(() => badGrain.GetValues(channelId), 2); + + Assert.Equal(2, badValues.Count); + if (fireAndForget) + { + Assert.Contains(1, badValues); + Assert.Contains(3, badValues); + } + else + { + Assert.Equal(1, badValues[0]); + Assert.Equal(3, badValues[1]); + } + } + + private static async Task> Get(Func>> func, int expectedCount, int timeoutMs = CallTimeoutMs) + { + var cts = new CancellationTokenSource(timeoutMs); + while (!cts.IsCancellationRequested) + { + try + { + var values = await func(); + if (values.Count == expectedCount) + { + return values; + } + await Task.Delay(10); + } + catch (Exception) + { + // Ignore + } + } + return await func(); + } + } +} \ No newline at end of file diff --git a/test/Tester/StreamingTests/SampleStreamingTests.cs b/test/Tester/StreamingTests/SampleStreamingTests.cs index c153edf5b7..f73ae1055e 100644 --- a/test/Tester/StreamingTests/SampleStreamingTests.cs +++ b/test/Tester/StreamingTests/SampleStreamingTests.cs @@ -5,6 +5,7 @@ using Microsoft.Extensions.Logging; using Orleans; using Orleans.Hosting; +using Orleans.Providers.Streams.SimpleMessageStream; using Orleans.Runtime; using Orleans.Streams; using Orleans.TestingHost; diff --git a/test/Tester/Tester.csproj b/test/Tester/Tester.csproj index 385d1fe087..92da26470a 100644 --- a/test/Tester/Tester.csproj +++ b/test/Tester/Tester.csproj @@ -20,5 +20,6 @@ + From a542c7a2f1550d4073d293776c8de36c34b45a7b Mon Sep 17 00:00:00 2001 From: Benjamin Petit Date: Fri, 27 May 2022 11:03:13 +0200 Subject: [PATCH 2/3] Address comments --- .../BroadcastChannelConsumerExtension.cs | 5 +++++ .../BroadcastChannelProvider.cs | 6 +++--- .../{BroadcastChannel.cs => BroadcastChannelWriter.cs} | 10 ++++++---- test/Tester/GrainCallFilterTests.cs | 2 -- .../BroadcastChannels/BroadcastChannelTests.cs | 8 ++++---- 5 files changed, 18 insertions(+), 13 deletions(-) rename src/Orleans.BroadcastChannel/{BroadcastChannel.cs => BroadcastChannelWriter.cs} (89%) diff --git a/src/Orleans.BroadcastChannel/BroadcastChannelConsumerExtension.cs b/src/Orleans.BroadcastChannel/BroadcastChannelConsumerExtension.cs index 844b96feca..25fa1b79ae 100644 --- a/src/Orleans.BroadcastChannel/BroadcastChannelConsumerExtension.cs +++ b/src/Orleans.BroadcastChannel/BroadcastChannelConsumerExtension.cs @@ -50,6 +50,10 @@ public Task OnPublished(object item) public BroadcastChannelConsumerExtension(IGrainContextAccessor grainContextAccessor) { _subscriptionObserver = grainContextAccessor.GrainContext?.GrainInstance as IOnBroadcastChannelSubscribed; + if (_subscriptionObserver == null) + { + throw new ArgumentException($"The grain doesn't implement interface {nameof(IOnBroadcastChannelSubscribed)}"); + } } public async Task OnError(InternalChannelId streamId, Exception exception) @@ -88,6 +92,7 @@ private async ValueTask GetStreamCallback(InternalChannelId streamId) { return callback; } + // Give a chance to the grain to attach a handler for this streamId var subscription = new BroadcastChannelSubscription(this, streamId); await _subscriptionObserver.OnSubscribed(subscription); } diff --git a/src/Orleans.BroadcastChannel/BroadcastChannelProvider.cs b/src/Orleans.BroadcastChannel/BroadcastChannelProvider.cs index dcde7413ec..79cd3933d1 100644 --- a/src/Orleans.BroadcastChannel/BroadcastChannelProvider.cs +++ b/src/Orleans.BroadcastChannel/BroadcastChannelProvider.cs @@ -10,7 +10,7 @@ namespace Orleans.BroadcastChannel { public interface IBroadcastChannelProvider { - IBroadcastChannel GetChannel(ChannelId streamId); + IBroadcastChannelWriter GetChannelWriter(ChannelId streamId); } internal class BroadcastChannelProvider : IBroadcastChannelProvider @@ -35,9 +35,9 @@ public BroadcastChannelProvider( _loggerFactory = loggerFactory; } - public IBroadcastChannel GetChannel(ChannelId streamId) + public IBroadcastChannelWriter GetChannelWriter(ChannelId streamId) { - return new BroadcastChannel( + return new BroadcastChannelWriter( new InternalChannelId(_providerName, streamId), _grainFactory, _subscriberTable, diff --git a/src/Orleans.BroadcastChannel/BroadcastChannel.cs b/src/Orleans.BroadcastChannel/BroadcastChannelWriter.cs similarity index 89% rename from src/Orleans.BroadcastChannel/BroadcastChannel.cs rename to src/Orleans.BroadcastChannel/BroadcastChannelWriter.cs index 0bc8b46db6..989a9a095c 100644 --- a/src/Orleans.BroadcastChannel/BroadcastChannel.cs +++ b/src/Orleans.BroadcastChannel/BroadcastChannelWriter.cs @@ -9,20 +9,22 @@ namespace Orleans.BroadcastChannel { - public interface IBroadcastChannel + public interface IBroadcastChannelWriter { Task Publish(T item); } - internal class BroadcastChannel : IBroadcastChannel + internal class BroadcastChannelWriter : IBroadcastChannelWriter { + private static readonly string LoggingCategory = typeof(BroadcastChannelWriter<>).FullName; + private readonly InternalChannelId _channelId; private readonly IGrainFactory _grainFactory; private readonly ImplicitChannelSubscriberTable _subscriberTable; private readonly bool _fireAndForgetDelivery; private readonly ILogger _logger; - public BroadcastChannel( + public BroadcastChannelWriter( InternalChannelId channelId, IGrainFactory grainFactory, ImplicitChannelSubscriberTable subscriberTable, @@ -33,7 +35,7 @@ public BroadcastChannel( _grainFactory = grainFactory; _subscriberTable = subscriberTable; _fireAndForgetDelivery = fireAndForgetDelivery; - _logger = loggerFactory.CreateLogger($"{nameof(BroadcastChannel)}-{_channelId}"); + _logger = loggerFactory.CreateLogger(LoggingCategory); } public async Task Publish(T item) diff --git a/test/Tester/GrainCallFilterTests.cs b/test/Tester/GrainCallFilterTests.cs index ccacb1baee..39b11a3c0e 100644 --- a/test/Tester/GrainCallFilterTests.cs +++ b/test/Tester/GrainCallFilterTests.cs @@ -13,8 +13,6 @@ using UnitTests.Grains; using Xunit; using Orleans.Hosting; -using Orleans.Serialization; -using Orleans.Providers.Streams.SimpleMessageStream; namespace UnitTests.General { diff --git a/test/Tester/StreamingTests/BroadcastChannels/BroadcastChannelTests.cs b/test/Tester/StreamingTests/BroadcastChannels/BroadcastChannelTests.cs index 695ca607c2..94c814c982 100644 --- a/test/Tester/StreamingTests/BroadcastChannels/BroadcastChannelTests.cs +++ b/test/Tester/StreamingTests/BroadcastChannels/BroadcastChannelTests.cs @@ -91,7 +91,7 @@ private async Task ClientPublishSingleChannelTestImpl(IBroadcastChannelProvider { var grainKey = Guid.NewGuid().ToString("N"); var channelId = ChannelId.Create("some-namespace", grainKey); - var stream = provider.GetChannel(channelId); + var stream = provider.GetChannelWriter(channelId); await stream.Publish(1); await stream.Publish(2); @@ -127,7 +127,7 @@ private async Task ClientPublishMultipleChannelTestImpl(IBroadcastChannelProvide channels.Add((id, value)); - await provider.GetChannel(id).Publish(value); + await provider.GetChannelWriter(id).Publish(value); } var grain = _fixture.Client.GetGrain(grainKey); @@ -145,7 +145,7 @@ private async Task MultipleSubscribersChannelTestImpl(IBroadcastChannelProvider { var grainKey = Guid.NewGuid().ToString("N"); var channelId = ChannelId.Create("multiple-namespaces-0", grainKey); - var stream = provider.GetChannel(channelId); + var stream = provider.GetChannelWriter(channelId); await stream.Publish(1); await stream.Publish(2); @@ -181,7 +181,7 @@ private async Task MultipleSubscribersOneBadActorChannelTestImpl(IBroadcastChann { var grainKey = Guid.NewGuid().ToString("N"); var channelId = ChannelId.Create("multiple-namespaces-0", grainKey); - var stream = provider.GetChannel(channelId); + var stream = provider.GetChannelWriter(channelId); var badGrain = _fixture.Client.GetGrain(grainKey); var goodGrain = _fixture.Client.GetGrain(grainKey); From d17398031156932fd1877fcf8a5e1ad5a9a0a2da Mon Sep 17 00:00:00 2001 From: Benjamin Petit Date: Thu, 23 Jun 2022 15:30:14 +0200 Subject: [PATCH 3/3] Make MultipleSubscribersOneBadActorChannelTest more reliable --- .../SimpleStreams/SimpleSubscriberGrain.cs | 6 +++++- .../BroadcastChannels/BroadcastChannelTests.cs | 12 +++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/test/Grains/TestGrains/SimpleStreams/SimpleSubscriberGrain.cs b/test/Grains/TestGrains/SimpleStreams/SimpleSubscriberGrain.cs index 5e737368bf..091d71f72c 100644 --- a/test/Grains/TestGrains/SimpleStreams/SimpleSubscriberGrain.cs +++ b/test/Grains/TestGrains/SimpleStreams/SimpleSubscriberGrain.cs @@ -17,6 +17,8 @@ public interface ISubscriberGrain : IGrainWithStringKey Task> GetValues(ChannelId streamId); + Task GetOnPublishedCounter(); + Task ThrowsOnReceive(bool throwsOnReceive); } @@ -28,11 +30,12 @@ public abstract class SubscriberGrainBase : Grain, ISubscriberGrain, IOnBroadcas { private Dictionary> _values = new(); private Dictionary> _errors = new(); + private int _onPublishedCounter = 0; private bool _throwsOnReceive = false; public Task> GetErrors(ChannelId streamId) => _errors.TryGetValue(streamId, out var errors) ? Task.FromResult(errors) : Task.FromResult(new List()); - public Task> GetValues(ChannelId streamId) => _values.TryGetValue(streamId, out var values) ? Task.FromResult(values) : Task.FromResult(new List()); + public Task GetOnPublishedCounter() => Task.FromResult(_onPublishedCounter); public Task OnSubscribed(IBroadcastChannelSubscription streamSubscription) { @@ -41,6 +44,7 @@ public Task OnSubscribed(IBroadcastChannelSubscription streamSubscription) Task OnPublished(ChannelId id, int item) { + _onPublishedCounter++; if (_throwsOnReceive) { throw new Exception("Some error message here"); diff --git a/test/Tester/StreamingTests/BroadcastChannels/BroadcastChannelTests.cs b/test/Tester/StreamingTests/BroadcastChannels/BroadcastChannelTests.cs index 94c814c982..edea57c3d6 100644 --- a/test/Tester/StreamingTests/BroadcastChannels/BroadcastChannelTests.cs +++ b/test/Tester/StreamingTests/BroadcastChannels/BroadcastChannelTests.cs @@ -196,6 +196,16 @@ private async Task MultipleSubscribersOneBadActorChannelTestImpl(IBroadcastChann if (fireAndForget) { await stream.Publish(2); + // Wait to be sure that published event reached the grain + var counter = 0; + var cts = new CancellationTokenSource(CallTimeoutMs); + while (!cts.IsCancellationRequested) + { + counter = await badGrain.GetOnPublishedCounter(); + if (counter == 1) break; + await Task.Delay(10); + } + Assert.Equal(2, counter); } else { @@ -236,7 +246,7 @@ private async Task MultipleSubscribersOneBadActorChannelTestImpl(IBroadcastChann } } - private static async Task> Get(Func>> func, int expectedCount, int timeoutMs = CallTimeoutMs) + private static async Task> Get(Func>> func, int expectedCount, int timeoutMs = CallTimeoutMs) { var cts = new CancellationTokenSource(timeoutMs); while (!cts.IsCancellationRequested)