File size: 4,598 Bytes
e294914 f0ddded e294914 230ae1d e294914 67f9da3 e294914 bda1e3c e294914 bda1e3c e294914 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import time
from llm.apimodels.gemini_model import Gemini
from llm.apimodels.hf_model import HF_Mistaril, HF_TinyLlama, HF_SmolLM135, HF_SmolLM360, HF_SmolLM, HF_Gemma2, HF_Qwen2
from typing import Optional, Any
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationChain
def prettify(raw_text: str) -> str:
pretty = raw_text.replace("**", "")
return pretty.strip()
memory: ConversationBufferWindowMemory = ConversationBufferWindowMemory(k=3, ai_prefix="Chelsea")
DELAY: int = 300 # 5 minutes
def has_failed(conversation, prompt) -> Optional[str]:
"""
Checks if the LLM conversation prediction fails and returns None if so.
Args:
conversation: The LLM conversation object used for prediction.
prompt: The prompt to be used for prediction.
Returns:
None, otherwise the prettified response.
"""
try:
response = conversation.predict(input=prompt)
print(f"response: {response}")
result = prettify(raw_text=response)
return result
except Exception as e:
print(f"Error during prediction with conversation in has_failed function: {e}")
return None
def has_delay(conversation, prompt) -> Optional[str]:
"""
Checks if the LLM conversation prediction takes longer than a set delay.
Args:
conversation: The LLM conversation object used for prediction.
prompt: The prompt to be used for prediction.
Returns:
None if the execution time exceeds the delay,
otherwise, the prettified response from the conversation object.
"""
start_time = time.perf_counter() # Start timer before prediction
try:
response = conversation.predict(input=prompt)
execution_time = time.perf_counter() - start_time # Calculate execution time
if execution_time > DELAY:
return None # Return None if delayed
result = prettify(raw_text=response) # Prettify the response
return result # Return the prettified response
except Exception as e:
print(f"Error during prediction with conversation in has_delay function: {e}")
class Conversation:
def __init__(self):
"""
Initializes the Conversation class with a prompt and a list of LLM model classes.
Args:
model_classes (list, optional): A list of LLM model classes to try in sequence.
Defaults to [Gemini, HF_SmolLM135, HF_SmolLM360, HF_TinyLlama, HF_SmolLM, HF_Gemma2, HF_Mistaril, HF_Qwen2].
"""
self.model_classes = [Gemini, HF_Gemma2, HF_SmolLM, HF_SmolLM360, HF_Mistaril, HF_Qwen2, HF_TinyLlama, HF_SmolLM135]
self.current_model_index = 0
def _get_conversation(self) -> Any:
"""
Creates a ConversationChain object using the current model class.
"""
try:
current_model_class = self.model_classes[self.current_model_index]
print("current model class is: ", current_model_class)
return ConversationChain(llm=current_model_class().execution(), memory=memory, return_final_only=True)
except Exception as e:
print(f"Error during conversation chain in get_conversation function: {e}")
def chatting(self, prompt: str, is_own_model: bool) -> str:
"""
Carries out the conversation with the user, handling errors and delays.
Args:
prompt(str): The prompt to be used for prediction.
Returns:
str: The final conversation response or None if all models fail.
"""
if prompt is None or prompt == "":
raise Exception(f"Prompt must be string not None or empty string: {prompt}")
while self.current_model_index < len(self.model_classes):
conversation = self._get_conversation()
result = has_failed(conversation=conversation, prompt=prompt)
if result is not None:
return result
print(f"chat - chatting result : {result}")
result = has_delay(conversation=conversation, prompt=prompt)
if result is None:
self.current_model_index += 1 # Switch to next model after delay
continue
return result
return "All models failed conversation. Please, try again"
def __str__(self) -> str:
return f"prompt: {type(self.prompt)}"
def __repr__(self) -> str:
return f"{self.__class__.__name__}(prompt: {type(self.prompt)})"
|