Spaces:
Running
Running
File size: 2,567 Bytes
2a0aa5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
"""
This module contains functions to interact with the models.
"""
import json
import os
from typing import List
from google.cloud import secretmanager
from google.oauth2 import service_account
import litellm
from credentials import get_credentials_json
GOOGLE_CLOUD_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT")
MODELS_SECRET = os.environ.get("MODELS_SECRET")
secretmanager_client = secretmanager.SecretManagerServiceClient(
credentials=service_account.Credentials.from_service_account_info(
get_credentials_json()))
models_secret = secretmanager_client.access_secret_version(
name=secretmanager_client.secret_version_path(GOOGLE_CLOUD_PROJECT,
MODELS_SECRET, "latest"))
decoded_secret = models_secret.payload.data.decode("UTF-8")
supported_models_json = json.loads(decoded_secret)
class Model:
def __init__(
self,
name: str,
provider: str = None,
# The JSON keys are in camelCase. To unpack these keys into
# Model attributes, we need to use the same camelCase names.
apiKey: str = None, # pylint: disable=invalid-name
apiBase: str = None): # pylint: disable=invalid-name
self.name = name
self.provider = provider
self.api_key = apiKey
self.api_base = apiBase
supported_models: List[Model] = [
Model(name=model_name, **model_config)
for model_name, model_config in supported_models_json.items()
]
def completion(model: Model, messages: List, max_tokens: float = None) -> str:
response = litellm.completion(model=model.provider + "/" +
model.name if model.provider else model.name,
api_key=model.api_key,
api_base=model.api_base,
messages=messages,
max_tokens=max_tokens)
return response.choices[0].message.content
def check_models(models: List[Model]):
for model in models:
print(f"Checking model {model.name}...")
try:
completion(model=model,
messages=[{
"content": "Hello.",
"role": "user"
}],
max_tokens=5)
print(f"Model {model.name} is available.")
# This check is designed to verify the availability of the models
# without any issues. Therefore, we need to catch all exceptions.
except Exception as e: # pylint: disable=broad-except
raise RuntimeError(f"Model {model.name} is not available: {e}") from e
|