File size: 5,379 Bytes
e679d69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import json
import re
from ast import literal_eval
from typing import Any, List, Union


class ParseError(Exception):
    """Parsing exception class."""

    def __init__(self, err_msg: str):
        self.err_msg = err_msg


class BaseParser:
    """Base parser to process inputs and outputs of actions.

    Args:
        action (:class:`BaseAction`): action to validate

    Attributes:
        PARAMETER_DESCRIPTION (:class:`str`): declare the input format which
            LLMs should follow when generating arguments for decided tools.
    """

    PARAMETER_DESCRIPTION: str = ''

    def __init__(self, action):
        self.action = action
        self._api2param = {}
        self._api2required = {}
        # perform basic argument validation
        if action.description:
            for api in action.description.get('api_list',
                                              [action.description]):
                name = (f'{action.name}.{api["name"]}'
                        if self.action.is_toolkit else api['name'])
                required_parameters = set(api['required'])
                all_parameters = {j['name'] for j in api['parameters']}
                if not required_parameters.issubset(all_parameters):
                    raise ValueError(
                        f'unknown parameters for function "{name}": '
                        f'{required_parameters - all_parameters}')
                if self.PARAMETER_DESCRIPTION:
                    api['parameter_description'] = self.PARAMETER_DESCRIPTION
                api_name = api['name'] if self.action.is_toolkit else 'run'
                self._api2param[api_name] = api['parameters']
                self._api2required[api_name] = api['required']

    def parse_inputs(self, inputs: str, name: str = 'run') -> dict:
        """Parse inputs LLMs generate for the action.

        Args:
            inputs (:class:`str`): input string extracted from responses

        Returns:
            :class:`dict`: processed input
        """
        inputs = {self._api2param[name][0]['name']: inputs}
        return inputs

    def parse_outputs(self, outputs: Any) -> List[dict]:
        """Parser outputs returned by the action.

        Args:
            outputs (:class:`Any`): raw output of the action

        Returns:
            :class:`List[dict]`: processed output of which each member is a
                dictionary with two keys - 'type' and 'content'.
        """
        if isinstance(outputs, dict):
            outputs = json.dumps(outputs, ensure_ascii=False)
        elif not isinstance(outputs, str):
            outputs = str(outputs)
        return [{
            'type': 'text',
            'content': outputs.encode('gbk', 'ignore').decode('gbk')
        }]


class JsonParser(BaseParser):
    """Json parser to convert input string into a dictionary.

    Args:
        action (:class:`BaseAction`): action to validate
    """

    PARAMETER_DESCRIPTION = (
        'If you call this tool, you must pass arguments in '
        'the JSON format {key: value}, where the key is the parameter name.')

    def parse_inputs(self,
                     inputs: Union[str, dict],
                     name: str = 'run') -> dict:
        if not isinstance(inputs, dict):
            try:
                match = re.search(r'^\s*(```json\n)?(.*)\n```\s*$', inputs,
                                  re.S)
                if match:
                    inputs = match.group(2).strip()
                inputs = json.loads(inputs)
            except json.JSONDecodeError as exc:
                raise ParseError(f'invalid json format: {inputs}') from exc
        input_keys = set(inputs)
        all_keys = {param['name'] for param in self._api2param[name]}
        if not input_keys.issubset(all_keys):
            raise ParseError(f'unknown arguments: {input_keys - all_keys}')
        required_keys = set(self._api2required[name])
        if not input_keys.issuperset(required_keys):
            raise ParseError(
                f'missing required arguments: {required_keys - input_keys}')
        return inputs


class TupleParser(BaseParser):
    """Tuple parser to convert input string into a tuple.

    Args:
        action (:class:`BaseAction`): action to validate
    """

    PARAMETER_DESCRIPTION = (
        'If you call this tool, you must pass arguments in the tuple format '
        'like (arg1, arg2, arg3), and the arguments are ordered.')

    def parse_inputs(self,
                     inputs: Union[str, tuple],
                     name: str = 'run') -> dict:
        if not isinstance(inputs, tuple):
            try:
                inputs = literal_eval(inputs)
            except Exception as exc:
                raise ParseError(f'invalid tuple format: {inputs}') from exc
        if len(inputs) < len(self._api2required[name]):
            raise ParseError(
                f'API takes {len(self._api2required[name])} required positional '
                f'arguments but {len(inputs)} were given')
        if len(inputs) > len(self._api2param[name]):
            raise ParseError(
                f'API takes {len(self._api2param[name])} positional arguments '
                f'but {len(inputs)} were given')
        inputs = {
            self._api2param[name][i]['name']: item
            for i, item in enumerate(inputs)
        }
        return inputs