File size: 4,065 Bytes
3c495cc
 
 
 
 
9e789e7
3c495cc
2a0aa5a
17b1403
3c495cc
17b1403
3c495cc
 
5352a13
9e789e7
2a0aa5a
 
762adde
 
3c495cc
9e789e7
 
 
 
fd9a72d
5352a13
3230abf
 
 
 
 
 
 
 
 
 
17b1403
 
 
 
 
 
 
 
 
 
 
3230abf
17b1403
 
 
3c495cc
 
 
 
 
e4bf3a0
43c8549
 
3c495cc
43c8549
 
3c495cc
43c8549
 
3c495cc
 
43c8549
af3d965
3c495cc
 
 
 
 
 
 
762adde
 
 
 
 
 
 
 
 
 
 
 
 
 
2a0aa5a
25a095e
fd9a72d
3c495cc
43c8549
3c495cc
 
fd9a72d
3230abf
2a0aa5a
3c495cc
fd9a72d
 
 
9e789e7
 
 
 
 
 
 
 
3c495cc
fd9a72d
 
 
2a0aa5a
 
25a095e
 
 
2a0aa5a
 
3a85228
2a0aa5a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
This module contains functions for generating responses using LLMs.
"""

import enum
import logging
from random import sample
from typing import List
from uuid import uuid4

from firebase_admin import firestore
import gradio as gr

from db import db
from model import ContextWindowExceededError
from model import Model
from model import supported_models
import rate_limit
from rate_limit import rate_limiter

logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


# TODO(#37): Move DB operations to db.py.
def get_history_collection(category: str):
  if category == Category.SUMMARIZE.value:
    return db.collection("arena-summarization-history")

  if category == Category.TRANSLATE.value:
    return db.collection("arena-translation-history")


def create_history(category: str, model_name: str, instruction: str,
                   prompt: str, response: str):
  doc_id = uuid4().hex

  doc = {
      "id": doc_id,
      "model": model_name,
      "instruction": instruction,
      "prompt": prompt,
      "response": response,
      "timestamp": firestore.SERVER_TIMESTAMP
  }

  doc_ref = get_history_collection(category).document(doc_id)
  doc_ref.set(doc)


class Category(enum.Enum):
  SUMMARIZE = "Summarize"
  TRANSLATE = "Translate"


# TODO(#31): Let the model builders set the instruction.
def get_instruction(category: str, model: Model, source_lang: str,
                    target_lang: str):
  if category == Category.SUMMARIZE.value:
    return model.summarize_instruction

  if category == Category.TRANSLATE.value:
    return model.translate_instruction.format(source_lang=source_lang,
                                              target_lang=target_lang)


def get_responses(prompt: str, category: str, source_lang: str,
                  target_lang: str, token: str):
  if not category:
    raise gr.Error("Please select a category.")

  if category == Category.TRANSLATE.value and (not source_lang or
                                               not target_lang):
    raise gr.Error("Please select source and target languages.")

  try:
    rate_limiter.check_rate_limit(token)
  except rate_limit.InvalidTokenException as e:
    raise gr.Error(
        "Your session has expired. Please refresh the page to continue.") from e
  except rate_limit.UserRateLimitException as e:
    raise gr.Error(
        "You have made too many requests in a short period. Please try again later."  # pylint: disable=line-too-long
    ) from e
  except rate_limit.SystemRateLimitException as e:
    raise gr.Error(
        "Our service is currently experiencing high traffic. Please try again later."  # pylint: disable=line-too-long
    ) from e

  models: List[Model] = sample(list(supported_models), 2)
  responses = []
  got_invalid_response = False
  for model in models:
    instruction = get_instruction(category, model, source_lang, target_lang)
    try:
      # TODO(#1): Allow user to set configuration.
      response, is_valid_response = model.completion(instruction, prompt)
      create_history(category, model.name, instruction, prompt, response)
      responses.append(response)

      if not is_valid_response:
        got_invalid_response = True

    except ContextWindowExceededError as e:
      logger.exception("Context window exceeded for model %s.", model.name)
      raise gr.Error(
          "The prompt is too long. Please try again with a shorter prompt."
      ) from e
    except Exception as e:
      logger.exception("Failed to get response from model %s.", model.name)
      raise gr.Error("Failed to get response. Please try again.") from e

  if got_invalid_response:
    gr.Warning("An invalid response was received.")

  model_names = [model.name for model in models]

  # It simulates concurrent stream response generation.
  max_response_length = max(len(response) for response in responses)
  for i in range(max_response_length):
    yield [response[:i + 1] for response in responses
          ] + model_names + [instruction]

  yield responses + model_names + [instruction]