File size: 4,598 Bytes
e294914
 
 
f0ddded
e294914
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230ae1d
e294914
 
67f9da3
e294914
 
 
 
 
 
 
 
 
 
 
 
 
bda1e3c
e294914
 
 
 
 
 
 
bda1e3c
e294914
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import time

from llm.apimodels.gemini_model import Gemini
from llm.apimodels.hf_model import HF_Mistaril, HF_TinyLlama, HF_SmolLM135, HF_SmolLM360, HF_SmolLM, HF_Gemma2, HF_Qwen2

from typing import Optional, Any

from langchain.chains.conversation.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationChain

def prettify(raw_text: str) -> str:        
    pretty = raw_text.replace("**", "")
    return pretty.strip()


memory: ConversationBufferWindowMemory = ConversationBufferWindowMemory(k=3, ai_prefix="Chelsea")

DELAY: int = 300  # 5 minutes

def has_failed(conversation, prompt) -> Optional[str]:
    """
    Checks if the LLM conversation prediction fails and returns None if so.

    Args:
        conversation: The LLM conversation object used for prediction.
        prompt: The prompt to be used for prediction.

    Returns:
        None, otherwise the prettified response.
    """

    try:
        response = conversation.predict(input=prompt)
        print(f"response: {response}")
        result = prettify(raw_text=response)
        return result
    except Exception as e:
        print(f"Error during prediction with conversation in has_failed function: {e}")
        return None


def has_delay(conversation, prompt) -> Optional[str]:
    """
    Checks if the LLM conversation prediction takes longer than a set delay.

    Args:
        conversation: The LLM conversation object used for prediction.
        prompt: The prompt to be used for prediction.

    Returns:
        None if the execution time exceeds the delay,
        otherwise, the prettified response from the conversation object.
    """

    start_time = time.perf_counter()  # Start timer before prediction
    try:
        response = conversation.predict(input=prompt)
        execution_time = time.perf_counter() - start_time  # Calculate execution time

        if execution_time > DELAY: 
            return None  # Return None if delayed
        
        result = prettify(raw_text=response)  # Prettify the response
        return result  # Return the prettified response

    except Exception as e:
        print(f"Error during prediction with conversation in has_delay function: {e}")


class Conversation:
    def __init__(self):
        """
        Initializes the Conversation class with a prompt and a list of LLM model classes.

        Args:
            model_classes (list, optional): A list of LLM model classes to try in sequence.
                Defaults to [Gemini, HF_SmolLM135, HF_SmolLM360, HF_TinyLlama, HF_SmolLM, HF_Gemma2, HF_Mistaril, HF_Qwen2].
        """

        self.model_classes = [Gemini, HF_Gemma2, HF_SmolLM, HF_SmolLM360, HF_Mistaril, HF_Qwen2, HF_TinyLlama, HF_SmolLM135]
        self.current_model_index = 0

    def _get_conversation(self) -> Any:
        """
        Creates a ConversationChain object using the current model class.
        """
        try:
            current_model_class = self.model_classes[self.current_model_index]
            print("current model class is: ", current_model_class)
            return ConversationChain(llm=current_model_class().execution(), memory=memory, return_final_only=True)
        except Exception as e:
            print(f"Error during conversation chain in get_conversation function: {e}")

    def chatting(self, prompt: str, is_own_model: bool) -> str:
        """
        Carries out the conversation with the user, handling errors and delays.

        Args: 
            prompt(str): The prompt to be used for prediction.

        Returns:
            str: The final conversation response or None if all models fail.
        """

        if prompt is None or prompt == "":
            raise Exception(f"Prompt must be string not None or empty string: {prompt}")

        while self.current_model_index < len(self.model_classes):
            conversation = self._get_conversation()

            result = has_failed(conversation=conversation, prompt=prompt)
            if result is not None:
                return result
            print(f"chat - chatting result : {result}")

            result = has_delay(conversation=conversation, prompt=prompt)
            if result is None:
                self.current_model_index += 1  # Switch to next model after delay
                continue

            return result

        return "All models failed conversation. Please, try again"
    
    def __str__(self) -> str:
        return f"prompt: {type(self.prompt)}"
    
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(prompt: {type(self.prompt)})"