-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrate_limit.py
125 lines (88 loc) · 3.27 KB
/
rate_limit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
This module contains functions for rate limiting requests.
The rate limiting system operates on two levels:
1. User-level rate limiting: Each user (identified by a token) has a
configurable minimum interval between requests.
2. System-wide rate limiting: There is a global limit on the total number of
requests across all users within a specified time period.
"""
from datetime import datetime
import signal
import sys
from typing import Dict
from uuid import uuid4
from apscheduler.schedulers import background
import gradio as gr
class InvalidTokenException(Exception):
pass
class UserRateLimitException(Exception):
pass
class SystemRateLimitException(Exception):
pass
class RateLimiter:
def __init__(self, limit=10000, period_in_seconds=60 * 60 * 24):
# Maps tokens to the last time they made a request.
# E.g, {"sometoken": datetime(2021, 8, 1, 0, 0, 0)}
self.last_request_times: Dict[str, datetime] = {}
# The number of requests made.
# This count is reset to zero at the end of each period.
self.request_count = 0
# The maximum number of requests allowed within the time period.
self.limit = limit
self.scheduler = background.BackgroundScheduler()
self.scheduler.add_job(self._remove_old_tokens,
"interval",
seconds=60 * 60 * 24)
self.scheduler.add_job(self._reset_request_count,
"interval",
seconds=period_in_seconds)
self.scheduler.start()
def check_rate_limit(self, token: str):
if not token or not self.token_exists(token):
raise InvalidTokenException()
if (datetime.now() - self.last_request_times[token]).seconds < 5:
raise UserRateLimitException()
if self.request_count >= self.limit:
raise SystemRateLimitException()
self.last_request_times[token] = datetime.now()
self.request_count += 1
def initialize_request(self, token: str):
self.last_request_times[token] = datetime.min
def token_exists(self, token: str):
return token in self.last_request_times
def _remove_old_tokens(self):
for token, last_request_time in dict(self.last_request_times).items():
if (datetime.now() - last_request_time).days >= 1:
del self.last_request_times[token]
def _reset_request_count(self):
self.request_count = 0
rate_limiter = RateLimiter()
def set_token(app: gr.Blocks, token: gr.Textbox):
get_client_token = """
function() {
return localStorage.getItem("arena_token");
}
"""
def set_server_token(existing_token):
if existing_token and rate_limiter.token_exists(existing_token):
return existing_token
new_token = uuid4().hex
rate_limiter.initialize_request(new_token)
return new_token
set_client_token = """
function(newToken) {
localStorage.setItem("arena_token", newToken);
}
"""
app.load(fn=set_server_token,
js=get_client_token,
inputs=[token],
outputs=[token])
token.change(fn=lambda _: None, js=set_client_token, inputs=[token])
def signal_handler(sig, frame):
del sig, frame # Unused.
rate_limiter.scheduler.shutdown()
sys.exit(0)
if gr.NO_RELOAD:
# Catch signal to ensure scheduler shuts down when server stops.
signal.signal(signal.SIGINT, signal_handler)