Lagent / lagent /actions /action_executor.py
Superkingjcj's picture
Upload 111 files
e679d69 verified
import inspect
from collections import OrderedDict
from typing import Callable, Dict, List, Union
from lagent.actions.base_action import BaseAction
from lagent.actions.builtin_actions import FinishAction, InvalidAction, NoAction
from lagent.hooks import Hook, RemovableHandle
from lagent.schema import ActionReturn, ActionValidCode, AgentMessage, FunctionCall
from lagent.utils import create_object
class ActionExecutor:
"""The action executor class.
Args:
actions (Union[BaseAction, List[BaseAction]]): The action or actions.
invalid_action (BaseAction, optional): The invalid action. Defaults to
InvalidAction().
no_action (BaseAction, optional): The no action.
Defaults to NoAction().
finish_action (BaseAction, optional): The finish action. Defaults to
FinishAction().
finish_in_action (bool, optional): Whether the finish action is in the
action list. Defaults to False.
"""
def __init__(
self,
actions: Union[BaseAction, List[BaseAction], Dict, List[Dict]],
invalid_action: BaseAction = dict(type=InvalidAction),
no_action: BaseAction = dict(type=NoAction),
finish_action: BaseAction = dict(type=FinishAction),
finish_in_action: bool = False,
hooks: List[Dict] = None,
):
if not isinstance(actions, list):
actions = [actions]
finish_action = create_object(finish_action)
if finish_in_action:
actions.append(finish_action)
for i, action in enumerate(actions):
actions[i] = create_object(action)
self.actions = {action.name: action for action in actions}
self.invalid_action = create_object(invalid_action)
self.no_action = create_object(no_action)
self.finish_action = finish_action
self._hooks: Dict[int, Hook] = OrderedDict()
if hooks:
for hook in hooks:
hook = create_object(hook)
self.register_hook(hook)
def description(self) -> List[Dict]:
actions = []
for action_name, action in self.actions.items():
if action.is_toolkit:
for api in action.description['api_list']:
api_desc = api.copy()
api_desc['name'] = f"{action_name}.{api_desc['name']}"
actions.append(api_desc)
else:
action_desc = action.description.copy()
actions.append(action_desc)
return actions
def __contains__(self, name: str):
return name in self.actions
def keys(self):
return list(self.actions.keys())
def __setitem__(self, name: str, action: Union[BaseAction, Dict]):
action = create_object(action)
self.actions[action.name] = action
def __delitem__(self, name: str):
del self.actions[name]
def forward(self, name, parameters, **kwargs) -> ActionReturn:
action_name, api_name = (
name.split('.') if '.' in name else (name, 'run'))
action_return: ActionReturn = ActionReturn()
if action_name not in self:
if name == self.no_action.name:
action_return = self.no_action(parameters)
elif name == self.finish_action.name:
action_return = self.finish_action(parameters)
else:
action_return = self.invalid_action(parameters)
else:
action_return = self.actions[action_name](parameters, api_name)
action_return.valid = ActionValidCode.OPEN
return action_return
def __call__(self,
message: AgentMessage,
session_id=0,
**kwargs) -> AgentMessage:
# message.receiver = self.name
for hook in self._hooks.values():
result = hook.before_action(self, message, session_id)
if result:
message = result
assert isinstance(message.content, FunctionCall) or (
isinstance(message.content, dict) and 'name' in message.content
and 'parameters' in message.content)
if isinstance(message.content, dict):
name = message.content.get('name')
parameters = message.content.get('parameters')
else:
name = message.content.name
parameters = message.content.parameters
response_message = self.forward(
name=name, parameters=parameters, **kwargs)
if not isinstance(response_message, AgentMessage):
response_message = AgentMessage(
sender=self.__class__.__name__,
content=response_message,
)
for hook in self._hooks.values():
result = hook.after_action(self, response_message, session_id)
if result:
response_message = result
return response_message
def register_hook(self, hook: Callable):
handle = RemovableHandle(self._hooks)
self._hooks[handle.id] = hook
return handle
class AsyncActionExecutor(ActionExecutor):
async def forward(self, name, parameters, **kwargs) -> ActionReturn:
action_name, api_name = (
name.split('.') if '.' in name else (name, 'run'))
action_return: ActionReturn = ActionReturn()
if action_name not in self:
if name == self.no_action.name:
action_return = self.no_action(parameters)
elif name == self.finish_action.name:
action_return = self.finish_action(parameters)
else:
action_return = self.invalid_action(parameters)
else:
action = self.actions[action_name]
if inspect.iscoroutinefunction(action.__call__):
action_return = await action(parameters, api_name)
else:
action_return = action(parameters, api_name)
action_return.valid = ActionValidCode.OPEN
return action_return
async def __call__(self,
message: AgentMessage,
session_id=0,
**kwargs) -> AgentMessage:
# message.receiver = self.name
for hook in self._hooks.values():
if inspect.iscoroutinefunction(hook.before_action):
result = await hook.before_action(self, message, session_id)
else:
result = hook.before_action(self, message, session_id)
if result:
message = result
assert isinstance(message.content, FunctionCall) or (
isinstance(message.content, dict) and 'name' in message.content
and 'parameters' in message.content)
if isinstance(message.content, dict):
name = message.content.get('name')
parameters = message.content.get('parameters')
else:
name = message.content.name
parameters = message.content.parameters
response_message = await self.forward(
name=name, parameters=parameters, **kwargs)
if not isinstance(response_message, AgentMessage):
response_message = AgentMessage(
sender=self.__class__.__name__,
content=response_message,
)
for hook in self._hooks.values():
if inspect.iscoroutinefunction(hook.after_action):
result = await hook.after_action(self, response_message,
session_id)
else:
result = hook.after_action(self, response_message, session_id)
if result:
response_message = result
return response_message