Skip to content
This repository has been archived by the owner on Jun 29, 2022. It is now read-only.

Rate limiter implementation #203

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions src/Middleware/RateLimiter.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
<?php
declare(strict_types=1);

namespace Yiisoft\Yii\Web\Middleware;

use Psr\Http\Message\ResponseFactoryInterface;
use Psr\Http\Message\ResponseInterface;
use Psr\Http\Message\ServerRequestInterface;
use Psr\Http\Server\MiddlewareInterface;
use Psr\Http\Server\RequestHandlerInterface;
use Psr\SimpleCache\CacheInterface;

/**
* Rate limiter limits the number of requests that could be made within a certain period of time
*/
final class RateLimiter implements MiddlewareInterface
samdark marked this conversation as resolved.
Show resolved Hide resolved
{
private int $limit = 1000;

private ?string $cacheKey = null;

/**
* @var callable
*/
private $cacheKeyCallback;

private int $cacheTtl = 360;

private CacheInterface $cache;

private ResponseFactoryInterface $responseFactory;

private bool $autoincrement = true;

public function __construct(CacheInterface $cache, ResponseFactoryInterface $responseFactory)
{
$this->cache = $cache;
$this->responseFactory = $responseFactory;
}

public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface
{
$this->setupCacheParams($request);

if (!$this->isAllowed()) {
return $this->createErrorResponse();
}

if ($this->autoincrement) {
$this->increment();
}

return $handler->handle($request);
}

public function setLimit(int $limit): self
romkatsu marked this conversation as resolved.
Show resolved Hide resolved
{
$this->limit = $limit;

return $this;
}

public function setCacheKey(string $key): self
{
$this->cacheKey = $key;

return $this;
}

public function setCacheKeyByCallback(callable $callback): self
romkatsu marked this conversation as resolved.
Show resolved Hide resolved
{
$this->cacheKeyCallback = $callback;

return $this;
}

public function setCacheTtl(int $ttl): self
romkatsu marked this conversation as resolved.
Show resolved Hide resolved
{
$this->cacheTtl = $ttl;

return $this;
}

public function setAutoIncrement(bool $increment): self
romkatsu marked this conversation as resolved.
Show resolved Hide resolved
{
$this->autoincrement = $increment;

return $this;
}

private function createErrorResponse(): ResponseInterface
{
$response = $this->responseFactory->createResponse(429);
$response->getBody()->write('Too Many Requests');

return $response;
}

private function isAllowed(): bool
{
return $this->getCounterValue() < $this->limit;
}

private function increment(): void
{
$value = $this->getCounterValue();
$value++;

$this->setCounterValue($value);
}

private function setupCacheParams(ServerRequestInterface $request): void
{
$this->cacheKey = $this->setupCacheKey($request);

if (!$this->hasCounterValue()) {
romkatsu marked this conversation as resolved.
Show resolved Hide resolved
$this->setCounterValue(0);
}
}

private function setupCacheKey(ServerRequestInterface $request): string
{
if ($this->cacheKeyCallback !== null) {
return \call_user_func($this->cacheKeyCallback, $request);
}

return $this->cacheKey ?? $this->generateCacheKey($request);
}

private function generateCacheKey(ServerRequestInterface $request): string
{
return strtolower('rate-limiter-' . $request->getMethod() . '-' . $request->getUri()->getPath());
}

private function getCounterValue(): int
{
return $this->cache->get($this->cacheKey, 0);
}

private function setCounterValue(int $value): void
{
$this->cache->set($this->cacheKey, $value, $this->cacheTtl);
}

private function hasCounterValue(): bool
{
return $this->cache->has($this->cacheKey);
}
}
161 changes: 161 additions & 0 deletions tests/Middleware/RateLimiterTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
<?php

namespace Yiisoft\Yii\Web\Tests\Middleware;

use Nyholm\Psr7\Factory\Psr17Factory;
use Nyholm\Psr7\Response;
use Nyholm\Psr7\ServerRequest;
use PHPUnit\Framework\TestCase;
use Psr\Http\Message\ResponseInterface;
use Psr\Http\Message\ServerRequestInterface;
use Psr\Http\Server\RequestHandlerInterface;
use Psr\SimpleCache\CacheInterface;
use Yiisoft\Cache\ArrayCache;
use Yiisoft\Http\Method;
use Yiisoft\Yii\Web\Middleware\RateLimiter;

