Spaces:
Sleeping
Sleeping
from typing import Dict, List, Optional, Union | |
from lagent.agents.aggregator.default_aggregator import DefaultAggregator | |
from lagent.memory.base_memory import Memory | |
from lagent.prompts.parsers.tool_parser import MixedToolParser, ToolParser, ToolStatusCode | |
class InternLMToolAggregator(DefaultAggregator): | |
def __init__(self, | |
environment_role='environment', | |
environment_begin='', | |
environment_end='', | |
user_names: Optional[List[str]] = None, | |
few_shot: Optional[List[List[dict]]] = None): | |
self.environment_role = environment_role | |
self.environment_begin = environment_begin | |
self.environment_end = environment_end | |
self.user_names = user_names or ['user'] | |
self.few_shot = few_shot or [] | |
def aggregate(self, | |
messages: Memory, | |
name: str, | |
parser: Union[ToolParser, MixedToolParser], | |
system_instruction: str = None) -> List[Dict[str, str]]: | |
_message = [] | |
messages = messages.get_memory() | |
if system_instruction: | |
_message.extend( | |
self.aggregate_system_intruction(system_instruction)) | |
tool_instruction = parser.format_instruction() | |
if tool_instruction: | |
if isinstance(tool_instruction, str): | |
tool_instruction = dict( | |
role='system', content=tool_instruction) | |
if parser.tool_type: | |
tool_instruction['name'] = parser.tool_type | |
if isinstance(tool_instruction, dict): | |
tool_instruction = [tool_instruction] | |
_message.extend(tool_instruction) | |
for shot in self.few_shot: | |
i = 0 | |
while i < len(shot): | |
msg = shot[i] | |
if msg['role'] in ['assistant', 'user', 'system']: | |
_message.append(msg) | |
elif msg['role'] == self.environment_role: | |
if not msg['content'].startswith(self.environment_begin): | |
msg['content'] = self.environment_begin + msg['content'] | |
if not msg['content'].endswith(self.environment_end): | |
msg['content'] += self.environment_end | |
_message.append(msg) | |
elif msg['role'] in ['thought', 'language']: | |
if i < len(shot) - 1 and shot[i + 1]['role'] == 'tool': | |
_message.append( | |
dict( | |
role='assistant', | |
content=parser.format_response( | |
dict( | |
tool_type=shot[i + 1]['name'], | |
thought=msg['content'], | |
action=shot[i + 1]['content'], | |
status=None)))) | |
i += 1 | |
else: | |
_message.append( | |
dict( | |
role='assistant', | |
content=parser.format_response( | |
dict( | |
tool_type=None, | |
thought=msg['content'], | |
action=None, | |
status=None)))) | |
else: | |
raise KeyError(f'Unkown role: {msg["role"]}') | |
i += 1 | |
tool_type = None | |
for message in messages: | |
if message.sender == name: | |
if isinstance(message.formatted, dict): | |
parsed = message.formatted | |
if parsed['status'] == ToolStatusCode.PARSING_ERROR: | |
continue | |
_message.append( | |
dict( | |
role='assistant', | |
content=parser.format_response(parsed))) | |
tool_type = parsed['tool_type'] | |
else: | |
_message.append( | |
dict(role='assistant', content=str(message.content))) | |
elif message.sender in self.user_names: | |
_message.append(dict(role='user', content=message.content)) | |
else: | |
msg = dict( | |
role=self.environment_role, | |
content=self.environment_begin + str(message.content) + | |
self.environment_end) | |
if tool_type: | |
msg['name'] = tool_type | |
_message.append(msg) | |
return _message | |