from typing import Dict, List, Optional, Any from crewai import Agent, Task import logging from utils.log_manager import LogManager from pydantic import Field, BaseModel, ConfigDict from transformers import AutoModelForCausalLM, AutoTokenizer import torch class BaseWellnessAgent(Agent): """Base agent class with Mistral LLM support""" # Allow arbitrary types in model model_config = ConfigDict(arbitrary_types_allowed=True) # Define fields that will be used log_manager: LogManager = Field(default_factory=LogManager) logger: logging.Logger = Field(default=None) config: Dict = Field(default_factory=dict) model: Any = Field(default=None) tokenizer: Any = Field(default=None) agent_type: str = Field(default="base") def __init__(self, model_config: Dict, agent_type: str, **kwargs): # Initialize the CrewAI agent first with required fields super().__init__( role=kwargs.get("role", "Wellness Support Agent"), goal=kwargs.get("goal", "Support mental wellness"), backstory=kwargs.get("backstory", "I am an AI agent specialized in mental health support."), verbose=kwargs.get("verbose", True), allow_delegation=kwargs.get("allow_delegation", False), tools=kwargs.get("tools", []), **kwargs ) # Initialize logging and configuration self.config = model_config self.agent_type = agent_type self.logger = self.log_manager.get_agent_logger(agent_type) # Initialize Mistral model self._initialize_model() self.logger.info(f"{agent_type.capitalize()} Agent initialized") def _initialize_model(self): """Initialize the Mistral model""" try: model_config = self.config[self.agent_type] self.logger.info(f"Initializing Mistral model: {model_config['model_id']}") # Initialize tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(model_config["model_id"]) self.model = AutoModelForCausalLM.from_pretrained( model_config["model_id"], torch_dtype=torch.float32, device_map="auto", load_in_4bit=True ) self.logger.info("Mistral model initialized successfully") except Exception as e: self.logger.error(f"Error initializing Mistral model: {str(e)}") raise def _generate_response(self, input_text: str) -> str: """Generate response using Mistral model""" try: # Prepare input with instruction template template = self.config[self.agent_type]["instruction_template"] prompt = template.format(input=input_text) # Tokenize input inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) # Generate response outputs = self.model.generate( **inputs, max_length=self.config[self.agent_type].get("max_length", 4096), temperature=self.config[self.agent_type].get("temperature", 0.7), top_p=self.config[self.agent_type].get("top_p", 0.95), repetition_penalty=self.config[self.agent_type].get("repetition_penalty", 1.1), do_sample=True, pad_token_id=self.tokenizer.eos_token_id ) # Decode and clean response response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) response = response.replace(prompt, "").strip() return response except Exception as e: self.logger.error(f"Error generating response: {str(e)}") return "I apologize, but I encountered an error generating a response." def execute_task(self, task: Task) -> str: """Execute a task assigned to the agent""" self.logger.info(f"Executing task: {task.description}") try: # Process the task description as a message result = self.process_message(task.description) return result["message"] except Exception as e: self.logger.error(f"Error executing task: {str(e)}") return "I apologize, but I encountered an error processing your request." def process_message(self, message: str, context: Dict = None) -> Dict: """Process a message and return a response""" self.logger.info("Processing message") context = context or {} try: # Generate response using Mistral response = self._generate_response(message) return { "message": response, "agent_type": self.agent_type, "task_type": "dialogue" } except Exception as e: self.logger.error(f"Error processing message: {str(e)}") return { "message": "I apologize, but I encountered an error. Let me try a different approach.", "agent_type": self.agent_type, "task_type": "error_recovery" } def get_status(self) -> Dict: """Get the current status of the agent""" return { "type": self.agent_type, "ready": bool(self.model and self.tokenizer), "tools_available": len(self.tools) }