Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from swarms.agents.message import Message | |
| class Mistral: | |
| """ | |
| Mistral | |
| model = Mistral(device="cuda", use_flash_attention=True, temperature=0.7, max_length=200) | |
| task = "My favourite condiment is" | |
| result = model.run(task) | |
| print(result) | |
| """ | |
| def __init__( | |
| self, | |
| ai_name: str = "Node Model Agent", | |
| system_prompt: str = None, | |
| model_name: str ="mistralai/Mistral-7B-v0.1", | |
| device: str ="cuda", | |
| use_flash_attention: bool = False, | |
| temperature: float = 1.0, | |
| max_length: int = 100, | |
| do_sample: bool = True | |
| ): | |
| self.ai_name = ai_name | |
| self.system_prompt = system_prompt | |
| self.model_name = model_name | |
| self.device = device | |
| self.use_flash_attention = use_flash_attention | |
| self.temperature = temperature | |
| self.max_length = max_length | |
| # Check if the specified device is available | |
| if not torch.cuda.is_available() and device == "cuda": | |
| raise ValueError("CUDA is not available. Please choose a different device.") | |
| # Load the model and tokenizer | |
| self.model = None | |
| self.tokenizer = None | |
| self.load_model() | |
| self.history = [] | |
| def load_model(self): | |
| try: | |
| self.model = AutoModelForCausalLM.from_pretrained(self.model_name) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model.to(self.device) | |
| except Exception as e: | |
| raise ValueError(f"Error loading the Mistral model: {str(e)}") | |
| def run( | |
| self, | |
| task: str | |
| ): | |
| """Run the model on a given task.""" | |
| try: | |
| model_inputs = self.tokenizer( | |
| [task], | |
| return_tensors="pt" | |
| ).to(self.device) | |
| generated_ids = self.model.generate( | |
| **model_inputs, | |
| max_length=self.max_length, | |
| do_sample=self.do_sample, | |
| temperature=self.temperature, | |
| max_new_tokens=self.max_length | |
| ) | |
| output_text = self.tokenizer.batch_decode(generated_ids)[0] | |
| return output_text | |
| except Exception as e: | |
| raise ValueError(f"Error running the model: {str(e)}") | |
| def chat( | |
| self, | |
| msg: str = None, | |
| streaming: bool = False | |
| ): | |
| """ | |
| Run chat | |
| Args: | |
| msg (str, optional): Message to send to the agent. Defaults to None. | |
| language (str, optional): Language to use. Defaults to None. | |
| streaming (bool, optional): Whether to stream the response. Defaults to False. | |
| Returns: | |
| str: Response from the agent | |
| Usage: | |
| -------------- | |
| agent = MultiModalAgent() | |
| agent.chat("Hello") | |
| """ | |
| #add users message to the history | |
| self.history.append( | |
| Message( | |
| "User", | |
| msg | |
| ) | |
| ) | |
| #process msg | |
| try: | |
| response = self.agent.run(msg) | |
| #add agent's response to the history | |
| self.history.append( | |
| Message( | |
| "Agent", | |
| response | |
| ) | |
| ) | |
| #if streaming is = True | |
| if streaming: | |
| return self._stream_response(response) | |
| else: | |
| response | |
| except Exception as error: | |
| error_message = f"Error processing message: {str(error)}" | |
| #add error to history | |
| self.history.append( | |
| Message( | |
| "Agent", | |
| error_message | |
| ) | |
| ) | |
| return error_message | |
| def _stream_response( | |
| self, | |
| response: str = None | |
| ): | |
| """ | |
| Yield the response token by token (word by word) | |
| Usage: | |
| -------------- | |
| for token in _stream_response(response): | |
| print(token) | |
| """ | |
| for token in response.split(): | |
| yield token | |