Spaces:
Running
Running
File size: 4,065 Bytes
3c495cc 9e789e7 3c495cc 2a0aa5a 17b1403 3c495cc 17b1403 3c495cc 5352a13 9e789e7 2a0aa5a 762adde 3c495cc 9e789e7 fd9a72d 5352a13 3230abf 17b1403 3230abf 17b1403 3c495cc e4bf3a0 43c8549 3c495cc 43c8549 3c495cc 43c8549 3c495cc 43c8549 af3d965 3c495cc 762adde 2a0aa5a 25a095e fd9a72d 3c495cc 43c8549 3c495cc fd9a72d 3230abf 2a0aa5a 3c495cc fd9a72d 9e789e7 3c495cc fd9a72d 2a0aa5a 25a095e 2a0aa5a 3a85228 2a0aa5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
"""
This module contains functions for generating responses using LLMs.
"""
import enum
import logging
from random import sample
from typing import List
from uuid import uuid4
from firebase_admin import firestore
import gradio as gr
from db import db
from model import ContextWindowExceededError
from model import Model
from model import supported_models
import rate_limit
from rate_limit import rate_limiter
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# TODO(#37): Move DB operations to db.py.
def get_history_collection(category: str):
if category == Category.SUMMARIZE.value:
return db.collection("arena-summarization-history")
if category == Category.TRANSLATE.value:
return db.collection("arena-translation-history")
def create_history(category: str, model_name: str, instruction: str,
prompt: str, response: str):
doc_id = uuid4().hex
doc = {
"id": doc_id,
"model": model_name,
"instruction": instruction,
"prompt": prompt,
"response": response,
"timestamp": firestore.SERVER_TIMESTAMP
}
doc_ref = get_history_collection(category).document(doc_id)
doc_ref.set(doc)
class Category(enum.Enum):
SUMMARIZE = "Summarize"
TRANSLATE = "Translate"
# TODO(#31): Let the model builders set the instruction.
def get_instruction(category: str, model: Model, source_lang: str,
target_lang: str):
if category == Category.SUMMARIZE.value:
return model.summarize_instruction
if category == Category.TRANSLATE.value:
return model.translate_instruction.format(source_lang=source_lang,
target_lang=target_lang)
def get_responses(prompt: str, category: str, source_lang: str,
target_lang: str, token: str):
if not category:
raise gr.Error("Please select a category.")
if category == Category.TRANSLATE.value and (not source_lang or
not target_lang):
raise gr.Error("Please select source and target languages.")
try:
rate_limiter.check_rate_limit(token)
except rate_limit.InvalidTokenException as e:
raise gr.Error(
"Your session has expired. Please refresh the page to continue.") from e
except rate_limit.UserRateLimitException as e:
raise gr.Error(
"You have made too many requests in a short period. Please try again later." # pylint: disable=line-too-long
) from e
except rate_limit.SystemRateLimitException as e:
raise gr.Error(
"Our service is currently experiencing high traffic. Please try again later." # pylint: disable=line-too-long
) from e
models: List[Model] = sample(list(supported_models), 2)
responses = []
got_invalid_response = False
for model in models:
instruction = get_instruction(category, model, source_lang, target_lang)
try:
# TODO(#1): Allow user to set configuration.
response, is_valid_response = model.completion(instruction, prompt)
create_history(category, model.name, instruction, prompt, response)
responses.append(response)
if not is_valid_response:
got_invalid_response = True
except ContextWindowExceededError as e:
logger.exception("Context window exceeded for model %s.", model.name)
raise gr.Error(
"The prompt is too long. Please try again with a shorter prompt."
) from e
except Exception as e:
logger.exception("Failed to get response from model %s.", model.name)
raise gr.Error("Failed to get response. Please try again.") from e
if got_invalid_response:
gr.Warning("An invalid response was received.")
model_names = [model.name for model in models]
# It simulates concurrent stream response generation.
max_response_length = max(len(response) for response in responses)
for i in range(max_response_length):
yield [response[:i + 1] for response in responses
] + model_names + [instruction]
yield responses + model_names + [instruction]
|