Spaces:
Running
Running
File size: 5,723 Bytes
2a0aa5a 2667b32 2a0aa5a a4e8fcb 43c8549 2a0aa5a 9e789e7 2a0aa5a a4e8fcb 2a0aa5a a4e8fcb 2667b32 9e789e7 2667b32 a4e8fcb 9e789e7 a4e8fcb 2667b32 a4e8fcb 9e789e7 2a0aa5a a4e8fcb 2a0aa5a 58ea8e3 2667b32 2a0aa5a a4e8fcb 58ea8e3 a4e8fcb 2667b32 2a0aa5a a4e8fcb 2a0aa5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
"""
This module contains functions to interact with the models.
"""
import json
import os
from typing import List
import litellm
DEFAULT_SUMMARIZE_INSTRUCTION = "Summarize the given text without changing the language of it." # pylint: disable=line-too-long
DEFAULT_TRANSLATE_INSTRUCTION = "Translate the given text from {source_lang} to {target_lang}." # pylint: disable=line-too-long
class ContextWindowExceededError(Exception):
pass
class Model:
def __init__(
self,
name: str,
provider: str = None,
api_key: str = None,
api_base: str = None,
summarize_instruction: str = None,
translate_instruction: str = None,
):
self.name = name
self.provider = provider
self.api_key = api_key
self.api_base = api_base
self.summarize_instruction = summarize_instruction or DEFAULT_SUMMARIZE_INSTRUCTION # pylint: disable=line-too-long
self.translate_instruction = translate_instruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long
def completion(self,
instruction: str,
prompt: str,
max_tokens: float = None) -> str:
messages = [{
"role":
"system",
"content":
instruction + """
Output following this JSON format:
{"result": "your result here"}"""
}, {
"role": "user",
"content": prompt
}]
try:
response = litellm.completion(model=self.provider + "/" +
self.name if self.provider else self.name,
api_key=self.api_key,
api_base=self.api_base,
messages=messages,
max_tokens=max_tokens,
**self._get_completion_kwargs())
json_response = response.choices[0].message.content
parsed_json = json.loads(json_response)
return parsed_json["result"]
except litellm.ContextWindowExceededError as e:
raise ContextWindowExceededError() from e
except json.JSONDecodeError as e:
raise RuntimeError(f"Failed to get JSON response: {e}") from e
def _get_completion_kwargs(self):
return {
# Ref: https://litellm.vercel.app/docs/completion/input#optional-fields # pylint: disable=line-too-long
"response_format": {
"type": "json_object"
}
}
class AnthropicModel(Model):
def completion(self,
instruction: str,
prompt: str,
max_tokens: float = None) -> str:
# Ref: https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/increase-consistency#prefill-claudes-response # pylint: disable=line-too-long
prefix = "<result>"
suffix = "</result>"
messages = [{
"role":
"user",
"content":
f"""{instruction}
Output following this format:
{prefix}...{suffix}
Text:
{prompt}"""
}, {
"role": "assistant",
"content": prefix
}]
try:
response = litellm.completion(
model=self.provider + "/" + self.name if self.provider else self.name,
api_key=self.api_key,
api_base=self.api_base,
messages=messages,
max_tokens=max_tokens,
)
except litellm.ContextWindowExceededError as e:
raise ContextWindowExceededError() from e
result = response.choices[0].message.content
if not result.endswith(suffix):
raise RuntimeError(f"Failed to get the formatted response: {result}")
return result.removesuffix(suffix).strip()
class VertexModel(Model):
def __init__(self, name: str, vertex_credentials: str):
super().__init__(name, provider="vertex_ai")
self.vertex_credentials = vertex_credentials
def _get_completion_kwargs(self):
return {
"response_format": {
"type": "json_object"
},
"vertex_credentials": self.vertex_credentials
}
class EeveModel(Model):
def _get_completion_kwargs(self):
json_template = {
"type": "object",
"properties": {
"result": {
"type": "string"
}
}
}
return {
"extra_body": {
"guided_json": json.dumps(json_template),
"guided_decoding_backend": "lm-format-enforcer"
}
}
supported_models: List[Model] = [
Model("gpt-4o-2024-05-13"),
Model("gpt-4-turbo-2024-04-09"),
Model("gpt-4-0125-preview"),
Model("gpt-3.5-turbo-0125"),
AnthropicModel("claude-3-opus-20240229"),
AnthropicModel("claude-3-sonnet-20240229"),
AnthropicModel("claude-3-haiku-20240307"),
VertexModel("gemini-1.5-pro-001",
vertex_credentials=os.getenv("VERTEX_CREDENTIALS")),
Model("mistral-small-2402", provider="mistral"),
Model("mistral-large-2402", provider="mistral"),
Model("llama3-8b-8192", provider="groq"),
Model("llama3-70b-8192", provider="groq"),
EeveModel("yanolja/EEVE-Korean-Instruct-10.8B-v1.0",
provider="openai",
api_base=os.getenv("EEVE_API_BASE"),
api_key=os.getenv("EEVE_API_KEY")),
]
def check_models(models: List[Model]):
for model in models:
print(f"Checking model {model.name}...")
try:
model.completion("You are an AI model.", "Hello, world!")
print(f"Model {model.name} is available.")
# This check is designed to verify the availability of the models
# without any issues. Therefore, we need to catch all exceptions.
except Exception as e: # pylint: disable=broad-except
raise RuntimeError(f"Model {model.name} is not available: {e}") from e
|