Spaces:
Runtime error
Runtime error
from typing import Dict, List, Optional, Any | |
from crewai import Agent, Task | |
import logging | |
from utils.log_manager import LogManager | |
from pydantic import Field, BaseModel, ConfigDict | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
class BaseWellnessAgent(Agent): | |
"""Base agent class with Mistral LLM support""" | |
# Allow arbitrary types in model | |
model_config = ConfigDict(arbitrary_types_allowed=True) | |
# Define fields that will be used | |
log_manager: LogManager = Field(default_factory=LogManager) | |
logger: logging.Logger = Field(default=None) | |
config: Dict = Field(default_factory=dict) | |
model: Any = Field(default=None) | |
tokenizer: Any = Field(default=None) | |
agent_type: str = Field(default="base") | |
def __init__(self, model_config: Dict, agent_type: str, **kwargs): | |
# Initialize the CrewAI agent first with required fields | |
super().__init__( | |
role=kwargs.get("role", "Wellness Support Agent"), | |
goal=kwargs.get("goal", "Support mental wellness"), | |
backstory=kwargs.get("backstory", "I am an AI agent specialized in mental health support."), | |
verbose=kwargs.get("verbose", True), | |
allow_delegation=kwargs.get("allow_delegation", False), | |
tools=kwargs.get("tools", []), | |
**kwargs | |
) | |
# Initialize logging and configuration | |
self.config = model_config | |
self.agent_type = agent_type | |
self.logger = self.log_manager.get_agent_logger(agent_type) | |
# Initialize Mistral model | |
self._initialize_model() | |
self.logger.info(f"{agent_type.capitalize()} Agent initialized") | |
def _initialize_model(self): | |
"""Initialize the Mistral model""" | |
try: | |
model_config = self.config[self.agent_type] | |
self.logger.info(f"Initializing Mistral model: {model_config['model_id']}") | |
# Initialize tokenizer and model | |
self.tokenizer = AutoTokenizer.from_pretrained(model_config["model_id"]) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_config["model_id"], | |
torch_dtype=torch.float32, | |
device_map="auto", | |
load_in_4bit=True | |
) | |
self.logger.info("Mistral model initialized successfully") | |
except Exception as e: | |
self.logger.error(f"Error initializing Mistral model: {str(e)}") | |
raise | |
def _generate_response(self, input_text: str) -> str: | |
"""Generate response using Mistral model""" | |
try: | |
# Prepare input with instruction template | |
template = self.config[self.agent_type]["instruction_template"] | |
prompt = template.format(input=input_text) | |
# Tokenize input | |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
# Generate response | |
outputs = self.model.generate( | |
**inputs, | |
max_length=self.config[self.agent_type].get("max_length", 4096), | |
temperature=self.config[self.agent_type].get("temperature", 0.7), | |
top_p=self.config[self.agent_type].get("top_p", 0.95), | |
repetition_penalty=self.config[self.agent_type].get("repetition_penalty", 1.1), | |
do_sample=True, | |
pad_token_id=self.tokenizer.eos_token_id | |
) | |
# Decode and clean response | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = response.replace(prompt, "").strip() | |
return response | |
except Exception as e: | |
self.logger.error(f"Error generating response: {str(e)}") | |
return "I apologize, but I encountered an error generating a response." | |
def execute_task(self, task: Task) -> str: | |
"""Execute a task assigned to the agent""" | |
self.logger.info(f"Executing task: {task.description}") | |
try: | |
# Process the task description as a message | |
result = self.process_message(task.description) | |
return result["message"] | |
except Exception as e: | |
self.logger.error(f"Error executing task: {str(e)}") | |
return "I apologize, but I encountered an error processing your request." | |
def process_message(self, message: str, context: Dict = None) -> Dict: | |
"""Process a message and return a response""" | |
self.logger.info("Processing message") | |
context = context or {} | |
try: | |
# Generate response using Mistral | |
response = self._generate_response(message) | |
return { | |
"message": response, | |
"agent_type": self.agent_type, | |
"task_type": "dialogue" | |
} | |
except Exception as e: | |
self.logger.error(f"Error processing message: {str(e)}") | |
return { | |
"message": "I apologize, but I encountered an error. Let me try a different approach.", | |
"agent_type": self.agent_type, | |
"task_type": "error_recovery" | |
} | |
def get_status(self) -> Dict: | |
"""Get the current status of the agent""" | |
return { | |
"type": self.agent_type, | |
"ready": bool(self.model and self.tokenizer), | |
"tools_available": len(self.tools) | |
} |