diff --git a/changelog.d/13018.bugfix b/changelog.d/13018.bugfix new file mode 100644 index 000000000000..a84657f04f67 --- /dev/null +++ b/changelog.d/13018.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which meant that rate limiting was not restrictive enough in some cases. \ No newline at end of file diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 849c18ceda16..54d13026c9e5 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -128,6 +128,9 @@ async def can_do_action( performed_count = action_count - time_delta * rate_hz if performed_count < 0: performed_count = 0 + + # Reset the start time and forgive all actions + action_count = 0 time_start = time_now_s # This check would be easier read as performed_count + n_actions > burst_count, @@ -140,7 +143,7 @@ async def can_do_action( else: # We haven't reached our limit yet allowed = True - action_count = performed_count + n_actions + action_count = action_count + n_actions if update: self.actions[key] = (action_count, time_start, rate_hz) diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index f661a9ff8e2f..18649c2c05dc 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -246,7 +246,7 @@ def test_multiple_actions(self): self.assertTrue(allowed) self.assertEqual(10.0, time_allowed) - # Test that, after doing these 3 actions, we can't do any more action without + # Test that, after doing these 3 actions, we can't do any more actions without # waiting. allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", n_actions=1, _time_now_s=0) @@ -254,7 +254,8 @@ def test_multiple_actions(self): self.assertFalse(allowed) self.assertEqual(10.0, time_allowed) - # Test that after waiting we can do only 1 action. + # Test that after waiting we would be able to do only 1 action. + # Note that we don't actually do it (update=False) here. allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action( None, @@ -265,23 +266,51 @@ def test_multiple_actions(self): ) ) self.assertTrue(allowed) - # The time allowed is the current time because we could still repeat the action - # once. - self.assertEqual(10.0, time_allowed) + # We would be able to do the 5th action at t=20. + self.assertEqual(20.0, time_allowed) + # Attempt (but fail) to perform TWO actions at t=10. + # Those would be the 4th and 5th actions. allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=10) ) self.assertFalse(allowed) - # The time allowed doesn't change despite allowed being False because, while we - # don't allow 2 actions, we could still do 1. + # The returned time allowed for the next action is now even though we weren't + # allowed to perform the action because whilst we don't allow 2 actions, + # we could still do 1. self.assertEqual(10.0, time_allowed) - # Test that after waiting a bit more we can do 2 actions. + # Test that after waiting until t=20, we can do perform 2 actions. + # These are the 4th and 5th actions. allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=20) ) self.assertTrue(allowed) - # The time allowed is the current time because we could still repeat the action - # once. - self.assertEqual(20.0, time_allowed) + # We would be able to do the 6th action at t=30. + self.assertEqual(30.0, time_allowed) + + def test_rate_limit_burst_only_given_once(self) -> None: + """ + Regression test against a bug that meant that you could build up + extra tokens by timing requests. + """ + limiter = Ratelimiter( + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + ) + + def consume_at(time: float) -> bool: + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=time) + ) + return success + + # Use all our 3 burst tokens + self.assertTrue(consume_at(0.0)) + self.assertTrue(consume_at(0.1)) + self.assertTrue(consume_at(0.2)) + + # Wait to recover 1 token (10 seconds at 0.1 Hz). + self.assertTrue(consume_at(10.1)) + + # Check that we get rate limited after using that token. + self.assertFalse(consume_at(11.1))