Spaces:
Running
Running
File size: 3,170 Bytes
2a0aa5a 43c8549 2a0aa5a 43c8549 2a0aa5a 43c8549 2a0aa5a 43c8549 2a0aa5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
"""
This module contains functions to interact with the models.
"""
import json
import os
from typing import List
from google.cloud import secretmanager
from google.oauth2 import service_account
import litellm
from credentials import get_credentials_json
GOOGLE_CLOUD_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT")
MODELS_SECRET = os.environ.get("MODELS_SECRET")
secretmanager_client = secretmanager.SecretManagerServiceClient(
credentials=service_account.Credentials.from_service_account_info(
get_credentials_json()))
models_secret = secretmanager_client.access_secret_version(
name=secretmanager_client.secret_version_path(GOOGLE_CLOUD_PROJECT,
MODELS_SECRET, "latest"))
decoded_secret = models_secret.payload.data.decode("UTF-8")
supported_models_json = json.loads(decoded_secret)
DEFAULT_SUMMARIZE_INSTRUCTION = "Summarize the following text, maintaining the language of the text." # pylint: disable=line-too-long
DEFAULT_TRANSLATE_INSTRUCTION = "Translate the following text from {source_lang} to {target_lang}." # pylint: disable=line-too-long
class Model:
def __init__(
self,
name: str,
provider: str = None,
# The JSON keys are in camelCase. To unpack these keys into
# Model attributes, we need to use the same camelCase names.
apiKey: str = None, # pylint: disable=invalid-name
apiBase: str = None, # pylint: disable=invalid-name
summarizeInstruction: str = None, # pylint: disable=invalid-name
translateInstruction: str = None): # pylint: disable=invalid-name
self.name = name
self.provider = provider
self.api_key = apiKey
self.api_base = apiBase
self.summarize_instruction = summarizeInstruction or DEFAULT_SUMMARIZE_INSTRUCTION # pylint: disable=line-too-long
self.translate_instruction = translateInstruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long
def completion(self, messages: List, max_tokens: float = None) -> str:
response = litellm.completion(model=self.provider + "/" +
self.name if self.provider else self.name,
api_key=self.api_key,
api_base=self.api_base,
messages=messages,
max_tokens=max_tokens)
return response.choices[0].message.content
supported_models: List[Model] = [
Model(name=model_name, **model_config)
for model_name, model_config in supported_models_json.items()
]
def check_models(models: List[Model]):
for model in models:
print(f"Checking model {model.name}...")
try:
model.completion(messages=[{
"role": "user",
"content": "Hello."
}],
max_tokens=5)
print(f"Model {model.name} is available.")
# This check is designed to verify the availability of the models
# without any issues. Therefore, we need to catch all exceptions.
except Exception as e: # pylint: disable=broad-except
raise RuntimeError(f"Model {model.name} is not available: {e}") from e
|