Spaces:
Running
Running
File size: 3,804 Bytes
3c495cc b695eaf 3c495cc 17b1403 3c495cc 17b1403 b695eaf 50a5912 3c495cc 50a5912 17b1403 50a5912 b695eaf 50a5912 b695eaf 3c495cc 17b1403 3c495cc e4bf3a0 3c495cc a44d1c3 3c495cc b695eaf 3c495cc 25a095e 3c495cc b695eaf 3c495cc b695eaf 3c495cc 25a095e 17b1403 3c495cc af115bf 3c495cc 25a095e 3a85228 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
"""
This module contains functions for generating responses using LLMs.
"""
import enum
import json
import os
from random import sample
from uuid import uuid4
from firebase_admin import firestore
from google.cloud import secretmanager
from google.oauth2 import service_account
import gradio as gr
from litellm import completion
from credentials import get_credentials_json
from leaderboard import db
GOOGLE_CLOUD_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT")
MODELS_SECRET = os.environ.get("MODELS_SECRET")
secretmanager_client = secretmanager.SecretManagerServiceClient(
credentials=service_account.Credentials.from_service_account_info(
get_credentials_json()))
models_secret = secretmanager_client.access_secret_version(
name=secretmanager_client.secret_version_path(GOOGLE_CLOUD_PROJECT,
MODELS_SECRET, "latest"))
decoded_secret = models_secret.payload.data.decode("UTF-8")
supported_models = json.loads(decoded_secret)
def create_history(model_name: str, instruction: str, prompt: str,
response: str):
doc_id = uuid4().hex
doc = {
"id": doc_id,
"model": model_name,
"instruction": instruction,
"prompt": prompt,
"response": response,
"timestamp": firestore.SERVER_TIMESTAMP
}
doc_ref = db.collection("arena-history").document(doc_id)
doc_ref.set(doc)
class Category(enum.Enum):
SUMMARIZE = "Summarize"
TRANSLATE = "Translate"
# TODO(#31): Let the model builders set the instruction.
def get_instruction(category, source_lang, target_lang):
if category == Category.SUMMARIZE.value:
return "Summarize the following text, maintaining the original language of the text in the summary." # pylint: disable=line-too-long
if category == Category.TRANSLATE.value:
return f"Translate the following text from {source_lang} to {target_lang}."
def get_responses(user_prompt, category, source_lang, target_lang):
if not category:
raise gr.Error("Please select a category.")
if category == Category.TRANSLATE.value and (not source_lang or
not target_lang):
raise gr.Error("Please select source and target languages.")
models = sample(list(supported_models), 2)
instruction = get_instruction(category, source_lang, target_lang)
responses = []
for model in models:
model_config = supported_models[model]
model_name = model_config[
"provider"] + "/" + model if "provider" in model_config else model
api_key = model_config.get("apiKey", None)
api_base = model_config.get("apiBase", None)
try:
# TODO(#1): Allow user to set configuration.
response = completion(model=model_name,
api_key=api_key,
api_base=api_base,
messages=[{
"content": instruction,
"role": "system"
}, {
"content": user_prompt,
"role": "user"
}])
content = response.choices[0].message.content
create_history(model, instruction, user_prompt, content)
responses.append(content)
# TODO(#1): Narrow down the exception type.
except Exception as e: # pylint: disable=broad-except
print(f"Error with model {model}: {e}")
raise gr.Error("Failed to get response. Please try again.")
# It simulates concurrent stream response generation.
max_response_length = max(len(response) for response in responses)
for i in range(max_response_length):
yield [response[:i + 1] for response in responses] + models + [instruction]
yield responses + models + [instruction]
|