diff --git a/src/Polly/RateLimit/RateLimitPolicy.cs b/src/Polly/RateLimit/RateLimitPolicy.cs index 95c891fc14..34a540cf8d 100644 --- a/src/Polly/RateLimit/RateLimitPolicy.cs +++ b/src/Polly/RateLimit/RateLimitPolicy.cs @@ -4,7 +4,6 @@ namespace Polly.RateLimit; /// <summary> /// A rate-limit policy that can be applied to synchronous delegates. /// </summary> -#pragma warning disable CA1062 // Validate arguments of public methods public class RateLimitPolicy : Policy, IRateLimitPolicy { private readonly IRateLimiter _rateLimiter; @@ -14,8 +13,15 @@ internal RateLimitPolicy(IRateLimiter rateLimiter) => /// <inheritdoc/> [DebuggerStepThrough] - protected override TResult Implementation<TResult>(Func<Context, CancellationToken, TResult> action, Context context, CancellationToken cancellationToken) => - RateLimitEngine.Implementation(_rateLimiter, null, action, context, cancellationToken); + protected override TResult Implementation<TResult>(Func<Context, CancellationToken, TResult> action, Context context, CancellationToken cancellationToken) + { + if (action is null) + { + throw new ArgumentNullException(nameof(action)); + } + + return RateLimitEngine.Implementation(_rateLimiter, null, action, context, cancellationToken); + } } /// <summary> @@ -37,6 +43,13 @@ internal RateLimitPolicy( /// <inheritdoc/> [DebuggerStepThrough] - protected override TResult Implementation(Func<Context, CancellationToken, TResult> action, Context context, CancellationToken cancellationToken) => - RateLimitEngine.Implementation(_rateLimiter, _retryAfterFactory, action, context, cancellationToken); + protected override TResult Implementation(Func<Context, CancellationToken, TResult> action, Context context, CancellationToken cancellationToken) + { + if (action is null) + { + throw new ArgumentNullException(nameof(action)); + } + + return RateLimitEngine.Implementation(_rateLimiter, _retryAfterFactory, action, context, cancellationToken); + } } diff --git a/test/Polly.Specs/RateLimit/RateLimitPolicySpecs.cs b/test/Polly.Specs/RateLimit/RateLimitPolicySpecs.cs index eb7329856e..d5b7a25ed6 100644 --- a/test/Polly.Specs/RateLimit/RateLimitPolicySpecs.cs +++ b/test/Polly.Specs/RateLimit/RateLimitPolicySpecs.cs @@ -31,4 +31,30 @@ protected override (bool, TimeSpan) TryExecuteThroughPolicy(IRateLimitPolicy pol throw new InvalidOperationException("Unexpected policy type in test construction."); } } + + [Fact] + public void Should_throw_when_action_is_null() + { + var flags = BindingFlags.NonPublic | BindingFlags.Instance; + Func<Context, CancellationToken, EmptyStruct> action = null!; + IRateLimiter rateLimiter = RateLimiterFactory.Create(TimeSpan.FromSeconds(1), 1); + + var instance = Activator.CreateInstance( + typeof(RateLimitPolicy), + flags, + null, + [rateLimiter], + null)!; + var instanceType = instance.GetType(); + var methods = instanceType.GetMethods(flags); + var methodInfo = methods.First(method => method is { Name: "Implementation", ReturnType.Name: "TResult" }); + var generic = methodInfo.MakeGenericMethod(typeof(EmptyStruct)); + + var func = () => generic.Invoke(instance, [action, new Context(), CancellationToken.None]); + + var exceptionAssertions = func.Should().Throw<TargetInvocationException>(); + exceptionAssertions.And.Message.Should().Be("Exception has been thrown by the target of an invocation."); + exceptionAssertions.And.InnerException.Should().BeOfType<ArgumentNullException>() + .Which.ParamName.Should().Be("action"); + } } diff --git a/test/Polly.Specs/RateLimit/RateLimitPolicyTResultSpecs.cs b/test/Polly.Specs/RateLimit/RateLimitPolicyTResultSpecs.cs index d94142b748..99fc2a224c 100644 --- a/test/Polly.Specs/RateLimit/RateLimitPolicyTResultSpecs.cs +++ b/test/Polly.Specs/RateLimit/RateLimitPolicyTResultSpecs.cs @@ -47,4 +47,30 @@ protected override TResult TryExecuteThroughPolicy<TResult>(IRateLimitPolicy<TRe throw new InvalidOperationException("Unexpected policy type in test construction."); } } + + [Fact] + public void Should_throw_when_action_is_null() + { + var flags = BindingFlags.NonPublic | BindingFlags.Instance; + Func<Context, CancellationToken, EmptyStruct> action = null!; + IRateLimiter rateLimiter = RateLimiterFactory.Create(TimeSpan.FromSeconds(1), 1); + Func<TimeSpan, Context, EmptyStruct>? retryAfterFactory = null; + + var instance = Activator.CreateInstance( + typeof(RateLimitPolicy<EmptyStruct>), + flags, + null, + [rateLimiter, retryAfterFactory], + null)!; + var instanceType = instance.GetType(); + var methods = instanceType.GetMethods(flags); + var methodInfo = methods.First(method => method is { Name: "Implementation", ReturnType.Name: "EmptyStruct" }); + + var func = () => methodInfo.Invoke(instance, [action, new Context(), CancellationToken.None]); + + var exceptionAssertions = func.Should().Throw<TargetInvocationException>(); + exceptionAssertions.And.Message.Should().Be("Exception has been thrown by the target of an invocation."); + exceptionAssertions.And.InnerException.Should().BeOfType<ArgumentNullException>() + .Which.ParamName.Should().Be("action"); + } }