mentalwellness / agents /base_agent.py
invincible-jha's picture
Implement Mistral LLM support for all agents
3e274d5
from typing import Dict, List, Optional, Any
from crewai import Agent, Task
import logging
from utils.log_manager import LogManager
from pydantic import Field, BaseModel, ConfigDict
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class BaseWellnessAgent(Agent):
"""Base agent class with Mistral LLM support"""
# Allow arbitrary types in model
model_config = ConfigDict(arbitrary_types_allowed=True)
# Define fields that will be used
log_manager: LogManager = Field(default_factory=LogManager)
logger: logging.Logger = Field(default=None)
config: Dict = Field(default_factory=dict)
model: Any = Field(default=None)
tokenizer: Any = Field(default=None)
agent_type: str = Field(default="base")
def __init__(self, model_config: Dict, agent_type: str, **kwargs):
# Initialize the CrewAI agent first with required fields
super().__init__(
role=kwargs.get("role", "Wellness Support Agent"),
goal=kwargs.get("goal", "Support mental wellness"),
backstory=kwargs.get("backstory", "I am an AI agent specialized in mental health support."),
verbose=kwargs.get("verbose", True),
allow_delegation=kwargs.get("allow_delegation", False),
tools=kwargs.get("tools", []),
**kwargs
)
# Initialize logging and configuration
self.config = model_config
self.agent_type = agent_type
self.logger = self.log_manager.get_agent_logger(agent_type)
# Initialize Mistral model
self._initialize_model()
self.logger.info(f"{agent_type.capitalize()} Agent initialized")
def _initialize_model(self):
"""Initialize the Mistral model"""
try:
model_config = self.config[self.agent_type]
self.logger.info(f"Initializing Mistral model: {model_config['model_id']}")
# Initialize tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_config["model_id"])
self.model = AutoModelForCausalLM.from_pretrained(
model_config["model_id"],
torch_dtype=torch.float32,
device_map="auto",
load_in_4bit=True
)
self.logger.info("Mistral model initialized successfully")
except Exception as e:
self.logger.error(f"Error initializing Mistral model: {str(e)}")
raise
def _generate_response(self, input_text: str) -> str:
"""Generate response using Mistral model"""
try:
# Prepare input with instruction template
template = self.config[self.agent_type]["instruction_template"]
prompt = template.format(input=input_text)
# Tokenize input
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
# Generate response
outputs = self.model.generate(
**inputs,
max_length=self.config[self.agent_type].get("max_length", 4096),
temperature=self.config[self.agent_type].get("temperature", 0.7),
top_p=self.config[self.agent_type].get("top_p", 0.95),
repetition_penalty=self.config[self.agent_type].get("repetition_penalty", 1.1),
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode and clean response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.replace(prompt, "").strip()
return response
except Exception as e:
self.logger.error(f"Error generating response: {str(e)}")
return "I apologize, but I encountered an error generating a response."
def execute_task(self, task: Task) -> str:
"""Execute a task assigned to the agent"""
self.logger.info(f"Executing task: {task.description}")
try:
# Process the task description as a message
result = self.process_message(task.description)
return result["message"]
except Exception as e:
self.logger.error(f"Error executing task: {str(e)}")
return "I apologize, but I encountered an error processing your request."
def process_message(self, message: str, context: Dict = None) -> Dict:
"""Process a message and return a response"""
self.logger.info("Processing message")
context = context or {}
try:
# Generate response using Mistral
response = self._generate_response(message)
return {
"message": response,
"agent_type": self.agent_type,
"task_type": "dialogue"
}
except Exception as e:
self.logger.error(f"Error processing message: {str(e)}")
return {
"message": "I apologize, but I encountered an error. Let me try a different approach.",
"agent_type": self.agent_type,
"task_type": "error_recovery"
}
def get_status(self) -> Dict:
"""Get the current status of the agent"""
return {
"type": self.agent_type,
"ready": bool(self.model and self.tokenizer),
"tools_available": len(self.tools)
}