Spaces:
Running
Running
import inspect | |
from collections import OrderedDict | |
from typing import Callable, Dict, List, Union | |
from lagent.actions.base_action import BaseAction | |
from lagent.actions.builtin_actions import FinishAction, InvalidAction, NoAction | |
from lagent.hooks import Hook, RemovableHandle | |
from lagent.schema import ActionReturn, ActionValidCode, AgentMessage, FunctionCall | |
from lagent.utils import create_object | |
class ActionExecutor: | |
"""The action executor class. | |
Args: | |
actions (Union[BaseAction, List[BaseAction]]): The action or actions. | |
invalid_action (BaseAction, optional): The invalid action. Defaults to | |
InvalidAction(). | |
no_action (BaseAction, optional): The no action. | |
Defaults to NoAction(). | |
finish_action (BaseAction, optional): The finish action. Defaults to | |
FinishAction(). | |
finish_in_action (bool, optional): Whether the finish action is in the | |
action list. Defaults to False. | |
""" | |
def __init__( | |
self, | |
actions: Union[BaseAction, List[BaseAction], Dict, List[Dict]], | |
invalid_action: BaseAction = dict(type=InvalidAction), | |
no_action: BaseAction = dict(type=NoAction), | |
finish_action: BaseAction = dict(type=FinishAction), | |
finish_in_action: bool = False, | |
hooks: List[Dict] = None, | |
): | |
if not isinstance(actions, list): | |
actions = [actions] | |
finish_action = create_object(finish_action) | |
if finish_in_action: | |
actions.append(finish_action) | |
for i, action in enumerate(actions): | |
actions[i] = create_object(action) | |
self.actions = {action.name: action for action in actions} | |
self.invalid_action = create_object(invalid_action) | |
self.no_action = create_object(no_action) | |
self.finish_action = finish_action | |
self._hooks: Dict[int, Hook] = OrderedDict() | |
if hooks: | |
for hook in hooks: | |
hook = create_object(hook) | |
self.register_hook(hook) | |
def description(self) -> List[Dict]: | |
actions = [] | |
for action_name, action in self.actions.items(): | |
if action.is_toolkit: | |
for api in action.description['api_list']: | |
api_desc = api.copy() | |
api_desc['name'] = f"{action_name}.{api_desc['name']}" | |
actions.append(api_desc) | |
else: | |
action_desc = action.description.copy() | |
actions.append(action_desc) | |
return actions | |
def __contains__(self, name: str): | |
return name in self.actions | |
def keys(self): | |
return list(self.actions.keys()) | |
def __setitem__(self, name: str, action: Union[BaseAction, Dict]): | |
action = create_object(action) | |
self.actions[action.name] = action | |
def __delitem__(self, name: str): | |
del self.actions[name] | |
def forward(self, name, parameters, **kwargs) -> ActionReturn: | |
action_name, api_name = ( | |
name.split('.') if '.' in name else (name, 'run')) | |
action_return: ActionReturn = ActionReturn() | |
if action_name not in self: | |
if name == self.no_action.name: | |
action_return = self.no_action(parameters) | |
elif name == self.finish_action.name: | |
action_return = self.finish_action(parameters) | |
else: | |
action_return = self.invalid_action(parameters) | |
else: | |
action_return = self.actions[action_name](parameters, api_name) | |
action_return.valid = ActionValidCode.OPEN | |
return action_return | |
def __call__(self, | |
message: AgentMessage, | |
session_id=0, | |
**kwargs) -> AgentMessage: | |
# message.receiver = self.name | |
for hook in self._hooks.values(): | |
result = hook.before_action(self, message, session_id) | |
if result: | |
message = result | |
assert isinstance(message.content, FunctionCall) or ( | |
isinstance(message.content, dict) and 'name' in message.content | |
and 'parameters' in message.content) | |
if isinstance(message.content, dict): | |
name = message.content.get('name') | |
parameters = message.content.get('parameters') | |
else: | |
name = message.content.name | |
parameters = message.content.parameters | |
response_message = self.forward( | |
name=name, parameters=parameters, **kwargs) | |
if not isinstance(response_message, AgentMessage): | |
response_message = AgentMessage( | |
sender=self.__class__.__name__, | |
content=response_message, | |
) | |
for hook in self._hooks.values(): | |
result = hook.after_action(self, response_message, session_id) | |
if result: | |
response_message = result | |
return response_message | |
def register_hook(self, hook: Callable): | |
handle = RemovableHandle(self._hooks) | |
self._hooks[handle.id] = hook | |
return handle | |
class AsyncActionExecutor(ActionExecutor): | |
async def forward(self, name, parameters, **kwargs) -> ActionReturn: | |
action_name, api_name = ( | |
name.split('.') if '.' in name else (name, 'run')) | |
action_return: ActionReturn = ActionReturn() | |
if action_name not in self: | |
if name == self.no_action.name: | |
action_return = self.no_action(parameters) | |
elif name == self.finish_action.name: | |
action_return = self.finish_action(parameters) | |
else: | |
action_return = self.invalid_action(parameters) | |
else: | |
action = self.actions[action_name] | |
if inspect.iscoroutinefunction(action.__call__): | |
action_return = await action(parameters, api_name) | |
else: | |
action_return = action(parameters, api_name) | |
action_return.valid = ActionValidCode.OPEN | |
return action_return | |
async def __call__(self, | |
message: AgentMessage, | |
session_id=0, | |
**kwargs) -> AgentMessage: | |
# message.receiver = self.name | |
for hook in self._hooks.values(): | |
if inspect.iscoroutinefunction(hook.before_action): | |
result = await hook.before_action(self, message, session_id) | |
else: | |
result = hook.before_action(self, message, session_id) | |
if result: | |
message = result | |
assert isinstance(message.content, FunctionCall) or ( | |
isinstance(message.content, dict) and 'name' in message.content | |
and 'parameters' in message.content) | |
if isinstance(message.content, dict): | |
name = message.content.get('name') | |
parameters = message.content.get('parameters') | |
else: | |
name = message.content.name | |
parameters = message.content.parameters | |
response_message = await self.forward( | |
name=name, parameters=parameters, **kwargs) | |
if not isinstance(response_message, AgentMessage): | |
response_message = AgentMessage( | |
sender=self.__class__.__name__, | |
content=response_message, | |
) | |
for hook in self._hooks.values(): | |
if inspect.iscoroutinefunction(hook.after_action): | |
result = await hook.after_action(self, response_message, | |
session_id) | |
else: | |
result = hook.after_action(self, response_message, session_id) | |
if result: | |
response_message = result | |
return response_message | |