File size: 4,068 Bytes
75309ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Union, TypeVar, Generic
from typing_extensions import TypedDict
from datetime import datetime
from termcolor import colored
from models.llms import (
    OllamaModel,
    OpenAIModel,
    GroqModel,
    GeminiModel,
    ClaudeModel,
    VllmModel,
    MistralModel
)

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Define a TypeVar for the state
StateT = TypeVar('StateT', bound=Dict[str, Any])

class BaseAgent(ABC, Generic[StateT]):
    def __init__(self, model: str = None, server: str = None, temperature: float = 0, 
                 model_endpoint: str = None, stop: str = None, location: str = "us", hyrbid: bool = False):
        self.model = model
        self.server = server
        self.temperature = temperature
        self.model_endpoint = model_endpoint
        self.stop = stop
        self.llm = self.get_llm()
        self.location = location
        self.hybrid = hyrbid

    
    def get_llm(self, json_model: bool = False):
        if self.server == 'openai':
            return OpenAIModel(model=self.model, temperature=self.temperature, json_response=json_model)
        elif self.server == 'ollama':
            return OllamaModel(model=self.model, temperature=self.temperature, json_response=json_model)
        elif self.server == 'vllm':
            return VllmModel(model=self.model, temperature=self.temperature, json_response=json_model,
                             model_endpoint=self.model_endpoint, stop=self.stop)
        elif self.server == 'groq':
            return GroqModel(model=self.model, temperature=self.temperature, json_response=json_model)
        elif self.server == 'claude':
            return ClaudeModel(temperature=self.temperature, model=self.model, json_response=json_model)
        elif self.server == 'mistral':
            return MistralModel(temperature=self.temperature, model=self.model, json_response=json_model)
        elif self.server == 'gemini':
            # raise ValueError(f"Unsupported server: {self.server}")
            return GeminiModel(temperature=self.temperature, model=self.model,  json_response=json_model)
        else:
            raise ValueError(f"Unsupported server: {self.server}")

    @abstractmethod
    def get_prompt(self, state: StateT = None) -> str:
        pass

    @abstractmethod
    def get_guided_json(self, state:StateT = None) -> Dict[str, Any]:
        pass

    def update_state(self, key: str, value: Union[str, dict], state: StateT = None) -> StateT:
        state[key] = value
        return state

    @abstractmethod
    def process_response(self, response: Any, user_input: str = None, state: StateT = None) -> Dict[str, Union[str, dict]]:
        pass

    @abstractmethod
    def get_conv_history(self, state: StateT = None) -> str:
        pass

    @abstractmethod
    def get_user_input(self) -> str:
        pass

    @abstractmethod
    def use_tool(self) -> Any:
        pass


    def invoke(self, state: StateT = None, human_in_loop: bool = False, user_input: str = None, final_answer: str = None) -> StateT:
        prompt = self.get_prompt(state)
        conversation_history = self.get_conv_history(state)

        if final_answer:
            print(colored(f"\n\n{final_answer}\n\n", "green"))

        if human_in_loop:
            user_input = self.get_user_input()

        messages = [
            {"role": "system", "content": f"{prompt}\n Today's date is {datetime.now()}"},
            {"role": "user", "content": f"\n{final_answer}\n" * 10 + f"{conversation_history}\n{user_input}"}
        ]

        if self.server == 'vllm':
            guided_json = self.get_guided_json(state)
            response = self.llm.invoke(messages, guided_json)
        else:
            response = self.llm.invoke(messages)

        updates = self.process_response(response, user_input, state)
        for key, value in updates.items():
            state = self.update_state(key, value, state)
        return state