final class RateLimiterTest extends TestCase
{
/**
* @test
*/
public function singleRequestIsAllowed(): void
{
$middleware = $this->createRateLimiter($this->getCache());
$response = $middleware->process($this->createRequest(), $this->createRequestHandler());
$this->assertEquals(200, $response->getStatusCode());
}

/**
* @test
*/
public function moreThanDefaultNumberOfRequestsIsNotAllowed(): void
{
$cache = $this->getCache();
$this->setRateLimiterCurrentRequestNumber($cache, 1000);
romkatsu marked this conversation as resolved.
Show resolved Hide resolved

$middleware = $this->createRateLimiter($cache);
$response = $middleware->process($this->createRequest(), $this->createRequestHandler());
$this->assertEquals(429, $response->getStatusCode());
}

/**
* @test
*/
public function customLimitWorksAsExpected(): void
{
$cache = $this->getCache();
$this->setRateLimiterCurrentRequestNumber($cache, 10);

$middleware = $this->createRateLimiter($cache)->setLimit(11);

$response = $middleware->process($this->createRequest(), $this->createRequestHandler());
$this->assertEquals(200, $response->getStatusCode());

$response = $middleware->process($this->createRequest(), $this->createRequestHandler());
$this->assertEquals(429, $response->getStatusCode());
}

/**
* @test
*/
public function customCacheKey(): void
{
$cache = $this->getCache();
$cache->set('custom-cache-key', 999);

$middleware = $this->createRateLimiter($cache)->setCacheKey('custom-cache-key');

$response = $middleware->process($this->createRequest(), $this->createRequestHandler());
$this->assertEquals(200, $response->getStatusCode());

$response = $middleware->process($this->createRequest(), $this->createRequestHandler());
$this->assertEquals(429, $response->getStatusCode());
}

/**
* @test
*/
public function customCacheKeyCallback(): void
{
$cache = $this->getCache();
$cache->set('POST', 1000);

$middleware = $this->createRateLimiter($cache)
->setCacheKeyByCallback(
static function (ServerRequestInterface $request) {
return $request->getMethod();
}
);

$response = $middleware->process($this->createRequest(), $this->createRequestHandler());
$this->assertEquals(200, $response->getStatusCode());

$response = $middleware->process($this->createRequest(Method::POST), $this->createRequestHandler());
$this->assertEquals(429, $response->getStatusCode());
}

/**
* @test
*/
public function customCacheTtl(): void
{
$middleware = $this->createRateLimiter($this->getCache())
->setLimit(1)
->setCacheTtl(1);

$response = $middleware->process($this->createRequest(), $this->createRequestHandler());
$this->assertEquals(200, $response->getStatusCode());

$response = $middleware->process($this->createRequest(), $this->createRequestHandler());
$this->assertEquals(429, $response->getStatusCode());

sleep(1);

$response = $middleware->process($this->createRequest(), $this->createRequestHandler());
$this->assertEquals(200, $response->getStatusCode());
}

/**
* @test
*/
public function disableAutoIncrement(): void
{
$cache = $this->getCache();

$middleware = $this->createRateLimiter($cache)->setAutoIncrement(false);
$response = $middleware->process($this->createRequest(), $this->createRequestHandler());
$this->assertEquals(200, $response->getStatusCode());
$this->assertEquals(0, $cache->get('rate-limiter-get-/'));
}

private function setRateLimiterCurrentRequestNumber(CacheInterface $cache, int $number, $method = 'get', $path = '/'): void
{
$cache->set("rate-limiter-$method-$path", 1000);
}

private function getCache(): CacheInterface
{
return new ArrayCache();
}

private function createRequestHandler(): RequestHandlerInterface
{
return new class implements RequestHandlerInterface {
public function handle(ServerRequestInterface $request): ResponseInterface
{
return new Response(200);
}
};
}

private function createRequest(string $method = Method::GET, string $uri = '/'): ServerRequestInterface
{
return new ServerRequest($method, $uri);
}

private function createRateLimiter(CacheInterface $cache): RateLimiter
{
return new RateLimiter($cache, new Psr17Factory());
}
}