Spaces:
Running
Running
from copy import deepcopy | |
from lagent.schema import ActionReturn, ActionStatusCode, FunctionCall | |
from .hook import Hook | |
class ActionPreprocessor(Hook): | |
"""The ActionPreprocessor is a hook that preprocesses the action message | |
and postprocesses the action return message. | |
""" | |
def before_action(self, executor, message, session_id): | |
assert isinstance(message.formatted, FunctionCall) or ( | |
isinstance(message.formatted, dict) and 'name' in message.content | |
and 'parameters' in message.formatted) or ( | |
'action' in message.formatted | |
and 'parameters' in message.formatted['action'] | |
and 'name' in message.formatted['action']) | |
if isinstance(message.formatted, dict): | |
name = message.formatted.get('name', | |
message.formatted['action']['name']) | |
parameters = message.formatted.get( | |
'parameters', message.formatted['action']['parameters']) | |
else: | |
name = message.formatted.name | |
parameters = message.formatted.parameters | |
message.content = dict(name=name, parameters=parameters) | |
return message | |
def after_action(self, executor, message, session_id): | |
action_return = message.content | |
if isinstance(action_return, ActionReturn): | |
if action_return.state == ActionStatusCode.SUCCESS: | |
response = action_return.format_result() | |
else: | |
response = action_return.errmsg | |
else: | |
response = action_return | |
message.content = response | |
return message | |
class InternLMActionProcessor(ActionPreprocessor): | |
def __init__(self, code_parameter: str = 'command'): | |
self.code_parameter = code_parameter | |
def before_action(self, executor, message, session_id): | |
message = deepcopy(message) | |
assert isinstance(message.formatted, dict) and set( | |
message.formatted).issuperset( | |
{'tool_type', 'thought', 'action', 'status'}) | |
if isinstance(message.formatted['action'], str): | |
# encapsulate code interpreter arguments | |
action_name = next(iter(executor.actions)) | |
parameters = {self.code_parameter: message.formatted['action']} | |
if action_name in ['AsyncIPythonInterpreter']: | |
parameters['session_id'] = session_id | |
message.formatted['action'] = dict( | |
name=action_name, parameters=parameters) | |
return super().before_action(executor, message, session_id) | |