import inspect import logging import re from abc import ABCMeta from copy import deepcopy from functools import wraps from typing import Callable, Optional, Type, get_args, get_origin try: from typing import Annotated except ImportError: from typing_extensions import Annotated from griffe import Docstring try: from griffe import DocstringSectionKind except ImportError: from griffe.enumerations import DocstringSectionKind from ..schema import ActionReturn, ActionStatusCode from .parser import BaseParser, JsonParser, ParseError logging.getLogger('griffe').setLevel(logging.ERROR) def tool_api(func: Optional[Callable] = None, *, explode_return: bool = False, returns_named_value: bool = False, **kwargs): """Turn functions into tools. It will parse typehints as well as docstrings to build the tool description and attach it to functions via an attribute ``api_description``. Examples: .. code-block:: python # typehints has higher priority than docstrings from typing import Annotated @tool_api def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1): '''Add operation Args: x (int): a y (int): b ''' return a + b print(add.api_description) Args: func (Optional[Callable]): function to decorate. Defaults to ``None``. explode_return (bool): whether to flatten the dictionary or tuple return as the ``return_data`` field. When enabled, it is recommended to annotate the member in docstrings. Defaults to ``False``. .. code-block:: python @tool_api(explode_return=True) def foo(a, b): '''A simple function Args: a (int): a b (int): b Returns: dict: information of inputs * x: value of a * y: value of b ''' return {'x': a, 'y': b} print(foo.api_description) returns_named_value (bool): whether to parse ``thing: Description`` in returns sections as a name and description, rather than a type and description. When true, type must be wrapped in parentheses: ``(int): Description``. When false, parentheses are optional but the items cannot be named: ``int: Description``. Defaults to ``False``. Returns: Callable: wrapped function or partial decorator Important: ``return_data`` field will be added to ``api_description`` only when ``explode_return`` or ``returns_named_value`` is enabled. """ def _detect_type(string): field_type = 'STRING' if 'list' in string: field_type = 'Array' elif 'str' not in string: if 'float' in string: field_type = 'FLOAT' elif 'int' in string: field_type = 'NUMBER' elif 'bool' in string: field_type = 'BOOLEAN' return field_type def _explode(desc): kvs = [] desc = '\nArgs:\n' + '\n'.join([ ' ' + item.lstrip(' -+*#.') for item in desc.split('\n')[1:] if item.strip() ]) docs = Docstring(desc).parse('google') if not docs: return kvs if docs[0].kind is DocstringSectionKind.parameters: for d in docs[0].value: d = d.as_dict() if not d['annotation']: d.pop('annotation') else: d['type'] = _detect_type(d.pop('annotation').lower()) kvs.append(d) return kvs def _parse_tool(function): # remove rst syntax docs = Docstring( re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse( 'google', returns_named_value=returns_named_value, **kwargs) desc = dict( name=function.__name__, description=docs[0].value if docs[0].kind is DocstringSectionKind.text else '', parameters=[], required=[], ) args_doc, returns_doc = {}, [] for doc in docs: if doc.kind is DocstringSectionKind.parameters: for d in doc.value: d = d.as_dict() d['type'] = _detect_type(d.pop('annotation').lower()) args_doc[d['name']] = d if doc.kind is DocstringSectionKind.returns: for d in doc.value: d = d.as_dict() if not d['name']: d.pop('name') if not d['annotation']: d.pop('annotation') else: d['type'] = _detect_type(d.pop('annotation').lower()) returns_doc.append(d) sig = inspect.signature(function) for name, param in sig.parameters.items(): if name == 'self': continue parameter = dict( name=param.name, type='STRING', description=args_doc.get(param.name, {}).get('description', '')) annotation = param.annotation if annotation is inspect.Signature.empty: parameter['type'] = args_doc.get(param.name, {}).get('type', 'STRING') else: if get_origin(annotation) is Annotated: annotation, info = get_args(annotation) if info: parameter['description'] = info while get_origin(annotation): annotation = get_args(annotation) parameter['type'] = _detect_type(str(annotation)) desc['parameters'].append(parameter) if param.default is inspect.Signature.empty: desc['required'].append(param.name) return_data = [] if explode_return: return_data = _explode(returns_doc[0]['description']) elif returns_named_value: return_data = returns_doc if return_data: desc['return_data'] = return_data return desc if callable(func): if inspect.iscoroutinefunction(func): @wraps(func) async def wrapper(self, *args, **kwargs): return await func(self, *args, **kwargs) else: @wraps(func) def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) wrapper.api_description = _parse_tool(func) return wrapper def decorate(func): if inspect.iscoroutinefunction(func): @wraps(func) async def wrapper(self, *args, **kwargs): return await func(self, *args, **kwargs) else: @wraps(func) def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) wrapper.api_description = _parse_tool(func) return wrapper return decorate class ToolMeta(ABCMeta): """Metaclass of tools.""" def __new__(mcs, name, base, attrs): is_toolkit, tool_desc = True, dict( name=name, description=Docstring(attrs.get('__doc__', '')).parse('google')[0].value) for key, value in attrs.items(): if callable(value) and hasattr(value, 'api_description'): api_desc = getattr(value, 'api_description') if key == 'run': tool_desc['parameters'] = api_desc['parameters'] tool_desc['required'] = api_desc['required'] if api_desc['description']: tool_desc['description'] = api_desc['description'] if api_desc.get('return_data'): tool_desc['return_data'] = api_desc['return_data'] is_toolkit = False else: tool_desc.setdefault('api_list', []).append(api_desc) if not is_toolkit and 'api_list' in tool_desc: raise KeyError('`run` and other tool APIs can not be implemented ' 'at the same time') if is_toolkit and 'api_list' not in tool_desc: is_toolkit = False if callable(attrs.get('run')): run_api = tool_api(attrs['run']) api_desc = run_api.api_description tool_desc['parameters'] = api_desc['parameters'] tool_desc['required'] = api_desc['required'] if api_desc['description']: tool_desc['description'] = api_desc['description'] if api_desc.get('return_data'): tool_desc['return_data'] = api_desc['return_data'] attrs['run'] = run_api else: tool_desc['parameters'], tool_desc['required'] = [], [] attrs['_is_toolkit'] = is_toolkit attrs['__tool_description__'] = tool_desc return super().__new__(mcs, name, base, attrs) class BaseAction(metaclass=ToolMeta): """Base class for all actions. Args: description (:class:`Optional[dict]`): The description of the action. Defaults to ``None``. parser (:class:`Type[BaseParser]`): The parser class to process the action's inputs and outputs. Defaults to :class:`JsonParser`. Examples: * simple tool .. code-block:: python class Bold(BaseAction): '''Make text bold''' def run(self, text: str): ''' Args: text (str): input text Returns: str: bold text ''' return '**' + text + '**' action = Bold() * toolkit with multiple APIs .. code-block:: python class Calculator(BaseAction): '''Calculator''' @tool_api def add(self, a, b): '''Add operation Args: a (int): augend b (int): addend Returns: int: sum ''' return a + b @tool_api def sub(self, a, b): '''Subtraction operation Args: a (int): minuend b (int): subtrahend Returns: int: difference ''' return a - b action = Calculator() """ def __init__( self, description: Optional[dict] = None, parser: Type[BaseParser] = JsonParser, ): self._description = deepcopy(description or self.__tool_description__) self._name = self._description['name'] self._parser = parser(self) def __call__(self, inputs: str, name='run') -> ActionReturn: fallback_args = {'inputs': inputs, 'name': name} if not hasattr(self, name): return ActionReturn( fallback_args, type=self.name, errmsg=f'invalid API: {name}', state=ActionStatusCode.API_ERROR) try: inputs = self._parser.parse_inputs(inputs, name) except ParseError as exc: return ActionReturn( fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR) try: outputs = getattr(self, name)(**inputs) except Exception as exc: return ActionReturn( inputs, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR) if isinstance(outputs, ActionReturn): action_return = outputs if not action_return.args: action_return.args = inputs if not action_return.type: action_return.type = self.name else: result = self._parser.parse_outputs(outputs) action_return = ActionReturn(inputs, type=self.name, result=result) return action_return @property def name(self): return self._name @property def is_toolkit(self): return self._is_toolkit @property def description(self) -> dict: """Description of the tool.""" return self._description def __repr__(self): return f'{self.description}' __str__ = __repr__ class AsyncActionMixin: async def __call__(self, inputs: str, name='run') -> ActionReturn: fallback_args = {'inputs': inputs, 'name': name} if not hasattr(self, name): return ActionReturn( fallback_args, type=self.name, errmsg=f'invalid API: {name}', state=ActionStatusCode.API_ERROR) try: inputs = self._parser.parse_inputs(inputs, name) except ParseError as exc: return ActionReturn( fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR) try: outputs = await getattr(self, name)(**inputs) except Exception as exc: return ActionReturn( inputs, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR) if isinstance(outputs, ActionReturn): action_return = outputs if not action_return.args: action_return.args = inputs if not action_return.type: action_return.type = self.name else: result = self._parser.parse_outputs(outputs) action_return = ActionReturn(inputs, type=self.name, result=result) return action_return