Skip to content

Commit

Permalink
Add throttle to rate limit the API calls
Browse files Browse the repository at this point in the history
  • Loading branch information
kochol committed Jan 14, 2021
1 parent f32fb40 commit 5ba8517
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 2 deletions.
9 changes: 7 additions & 2 deletions Server/Controllers/AuthController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
using System.Text;
using System.Threading.Tasks;
using Server.Data;
using Server.Filters;
using Microsoft.Extensions.Caching.Memory;

namespace Server.Controllers
{
Expand All @@ -16,10 +18,10 @@ public class AuthController : ControllerBase
{
private readonly IConfiguration _config;

public AuthController(IConfiguration configuration)
public AuthController(IConfiguration configuration)
{
_config = configuration;
if (LobbyManager._config == null)
if (LobbyManager._config == null)
{
LobbyManager._config = _config;
}
Expand All @@ -33,6 +35,7 @@ public AuthController(IConfiguration configuration)
/// <param name="Password"></param>
/// <returns></returns>
[HttpGet("Register/{deviceId}/{Username}/{Password}")]
[Throttle(TimeUnit = TimeUnit.Minute, Count = 1)]
public async Task<ActionResult<string>> Register(long deviceId, string Username, string Password)
{
var player = await DataContext.Players.GetPlayerByDeviceId(deviceId.ToString());
Expand All @@ -57,6 +60,7 @@ public async Task<ActionResult<string>> Register(long deviceId, string Username,
}

[HttpGet("{Username}/{Password}")]
[Throttle(TimeUnit = TimeUnit.Minute, Count = 5)]
public async Task<ActionResult<string>> Get(string Username, string Password)
{
var player = await DataContext.Players.GetPlayerByUserName(Username);
Expand All @@ -67,6 +71,7 @@ public async Task<ActionResult<string>> Get(string Username, string Password)
}

[HttpGet("{deviceId}/{platformName}/{deviceInfo}")]
[Throttle(TimeUnit = TimeUnit.Minute, Count = 1)]
public async Task<ActionResult<string>> Get(long deviceId, string platformName, string deviceInfo)
{
var player = await DataContext.Players.GetPlayerByDeviceId(deviceId.ToString());
Expand Down
3 changes: 3 additions & 0 deletions Server/Controllers/PlayerController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc;
using Server.Data;
using Server.Filters;

namespace Server.Controllers
{
Expand All @@ -15,6 +16,7 @@ namespace Server.Controllers
public class PlayerController : ControllerBase
{
[HttpGet]
[Throttle(TimeUnit = TimeUnit.Minute, Count = 5)]
public async Task<ActionResult<Player>> Get()
{
var player = await DataContext.Players.GetPlayerById(long.Parse(User.Identity.Name));
Expand Down Expand Up @@ -43,6 +45,7 @@ public async Task<ActionResult<List<Game>>> GetGames(int offset, int count)
}

[HttpGet("name/{player_id}")]
[Throttle(TimeUnit = TimeUnit.Minute, Count = 100)]
public async Task<ActionResult<string>> GetPlayerName(long player_id)
{
var player = await DataContext.Players.GetPlayerById(player_id);
Expand Down
59 changes: 59 additions & 0 deletions Server/Filters/ThrottleAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.Controllers;
using Microsoft.AspNetCore.Mvc.Filters;
using Microsoft.Extensions.Caching.Memory;
using System;
using Microsoft.Extensions.DependencyInjection;

namespace Server.Filters
{
public enum TimeUnit
{
Second = 1,
Minute = 60,
Hour = 3600,
Day = 86400
}

[AttributeUsage(AttributeTargets.Method, AllowMultiple = true)]
public class ThrottleAttribute : ActionFilterAttribute
{
public TimeUnit TimeUnit { get; set; }
public int Count { get; set; }

public override void OnActionExecuting(ActionExecutingContext filterContext)
{
var seconds = Convert.ToInt32(TimeUnit);

var controller = filterContext.ActionDescriptor as ControllerActionDescriptor;
Console.WriteLine(filterContext.HttpContext.Connection.RemoteIpAddress);

var key = string.Join(
"-",
seconds,
filterContext.HttpContext.Request.Method,
controller.ControllerName,
controller.ActionName,
filterContext.HttpContext.Connection.RemoteIpAddress
);

// increment the cache value
var cnt = 1;
IMemoryCache cache = filterContext.HttpContext.RequestServices.GetService<IMemoryCache>();
if (cache.TryGetValue(key, out cnt))
{
cnt++;
}
cache.Set(key, cnt, DateTime.UtcNow.AddSeconds(seconds));

if (cnt > Count)
{
filterContext.Result = new ContentResult
{
Content = "You are allowed to make only " + Count + " requests per " + TimeUnit.ToString().ToLower()
};
filterContext.HttpContext.Response.StatusCode = 429;
}
}
}
}
2 changes: 2 additions & 0 deletions Server/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ public void ConfigureServices(IServiceCollection services)

services.AddControllers();

services.AddMemoryCache();

//add the Swagger services
services.AddSwaggerDocument(document =>
{
Expand Down

0 comments on commit 5ba8517

Please sign in to comment.