arena / model.py
Kang Suhyun
[#67] Check models at the start of the app (#68)
2a0aa5a unverified
raw
history blame
2.57 kB
"""
This module contains functions to interact with the models.
"""
import json
import os
from typing import List
from google.cloud import secretmanager
from google.oauth2 import service_account
import litellm
from credentials import get_credentials_json
GOOGLE_CLOUD_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT")
MODELS_SECRET = os.environ.get("MODELS_SECRET")
secretmanager_client = secretmanager.SecretManagerServiceClient(
credentials=service_account.Credentials.from_service_account_info(
get_credentials_json()))
models_secret = secretmanager_client.access_secret_version(
name=secretmanager_client.secret_version_path(GOOGLE_CLOUD_PROJECT,
MODELS_SECRET, "latest"))
decoded_secret = models_secret.payload.data.decode("UTF-8")
supported_models_json = json.loads(decoded_secret)
class Model:
def __init__(
self,
name: str,
provider: str = None,
# The JSON keys are in camelCase. To unpack these keys into
# Model attributes, we need to use the same camelCase names.
apiKey: str = None, # pylint: disable=invalid-name
apiBase: str = None): # pylint: disable=invalid-name
self.name = name
self.provider = provider
self.api_key = apiKey
self.api_base = apiBase
supported_models: List[Model] = [
Model(name=model_name, **model_config)
for model_name, model_config in supported_models_json.items()
]
def completion(model: Model, messages: List, max_tokens: float = None) -> str:
response = litellm.completion(model=model.provider + "/" +
model.name if model.provider else model.name,
api_key=model.api_key,
api_base=model.api_base,
messages=messages,
max_tokens=max_tokens)
return response.choices[0].message.content
def check_models(models: List[Model]):
for model in models:
print(f"Checking model {model.name}...")
try:
completion(model=model,
messages=[{
"content": "Hello.",
"role": "user"
}],
max_tokens=5)
print(f"Model {model.name} is available.")
# This check is designed to verify the availability of the models
# without any issues. Therefore, we need to catch all exceptions.
except Exception as e: # pylint: disable=broad-except
raise RuntimeError(f"Model {model.name} is not available: {e}") from e