File size: 4,876 Bytes
e679d69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import json
from enum import IntEnum

# import re
from typing import Any, Callable, List, Optional

from lagent.prompts.parsers import StrParser
from lagent.utils import create_object, load_class_from_string


def default_plugin_validate(plugin: str):
    plugin = plugin.strip()
    if not (plugin.startswith('{') and plugin.endswith("}")):
        raise json.decoder.JSONDecodeError
    return json.loads(plugin)


class ToolStatusCode(IntEnum):
    NO_TOOL = 0
    VALID_TOOL = 1
    PARSING_ERROR = -1


class ToolParser(StrParser):

    def __init__(self,
                 tool_type: str,
                 template: str = '',
                 begin: str = '<tool>\n',
                 end: str = '</tool>\n',
                 validate: Callable[[str], Any] = None,
                 **kwargs):
        super().__init__(template, begin=begin, end=end, **kwargs)
        self.template = template
        self.tool_type = tool_type
        # self.pattern = re.compile(
        #     '(.*?){}(.*)({})?'.format(re.escape(begin), re.escape(end)),
        #     re.DOTALL)
        self.validate = load_class_from_string(validate) if isinstance(
            validate, str) else validate

    def parse_response(self, data: str) -> dict:
        if self.format_field['begin'] not in data:
            return dict(
                tool_type=None,
                thought=data,
                action=None,
                status=ToolStatusCode.NO_TOOL)
        thought, action, *_ = data.split(self.format_field["begin"])
        action = action.split(self.format_field['end'])[0]
        status = ToolStatusCode.VALID_TOOL
        if self.validate:
            try:
                action = self.validate(action)
            except Exception:
                status = ToolStatusCode.PARSING_ERROR
        return dict(
            tool_type=self.tool_type,
            thought=thought,
            action=action,
            status=status)

    def format_response(self, parsed: dict) -> str:
        if parsed['action'] is None:
            return parsed['thought']
        assert parsed['tool_type'] == self.tool_type
        if isinstance(parsed['action'], dict):
            action = json.dumps(parsed['action'], ensure_ascii=False)
        else:
            action = str(parsed['action'])
        return parsed['thought'] + self.format_field[
            'begin'] + action + self.format_field['end']


class InterpreterParser(ToolParser):

    def __init__(self,
                 tool_type: str = 'interpreter',
                 template: str = '',
                 begin: str = '<|action_start|><|interpreter|>\n',
                 end: str = '<|action_end|>\n',
                 validate: Callable[[str], Any] = None,
                 **kwargs):
        super().__init__(tool_type, template, begin, end, validate, **kwargs)


class PluginParser(ToolParser):

    def __init__(self,
                 tool_type: str = 'plugin',
                 template: str = '',
                 begin: str = '<|action_start|><|plugin|>\n',
                 end: str = '<|action_end|>\n',
                 validate: Callable[[str], Any] = default_plugin_validate,
                 **kwargs):
        super().__init__(tool_type, template, begin, end, validate, **kwargs)


class MixedToolParser(StrParser):

    def __init__(self,
                 tool_type: Optional[str] = None,
                 template='',
                 parsers: List[ToolParser] = None,
                 **format_field):
        self.parsers = {}
        self.tool_type = tool_type
        for parser in parsers or []:
            parser = create_object(parser)
            self.parsers[parser.tool_type] = parser
        super().__init__(template, **format_field)

    def format_instruction(self) -> List[dict]:
        inst = []
        content = super().format_instruction()
        if content.strip():
            msg = dict(role='system', content=content)
            if self.tool_type:
                msg['name'] = self.tool_type
            inst.append(msg)
        for name, parser in self.parsers.items():
            content = parser.format_instruction()
            if content.strip():
                inst.append(dict(role='system', content=content, name=name))
        return inst

    def parse_response(self, data: str) -> dict:
        res = dict(
            tool_type=None,
            thought=data,
            action=None,
            status=ToolStatusCode.NO_TOOL)
        for name, parser in self.parsers.items():
            res = parser.parse_response(data)
            if res['tool_type'] == name:
                break
        return res

    def format_response(self, parsed: dict) -> str:
        if parsed['action'] is None:
            return parsed['thought']
        assert parsed['tool_type'] in self.parsers
        return self.parsers[parsed['tool_type']].format_response(parsed)