Spaces:
Running
Running
File size: 5,963 Bytes
2a0aa5a 2667b32 fd9a72d 2a0aa5a a4e8fcb 43c8549 2a0aa5a 9e789e7 2a0aa5a a4e8fcb 2a0aa5a a4e8fcb fd9a72d a4e8fcb fd9a72d a4e8fcb 8d7d881 a4e8fcb 2667b32 fd9a72d a4e8fcb 2667b32 a4e8fcb fd9a72d a4e8fcb 9e789e7 fd9a72d 2a0aa5a fd9a72d a4e8fcb fd9a72d a4e8fcb 2a0aa5a 58ea8e3 2a0aa5a 9f68c4f da8eb3d 5640a96 4c0bb84 58ea8e3 4c0bb84 9f68c4f 4c0bb84 9f68c4f 4c0bb84 9f68c4f 2a0aa5a 8d7d881 2a0aa5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
"""
This module contains functions to interact with the models.
"""
import json
import os
from typing import List, Optional, Tuple
import litellm
DEFAULT_SUMMARIZE_INSTRUCTION = "Summarize the given text without changing the language of it." # pylint: disable=line-too-long
DEFAULT_TRANSLATE_INSTRUCTION = "Translate the given text from {source_lang} to {target_lang}." # pylint: disable=line-too-long
class ContextWindowExceededError(Exception):
pass
class Model:
def __init__(
self,
name: str,
provider: str = None,
api_key: str = None,
api_base: str = None,
summarize_instruction: str = None,
translate_instruction: str = None,
):
self.name = name
self.provider = provider
self.api_key = api_key
self.api_base = api_base
self.summarize_instruction = summarize_instruction or DEFAULT_SUMMARIZE_INSTRUCTION # pylint: disable=line-too-long
self.translate_instruction = translate_instruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long
# Returns the parsed result or raw response, and whether parsing succeeded.
def completion(self,
instruction: str,
prompt: str,
max_tokens: Optional[float] = None,
max_retries: int = 2) -> Tuple[str, bool]:
messages = [{
"role":
"system",
"content":
instruction + """
Output following this JSON format without using code blocks:
{"result": "your result here"}"""
}, {
"role": "user",
"content": prompt
}]
for attempt in range(max_retries + 1):
try:
response = litellm.completion(model=self.provider + "/" +
self.name if self.provider else self.name,
api_key=self.api_key,
api_base=self.api_base,
messages=messages,
max_tokens=max_tokens,
**self._get_completion_kwargs())
json_response = response.choices[0].message.content
parsed_json = json.loads(json_response)
return parsed_json["result"], True
except litellm.ContextWindowExceededError as e:
raise ContextWindowExceededError() from e
except json.JSONDecodeError:
if attempt == max_retries:
return json_response, False
def _get_completion_kwargs(self):
return {
# Ref: https://litellm.vercel.app/docs/completion/input#optional-fields # pylint: disable=line-too-long
"response_format": {
"type": "json_object"
}
}
class AnthropicModel(Model):
def completion(self,
instruction: str,
prompt: str,
max_tokens: Optional[float] = None,
max_retries: int = 2) -> Tuple[str, bool]:
# Ref: https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/increase-consistency#prefill-claudes-response # pylint: disable=line-too-long
prefix = "<result>"
suffix = "</result>"
messages = [{
"role":
"user",
"content":
f"""{instruction}
Output following this format:
{prefix}...{suffix}
Text:
{prompt}"""
}, {
"role": "assistant",
"content": prefix
}]
for attempt in range(max_retries + 1):
try:
response = litellm.completion(
model=self.provider + "/" +
self.name if self.provider else self.name,
api_key=self.api_key,
api_base=self.api_base,
messages=messages,
max_tokens=max_tokens,
)
except litellm.ContextWindowExceededError as e:
raise ContextWindowExceededError() from e
result = response.choices[0].message.content
if result.endswith(suffix):
return result.removesuffix(suffix).strip(), True
if attempt == max_retries:
return result, False
class VertexModel(Model):
def __init__(self, name: str, vertex_credentials: str):
super().__init__(name, provider="vertex_ai")
self.vertex_credentials = vertex_credentials
def _get_completion_kwargs(self):
return {
"response_format": {
"type": "json_object"
},
"vertex_credentials": self.vertex_credentials
}
supported_models: List[Model] = [
Model("gpt-4o-2024-08-06"),
Model("gpt-4o-mini-2024-07-18"),
AnthropicModel("claude-3-5-sonnet-20241022"),
AnthropicModel("claude-3-5-haiku-20241022"),
VertexModel("gemini-1.5-pro-002",
vertex_credentials=os.getenv("VERTEX_CREDENTIALS")),
VertexModel("gemini-1.5-flash-002",
vertex_credentials=os.getenv("VERTEX_CREDENTIALS")),
Model("google/gemma-2-9b-it", provider="deepinfra"),
Model("google/gemma-2-27b-it", provider="deepinfra"),
Model("meta-llama/Meta-Llama-3.1-8B-Instruct", provider="deepinfra"),
Model("meta-llama/Meta-Llama-3.1-70B-Instruct", provider="deepinfra"),
Model("meta-llama/Meta-Llama-3.1-405B-Instruct", provider="deepinfra"),
Model("meta-llama/Llama-3.2-3B-Instruct", provider="deepinfra"),
Model("meta-llama/Llama-3.2-1B-Instruct", provider="deepinfra"),
Model("Qwen/Qwen2.5-72B-Instruct", provider="deepinfra"),
]
def check_models(models: List[Model]):
for model in models:
print(f"Checking model {model.name}...")
try:
model.completion(
"""Output following this JSON format without using code blocks:
{"result": "your result here"}""", "How are you?")
print(f"Model {model.name} is available.")
# This check is designed to verify the availability of the models
# without any issues. Therefore, we need to catch all exceptions.
except Exception as e: # pylint: disable=broad-except
raise RuntimeError(f"Model {model.name} is not available: {e}") from e
|