File size: 3,349 Bytes
762adde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af3d965
762adde
 
 
 
 
 
 
 
 
 
 
 
 
 
af3d965
 
 
762adde
 
 
 
 
 
 
 
 
 
 
 
af3d965
 
 
 
 
 
 
 
 
 
 
762adde
 
 
 
 
af3d965
762adde
af3d965
762adde
 
 
af3d965
 
 
 
 
762adde
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
This module contains functions for rate limiting requests.

The rate limiting system operates on two levels:
1. User-level rate limiting: Each user (identified by a token) has a
   configurable minimum interval between requests.

2. System-wide rate limiting: There is a global limit on the total number of 
   requests across all users within a specified time period.
"""

from datetime import datetime
import signal
import sys
from typing import Dict
from uuid import uuid4

from apscheduler.schedulers import background
import gradio as gr


class InvalidTokenException(Exception):
  pass


class UserRateLimitException(Exception):
  pass


class SystemRateLimitException(Exception):
  pass


class RateLimiter:

  def __init__(self, limit=10000, period_in_seconds=60 * 60 * 24):
    # Maps tokens to the last time they made a request.
    # E.g, {"sometoken": datetime(2021, 8, 1, 0, 0, 0)}
    self.last_request_times: Dict[str, datetime] = {}

    # The number of requests made.
    # This count is reset to zero at the end of each period.
    self.request_count = 0

    # The maximum number of requests allowed within the time period.
    self.limit = limit

    self.scheduler = background.BackgroundScheduler()
    self.scheduler.add_job(self._remove_old_tokens,
                           "interval",
                           seconds=60 * 60 * 24)
    self.scheduler.add_job(self._reset_request_count,
                           "interval",
                           seconds=period_in_seconds)
    self.scheduler.start()

  def check_rate_limit(self, token: str):
    if not token or not self.token_exists(token):
      raise InvalidTokenException()

    if (datetime.now() - self.last_request_times[token]).seconds < 5:
      raise UserRateLimitException()

    if self.request_count >= self.limit:
      raise SystemRateLimitException()

    self.last_request_times[token] = datetime.now()
    self.request_count += 1

  def initialize_request(self, token: str):
    self.last_request_times[token] = datetime.min

  def token_exists(self, token: str):
    return token in self.last_request_times

  def _remove_old_tokens(self):
    for token, last_request_time in dict(self.last_request_times).items():
      if (datetime.now() - last_request_time).days >= 1:
        del self.last_request_times[token]

  def _reset_request_count(self):
    self.request_count = 0


rate_limiter = RateLimiter()


def set_token(app: gr.Blocks, token: gr.Textbox):

  get_client_token = """
  function() {
    return localStorage.getItem("arena_token");
  }
  """

  def set_server_token(existing_token):
    if existing_token and rate_limiter.token_exists(existing_token):
      return existing_token

    new_token = uuid4().hex
    rate_limiter.initialize_request(new_token)
    return new_token

  set_client_token = """
  function(newToken) {
    localStorage.setItem("arena_token", newToken);
  }
  """

  app.load(fn=set_server_token,
           js=get_client_token,
           inputs=[token],
           outputs=[token])
  token.change(fn=lambda _: None, js=set_client_token, inputs=[token])


def signal_handler(sig, frame):
  del sig, frame  # Unused.
  rate_limiter.scheduler.shutdown()
  sys.exit(0)


if gr.NO_RELOAD:
  # Catch signal to ensure scheduler shuts down when server stops.
  signal.signal(signal.SIGINT, signal_handler)