diff --git a/src/RateLimiter/Counter.php b/src/RateLimiter/Counter.php new file mode 100644 index 000000000..753b36260 --- /dev/null +++ b/src/RateLimiter/Counter.php @@ -0,0 +1,154 @@ +limit = $limit; + $this->period = $period * self::MILLISECONDS_PER_SECOND; + $this->storage = $storage; + + $this->incrementInterval = (float)($this->period / $this->limit); + } + + public function setId(string $id): void + { + $this->id = $id; + } + + /** + * @param int $ttl cache TTL that is used to store counter values + * Default is one day. + * Note that period can not exceed TTL. + */ + public function setTtl(int $ttl): void + { + $this->ttl = $ttl; + } + + public function getCacheKey(): string + { + return self::ID_PREFIX . $this->id; + } + + public function incrementAndGetState(): CounterState + { + if ($this->id === null) { + throw new \LogicException('The counter ID should be set'); + } + + $this->lastIncrementTime = $this->getCurrentTime(); + $theoreticalNextIncrementTime = $this->calculateTheoreticalNextIncrementTime( + $this->getLastStoredTheoreticalNextIncrementTime() + ); + $remaining = $this->calculateRemaining($theoreticalNextIncrementTime); + $resetAfter = $this->calculateResetAfter($theoreticalNextIncrementTime); + + if ($remaining >= 1) { + $this->storeTheoreticalNextIncrementTime($theoreticalNextIncrementTime); + } + + return new CounterState($this->limit, $remaining, $resetAfter); + } + + /** + * @param float $storedTheoreticalNextIncrementTime + * @return float theoretical increment time that would be expected from equally spaced increments at exactly rate limit + * In GCRA it is known as TAT, theoretical arrival time. + */ + private function calculateTheoreticalNextIncrementTime(float $storedTheoreticalNextIncrementTime): float + { + return max($this->lastIncrementTime, $storedTheoreticalNextIncrementTime) + $this->incrementInterval; + } + + /** + * @param float $theoreticalNextIncrementTime + * @return int the number of remaining requests in the current time period + */ + private function calculateRemaining(float $theoreticalNextIncrementTime): int + { + $incrementAllowedAt = $theoreticalNextIncrementTime - $this->period; + + return (int)(round($this->lastIncrementTime - $incrementAllowedAt) / $this->incrementInterval); + } + + private function getLastStoredTheoreticalNextIncrementTime(): float + { + return $this->storage->get($this->getCacheKey(), (float)$this->lastIncrementTime); + } + + private function storeTheoreticalNextIncrementTime(float $theoreticalNextIncrementTime): void + { + $this->storage->set($this->getCacheKey(), $theoreticalNextIncrementTime, $this->ttl); + } + + /** + * @param float $theoreticalNextIncrementTime + * @return int timestamp to wait until the rate limit resets + */ + private function calculateResetAfter(float $theoreticalNextIncrementTime): int + { + return (int)($theoreticalNextIncrementTime / self::MILLISECONDS_PER_SECOND); + } + + private function getCurrentTime(): int + { + return (int)round(microtime(true) * self::MILLISECONDS_PER_SECOND); + } +} diff --git a/src/RateLimiter/CounterInterface.php b/src/RateLimiter/CounterInterface.php new file mode 100644 index 000000000..a7b01c697 --- /dev/null +++ b/src/RateLimiter/CounterInterface.php @@ -0,0 +1,23 @@ +limit = $limit; + $this->remaining = $remaining; + $this->reset = $reset; + } + + /** + * @return int the maximum number of requests allowed with a time period + */ + public function getLimit(): int + { + return $this->limit; + } + + /** + * @return int the number of remaining requests in the current time period + */ + public function getRemaining(): int + { + return $this->remaining; + } + + /** + * @return int timestamp to wait until the rate limit resets + */ + public function getResetTime(): int + { + return $this->reset; + } + + /** + * @return bool if requests limit is reached + */ + public function isLimitReached(): bool + { + return $this->remaining === 0; + } +} diff --git a/src/RateLimiter/RateLimiterMiddleware.php b/src/RateLimiter/RateLimiterMiddleware.php new file mode 100644 index 000000000..5778dead7 --- /dev/null +++ b/src/RateLimiter/RateLimiterMiddleware.php @@ -0,0 +1,99 @@ +counter = $counter; + $this->responseFactory = $responseFactory; + } + + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface + { + $this->counter->setId($this->generateId($request)); + $result = $this->counter->incrementAndGetState(); + + if ($result->isLimitReached()) { + $response = $this->createErrorResponse(); + } else { + $response = $handler->handle($request); + } + + return $this->addHeaders($response, $result); + } + + public function withCounterIdCallback(?callable $callback): self + { + $new = clone $this; + $new->counterIdCallback = $callback; + + return $new; + } + + public function withCounterId(string $id): self + { + $new = clone $this; + $new->counterId = $id; + + return $new; + } + + private function createErrorResponse(): ResponseInterface + { + $response = $this->responseFactory->createResponse(Status::TOO_MANY_REQUESTS); + $response->getBody()->write(Status::TEXTS[Status::TOO_MANY_REQUESTS]); + + return $response; + } + + private function generateId(ServerRequestInterface $request): string + { + if ($this->counterIdCallback !== null) { + return \call_user_func($this->counterIdCallback, $request); + } + + return $this->counterId ?? $this->generateIdFromRequest($request); + } + + private function generateIdFromRequest(ServerRequestInterface $request): string + { + return strtolower($request->getMethod() . '-' . $request->getUri()->getPath()); + } + + private function addHeaders(ResponseInterface $response, CounterState $result): ResponseInterface + { + return $response + ->withHeader('X-Rate-Limit-Limit', $result->getLimit()) + ->withHeader('X-Rate-Limit-Remaining', $result->getRemaining()) + ->withHeader('X-Rate-Limit-Reset', $result->getResetTime()); + } +} diff --git a/tests/RateLimiter/CounterTest.php b/tests/RateLimiter/CounterTest.php new file mode 100644 index 000000000..e605f6d35 --- /dev/null +++ b/tests/RateLimiter/CounterTest.php @@ -0,0 +1,93 @@ +setId('key'); + + $statistics = $counter->incrementAndGetState(); + $this->assertEquals(2, $statistics->getLimit()); + $this->assertEquals(1, $statistics->getRemaining()); + $this->assertGreaterThanOrEqual(time(), $statistics->getResetTime()); + $this->assertFalse($statistics->isLimitReached()); + } + + /** + * @test + */ + public function statisticsShouldBeCorrectWhenLimitIsReached(): void + { + $counter = new Counter(2, 4, new ArrayCache()); + $counter->setId('key'); + + $statistics = $counter->incrementAndGetState(); + $this->assertEquals(2, $statistics->getLimit()); + $this->assertEquals(1, $statistics->getRemaining()); + $this->assertGreaterThanOrEqual(time(), $statistics->getResetTime()); + $this->assertFalse($statistics->isLimitReached()); + + $statistics = $counter->incrementAndGetState(); + $this->assertEquals(2, $statistics->getLimit()); + $this->assertEquals(0, $statistics->getRemaining()); + $this->assertGreaterThanOrEqual(time(), $statistics->getResetTime()); + $this->assertTrue($statistics->isLimitReached()); + } + + /** + * @test + */ + public function shouldNotBeAbleToSetInvalidId(): void + { + $this->expectException(\LogicException::class); + (new Counter(10, 60, new ArrayCache()))->incrementAndGetState(); + } + + /** + * @test + */ + public function shouldNotBeAbleToSetInvalidLimit(): void + { + $this->expectException(InvalidArgumentException::class); + new Counter(0, 60, new ArrayCache()); + } + + /** + * @test + */ + public function shouldNotBeAbleToSetInvalidPeriod(): void + { + $this->expectException(InvalidArgumentException::class); + new Counter(10, 0, new ArrayCache()); + } + + /** + * @test + */ + public function incrementMustBeUniformAfterLimitIsReached(): void + { + $counter = new Counter(10, 1, new ArrayCache()); + $counter->setId('key'); + + for ($i = 0; $i < 10; $i++) { + $counter->incrementAndGetState(); + } + + for ($i = 0; $i < 5; $i++) { + usleep(110000); // period(microseconds) / limit + 10ms(cost work) + $statistics = $counter->incrementAndGetState(); + $this->assertEquals(1, $statistics->getRemaining()); + } + } +} diff --git a/tests/RateLimiter/FakeCounter.php b/tests/RateLimiter/FakeCounter.php new file mode 100644 index 000000000..60959df9e --- /dev/null +++ b/tests/RateLimiter/FakeCounter.php @@ -0,0 +1,37 @@ +reset = $reset; + $this->limit = $limit; + $this->remaining = $limit; + } + + public function setId(string $id): void + { + $this->id = $id; + } + + public function getId(): ?string + { + return $this->id; + } + + public function incrementAndGetState(): CounterState + { + $this->remaining--; + return new CounterState($this->limit, $this->remaining, $this->reset); + } +} diff --git a/tests/RateLimiter/RateLimiterMiddlewareTest.php b/tests/RateLimiter/RateLimiterMiddlewareTest.php new file mode 100644 index 000000000..0c6d7b170 --- /dev/null +++ b/tests/RateLimiter/RateLimiterMiddlewareTest.php @@ -0,0 +1,116 @@ +createRateLimiter($counter)->process($this->createRequest(), $this->createRequestHandler()); + $this->assertEquals(200, $response->getStatusCode()); + + $this->assertEquals( + [ + 'X-Rate-Limit-Limit' => ['100'], + 'X-Rate-Limit-Remaining' => ['99'], + 'X-Rate-Limit-Reset' => ['100'], + ], + $response->getHeaders() + ); + } + + /** + * @test + */ + public function limitingIsStartedWhenExpected(): void + { + $counter = new FakeCounter(2, 100); + $middleware = $this->createRateLimiter($counter); + + // last allowed request + $response = $middleware->process($this->createRequest(), $this->createRequestHandler()); + $this->assertEquals(200, $response->getStatusCode()); + $this->assertEquals( + [ + 'X-Rate-Limit-Limit' => ['2'], + 'X-Rate-Limit-Remaining' => ['1'], + 'X-Rate-Limit-Reset' => ['100'], + ], + $response->getHeaders() + ); + + // first denied request + $response = $middleware->process($this->createRequest(), $this->createRequestHandler()); + $this->assertEquals(429, $response->getStatusCode()); + $this->assertEquals( + [ + 'X-Rate-Limit-Limit' => ['2'], + 'X-Rate-Limit-Remaining' => ['0'], + 'X-Rate-Limit-Reset' => ['100'], + ], + $response->getHeaders() + ); + } + + /** + * @test + */ + public function counterIdCouldBeSet(): void + { + $counter = new FakeCounter(100, 100); + $middleware = $this->createRateLimiter($counter)->withCounterId('custom-id'); + $middleware->process($this->createRequest(), $this->createRequestHandler()); + $this->assertEquals('custom-id', $counter->getId()); + } + + /** + * @test + */ + public function counterIdCouldBeSetWithCallback(): void + { + $counter = new FakeCounter(100, 100); + $middleware = $this->createRateLimiter($counter)->withCounterIdCallback( + static function (ServerRequestInterface $request) { + return $request->getMethod(); + } + ); + + $middleware->process($this->createRequest(), $this->createRequestHandler()); + $this->assertEquals('GET', $counter->getId()); + } + + private function createRequestHandler(): RequestHandlerInterface + { + return new class implements RequestHandlerInterface { + public function handle(ServerRequestInterface $request): ResponseInterface + { + return new Response(200); + } + }; + } + + private function createRequest(string $method = Method::GET, string $uri = '/'): ServerRequestInterface + { + return new ServerRequest($method, $uri); + } + + private function createRateLimiter(CounterInterface $counter): RateLimiterMiddleware + { + return new RateLimiterMiddleware($counter, new Psr17Factory()); + } +}