Spaces:
Running
Running
""" | |
This module contains functions for generating responses using LLMs. | |
""" | |
import enum | |
import json | |
import os | |
from random import sample | |
from uuid import uuid4 | |
from firebase_admin import firestore | |
from google.cloud import secretmanager | |
from google.oauth2 import service_account | |
import gradio as gr | |
from litellm import completion | |
from credentials import get_credentials_json | |
from leaderboard import db | |
GOOGLE_CLOUD_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT") | |
MODELS_SECRET = os.environ.get("MODELS_SECRET") | |
secretmanager_client = secretmanager.SecretManagerServiceClient( | |
credentials=service_account.Credentials.from_service_account_info( | |
get_credentials_json())) | |
models_secret = secretmanager_client.access_secret_version( | |
name=secretmanager_client.secret_version_path(GOOGLE_CLOUD_PROJECT, | |
MODELS_SECRET, "latest")) | |
decoded_secret = models_secret.payload.data.decode("UTF-8") | |
supported_models = json.loads(decoded_secret) | |
def create_history(model_name: str, instruction: str, prompt: str, | |
response: str): | |
doc_id = uuid4().hex | |
doc = { | |
"id": doc_id, | |
"model": model_name, | |
"instruction": instruction, | |
"prompt": prompt, | |
"response": response, | |
"timestamp": firestore.SERVER_TIMESTAMP | |
} | |
doc_ref = db.collection("arena-history").document(doc_id) | |
doc_ref.set(doc) | |
class Category(enum.Enum): | |
SUMMARIZE = "Summarize" | |
TRANSLATE = "Translate" | |
# TODO(#31): Let the model builders set the instruction. | |
def get_instruction(category, source_lang, target_lang): | |
if category == Category.SUMMARIZE.value: | |
return "Summarize the following text, maintaining the original language of the text in the summary." # pylint: disable=line-too-long | |
if category == Category.TRANSLATE.value: | |
return f"Translate the following text from {source_lang} to {target_lang}." | |
def get_responses(user_prompt, category, source_lang, target_lang): | |
if not category: | |
raise gr.Error("Please select a category.") | |
if category == Category.TRANSLATE.value and (not source_lang or | |
not target_lang): | |
raise gr.Error("Please select source and target languages.") | |
models = sample(list(supported_models), 2) | |
instruction = get_instruction(category, source_lang, target_lang) | |
responses = [] | |
for model in models: | |
model_config = supported_models[model] | |
model_name = model_config[ | |
"provider"] + "/" + model if "provider" in model_config else model | |
api_key = model_config.get("apiKey", None) | |
api_base = model_config.get("apiBase", None) | |
try: | |
# TODO(#1): Allow user to set configuration. | |
response = completion(model=model_name, | |
api_key=api_key, | |
api_base=api_base, | |
messages=[{ | |
"content": instruction, | |
"role": "system" | |
}, { | |
"content": user_prompt, | |
"role": "user" | |
}]) | |
content = response.choices[0].message.content | |
create_history(model, instruction, user_prompt, content) | |
responses.append(content) | |
# TODO(#1): Narrow down the exception type. | |
except Exception as e: # pylint: disable=broad-except | |
print(f"Error with model {model}: {e}") | |
raise gr.Error("Failed to get response. Please try again.") | |
# It simulates concurrent stream response generation. | |
max_response_length = max(len(response) for response in responses) | |
for i in range(max_response_length): | |
yield [response[:i + 1] for response in responses] + models + [instruction] | |
yield responses + models + [instruction] | |