Spaces:
Running
Running
File size: 2,958 Bytes
3c495cc 2a0aa5a 17b1403 3c495cc 17b1403 3c495cc 17b1403 2a0aa5a 3c495cc 3230abf 17b1403 3230abf 17b1403 3c495cc e4bf3a0 43c8549 3c495cc 43c8549 3c495cc 43c8549 3c495cc 43c8549 3c495cc 2a0aa5a 25a095e 3c495cc 43c8549 3c495cc 43c8549 3230abf 2a0aa5a 3c495cc 2a0aa5a af115bf 3c495cc 2a0aa5a 25a095e 2a0aa5a 3a85228 2a0aa5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
"""
This module contains functions for generating responses using LLMs.
"""
import enum
from random import sample
from typing import List
from uuid import uuid4
from firebase_admin import firestore
import gradio as gr
from leaderboard import db
from model import Model
from model import supported_models
def get_history_collection(category: str):
if category == Category.SUMMARIZE.value:
return db.collection("arena-summarization-history")
if category == Category.TRANSLATE.value:
return db.collection("arena-translation-history")
def create_history(category: str, model_name: str, instruction: str,
prompt: str, response: str):
doc_id = uuid4().hex
doc = {
"id": doc_id,
"model": model_name,
"instruction": instruction,
"prompt": prompt,
"response": response,
"timestamp": firestore.SERVER_TIMESTAMP
}
doc_ref = get_history_collection(category).document(doc_id)
doc_ref.set(doc)
class Category(enum.Enum):
SUMMARIZE = "Summarize"
TRANSLATE = "Translate"
# TODO(#31): Let the model builders set the instruction.
def get_instruction(category: str, model: Model, source_lang: str,
target_lang: str):
if category == Category.SUMMARIZE.value:
return model.summarize_instruction
if category == Category.TRANSLATE.value:
return model.translate_instruction.format(source_lang=source_lang,
target_lang=target_lang)
def get_responses(prompt: str, category: str, source_lang: str,
target_lang: str):
if not category:
raise gr.Error("Please select a category.")
if category == Category.TRANSLATE.value and (not source_lang or
not target_lang):
raise gr.Error("Please select source and target languages.")
models: List[Model] = sample(list(supported_models), 2)
responses = []
for model in models:
instruction = get_instruction(category, model, source_lang, target_lang)
try:
# TODO(#1): Allow user to set configuration.
response = model.completion(messages=[{
"role": "system",
"content": instruction
}, {
"role": "user",
"content": prompt
}])
create_history(category, model.name, instruction, prompt, response)
responses.append(response)
# TODO(#1): Narrow down the exception type.
except Exception as e: # pylint: disable=broad-except
print(f"Error with model {model.name}: {e}")
raise gr.Error("Failed to get response. Please try again.")
model_names = [model.name for model in models]
# It simulates concurrent stream response generation.
max_response_length = max(len(response) for response in responses)
for i in range(max_response_length):
yield [response[:i + 1] for response in responses
] + model_names + [instruction]
yield responses + model_names + [instruction]
|