Spaces:
Sleeping
Sleeping
import inspect | |
import logging | |
import re | |
from abc import ABCMeta | |
from copy import deepcopy | |
from functools import wraps | |
from typing import Callable, Optional, Type, get_args, get_origin | |
try: | |
from typing import Annotated | |
except ImportError: | |
from typing_extensions import Annotated | |
from griffe import Docstring | |
try: | |
from griffe import DocstringSectionKind | |
except ImportError: | |
from griffe.enumerations import DocstringSectionKind | |
from ..schema import ActionReturn, ActionStatusCode | |
from .parser import BaseParser, JsonParser, ParseError | |
logging.getLogger('griffe').setLevel(logging.ERROR) | |
def tool_api(func: Optional[Callable] = None, | |
*, | |
explode_return: bool = False, | |
returns_named_value: bool = False, | |
**kwargs): | |
"""Turn functions into tools. It will parse typehints as well as docstrings | |
to build the tool description and attach it to functions via an attribute | |
``api_description``. | |
Examples: | |
.. code-block:: python | |
# typehints has higher priority than docstrings | |
from typing import Annotated | |
@tool_api | |
def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1): | |
'''Add operation | |
Args: | |
x (int): a | |
y (int): b | |
''' | |
return a + b | |
print(add.api_description) | |
Args: | |
func (Optional[Callable]): function to decorate. Defaults to ``None``. | |
explode_return (bool): whether to flatten the dictionary or tuple return | |
as the ``return_data`` field. When enabled, it is recommended to | |
annotate the member in docstrings. Defaults to ``False``. | |
.. code-block:: python | |
@tool_api(explode_return=True) | |
def foo(a, b): | |
'''A simple function | |
Args: | |
a (int): a | |
b (int): b | |
Returns: | |
dict: information of inputs | |
* x: value of a | |
* y: value of b | |
''' | |
return {'x': a, 'y': b} | |
print(foo.api_description) | |
returns_named_value (bool): whether to parse ``thing: Description`` in | |
returns sections as a name and description, rather than a type and | |
description. When true, type must be wrapped in parentheses: | |
``(int): Description``. When false, parentheses are optional but | |
the items cannot be named: ``int: Description``. Defaults to ``False``. | |
Returns: | |
Callable: wrapped function or partial decorator | |
Important: | |
``return_data`` field will be added to ``api_description`` only | |
when ``explode_return`` or ``returns_named_value`` is enabled. | |
""" | |
def _detect_type(string): | |
field_type = 'STRING' | |
if 'list' in string: | |
field_type = 'Array' | |
elif 'str' not in string: | |
if 'float' in string: | |
field_type = 'FLOAT' | |
elif 'int' in string: | |
field_type = 'NUMBER' | |
elif 'bool' in string: | |
field_type = 'BOOLEAN' | |
return field_type | |
def _explode(desc): | |
kvs = [] | |
desc = '\nArgs:\n' + '\n'.join([ | |
' ' + item.lstrip(' -+*#.') | |
for item in desc.split('\n')[1:] if item.strip() | |
]) | |
docs = Docstring(desc).parse('google') | |
if not docs: | |
return kvs | |
if docs[0].kind is DocstringSectionKind.parameters: | |
for d in docs[0].value: | |
d = d.as_dict() | |
if not d['annotation']: | |
d.pop('annotation') | |
else: | |
d['type'] = _detect_type(d.pop('annotation').lower()) | |
kvs.append(d) | |
return kvs | |
def _parse_tool(function): | |
# remove rst syntax | |
docs = Docstring( | |
re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse( | |
'google', returns_named_value=returns_named_value, **kwargs) | |
desc = dict( | |
name=function.__name__, | |
description=docs[0].value | |
if docs[0].kind is DocstringSectionKind.text else '', | |
parameters=[], | |
required=[], | |
) | |
args_doc, returns_doc = {}, [] | |
for doc in docs: | |
if doc.kind is DocstringSectionKind.parameters: | |
for d in doc.value: | |
d = d.as_dict() | |
d['type'] = _detect_type(d.pop('annotation').lower()) | |
args_doc[d['name']] = d | |
if doc.kind is DocstringSectionKind.returns: | |
for d in doc.value: | |
d = d.as_dict() | |
if not d['name']: | |
d.pop('name') | |
if not d['annotation']: | |
d.pop('annotation') | |
else: | |
d['type'] = _detect_type(d.pop('annotation').lower()) | |
returns_doc.append(d) | |
sig = inspect.signature(function) | |
for name, param in sig.parameters.items(): | |
if name == 'self': | |
continue | |
parameter = dict( | |
name=param.name, | |
type='STRING', | |
description=args_doc.get(param.name, | |
{}).get('description', '')) | |
annotation = param.annotation | |
if annotation is inspect.Signature.empty: | |
parameter['type'] = args_doc.get(param.name, | |
{}).get('type', 'STRING') | |
else: | |
if get_origin(annotation) is Annotated: | |
annotation, info = get_args(annotation) | |
if info: | |
parameter['description'] = info | |
while get_origin(annotation): | |
annotation = get_args(annotation) | |
parameter['type'] = _detect_type(str(annotation)) | |
desc['parameters'].append(parameter) | |
if param.default is inspect.Signature.empty: | |
desc['required'].append(param.name) | |
return_data = [] | |
if explode_return: | |
return_data = _explode(returns_doc[0]['description']) | |
elif returns_named_value: | |
return_data = returns_doc | |
if return_data: | |
desc['return_data'] = return_data | |
return desc | |
if callable(func): | |
if inspect.iscoroutinefunction(func): | |
async def wrapper(self, *args, **kwargs): | |
return await func(self, *args, **kwargs) | |
else: | |
def wrapper(self, *args, **kwargs): | |
return func(self, *args, **kwargs) | |
wrapper.api_description = _parse_tool(func) | |
return wrapper | |
def decorate(func): | |
if inspect.iscoroutinefunction(func): | |
async def wrapper(self, *args, **kwargs): | |
return await func(self, *args, **kwargs) | |
else: | |
def wrapper(self, *args, **kwargs): | |
return func(self, *args, **kwargs) | |
wrapper.api_description = _parse_tool(func) | |
return wrapper | |
return decorate | |
class ToolMeta(ABCMeta): | |
"""Metaclass of tools.""" | |
def __new__(mcs, name, base, attrs): | |
is_toolkit, tool_desc = True, dict( | |
name=name, | |
description=Docstring(attrs.get('__doc__', | |
'')).parse('google')[0].value) | |
for key, value in attrs.items(): | |
if callable(value) and hasattr(value, 'api_description'): | |
api_desc = getattr(value, 'api_description') | |
if key == 'run': | |
tool_desc['parameters'] = api_desc['parameters'] | |
tool_desc['required'] = api_desc['required'] | |
if api_desc['description']: | |
tool_desc['description'] = api_desc['description'] | |
if api_desc.get('return_data'): | |
tool_desc['return_data'] = api_desc['return_data'] | |
is_toolkit = False | |
else: | |
tool_desc.setdefault('api_list', []).append(api_desc) | |
if not is_toolkit and 'api_list' in tool_desc: | |
raise KeyError('`run` and other tool APIs can not be implemented ' | |
'at the same time') | |
if is_toolkit and 'api_list' not in tool_desc: | |
is_toolkit = False | |
if callable(attrs.get('run')): | |
run_api = tool_api(attrs['run']) | |
api_desc = run_api.api_description | |
tool_desc['parameters'] = api_desc['parameters'] | |
tool_desc['required'] = api_desc['required'] | |
if api_desc['description']: | |
tool_desc['description'] = api_desc['description'] | |
if api_desc.get('return_data'): | |
tool_desc['return_data'] = api_desc['return_data'] | |
attrs['run'] = run_api | |
else: | |
tool_desc['parameters'], tool_desc['required'] = [], [] | |
attrs['_is_toolkit'] = is_toolkit | |
attrs['__tool_description__'] = tool_desc | |
return super().__new__(mcs, name, base, attrs) | |
class BaseAction(metaclass=ToolMeta): | |
"""Base class for all actions. | |
Args: | |
description (:class:`Optional[dict]`): The description of the action. | |
Defaults to ``None``. | |
parser (:class:`Type[BaseParser]`): The parser class to process the | |
action's inputs and outputs. Defaults to :class:`JsonParser`. | |
Examples: | |
* simple tool | |
.. code-block:: python | |
class Bold(BaseAction): | |
'''Make text bold''' | |
def run(self, text: str): | |
''' | |
Args: | |
text (str): input text | |
Returns: | |
str: bold text | |
''' | |
return '**' + text + '**' | |
action = Bold() | |
* toolkit with multiple APIs | |
.. code-block:: python | |
class Calculator(BaseAction): | |
'''Calculator''' | |
@tool_api | |
def add(self, a, b): | |
'''Add operation | |
Args: | |
a (int): augend | |
b (int): addend | |
Returns: | |
int: sum | |
''' | |
return a + b | |
@tool_api | |
def sub(self, a, b): | |
'''Subtraction operation | |
Args: | |
a (int): minuend | |
b (int): subtrahend | |
Returns: | |
int: difference | |
''' | |
return a - b | |
action = Calculator() | |
""" | |
def __init__( | |
self, | |
description: Optional[dict] = None, | |
parser: Type[BaseParser] = JsonParser, | |
): | |
self._description = deepcopy(description or self.__tool_description__) | |
self._name = self._description['name'] | |
self._parser = parser(self) | |
def __call__(self, inputs: str, name='run') -> ActionReturn: | |
fallback_args = {'inputs': inputs, 'name': name} | |
if not hasattr(self, name): | |
return ActionReturn( | |
fallback_args, | |
type=self.name, | |
errmsg=f'invalid API: {name}', | |
state=ActionStatusCode.API_ERROR) | |
try: | |
inputs = self._parser.parse_inputs(inputs, name) | |
except ParseError as exc: | |
return ActionReturn( | |
fallback_args, | |
type=self.name, | |
errmsg=exc.err_msg, | |
state=ActionStatusCode.ARGS_ERROR) | |
try: | |
outputs = getattr(self, name)(**inputs) | |
except Exception as exc: | |
return ActionReturn( | |
inputs, | |
type=self.name, | |
errmsg=str(exc), | |
state=ActionStatusCode.API_ERROR) | |
if isinstance(outputs, ActionReturn): | |
action_return = outputs | |
if not action_return.args: | |
action_return.args = inputs | |
if not action_return.type: | |
action_return.type = self.name | |
else: | |
result = self._parser.parse_outputs(outputs) | |
action_return = ActionReturn(inputs, type=self.name, result=result) | |
return action_return | |
def name(self): | |
return self._name | |
def is_toolkit(self): | |
return self._is_toolkit | |
def description(self) -> dict: | |
"""Description of the tool.""" | |
return self._description | |
def __repr__(self): | |
return f'{self.description}' | |
__str__ = __repr__ | |
class AsyncActionMixin: | |
async def __call__(self, inputs: str, name='run') -> ActionReturn: | |
fallback_args = {'inputs': inputs, 'name': name} | |
if not hasattr(self, name): | |
return ActionReturn( | |
fallback_args, | |
type=self.name, | |
errmsg=f'invalid API: {name}', | |
state=ActionStatusCode.API_ERROR) | |
try: | |
inputs = self._parser.parse_inputs(inputs, name) | |
except ParseError as exc: | |
return ActionReturn( | |
fallback_args, | |
type=self.name, | |
errmsg=exc.err_msg, | |
state=ActionStatusCode.ARGS_ERROR) | |
try: | |
outputs = await getattr(self, name)(**inputs) | |
except Exception as exc: | |
return ActionReturn( | |
inputs, | |
type=self.name, | |
errmsg=str(exc), | |
state=ActionStatusCode.API_ERROR) | |
if isinstance(outputs, ActionReturn): | |
action_return = outputs | |
if not action_return.args: | |
action_return.args = inputs | |
if not action_return.type: | |
action_return.type = self.name | |
else: | |
result = self._parser.parse_outputs(outputs) | |
action_return = ActionReturn(inputs, type=self.name, result=result) | |
return action_return | |