diff --git a/lagent/__init__.py b/lagent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7b4a71d5a54e4f6a3bff362dfa29016277be4e --- /dev/null +++ b/lagent/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .version import __version__, version_info + +__all__ = ['__version__', 'version_info'] diff --git a/lagent/__pycache__/__init__.cpython-310.pyc b/lagent/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56c98f2137896168a11059447471a13f838218c4 Binary files /dev/null and b/lagent/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/__pycache__/schema.cpython-310.pyc b/lagent/__pycache__/schema.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e37fd64230bac564c054add83896356262bd68af Binary files /dev/null and b/lagent/__pycache__/schema.cpython-310.pyc differ diff --git a/lagent/__pycache__/version.cpython-310.pyc b/lagent/__pycache__/version.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9520e8e7e4ea74e697dfc2ea5a0165e513176ca Binary files /dev/null and b/lagent/__pycache__/version.cpython-310.pyc differ diff --git a/lagent/actions/__init__.py b/lagent/actions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6f777710682ce566d813dfcff56a1f93e37e9845 --- /dev/null +++ b/lagent/actions/__init__.py @@ -0,0 +1,26 @@ +from .action_executor import ActionExecutor, AsyncActionExecutor +from .arxiv_search import ArxivSearch, AsyncArxivSearch +from .base_action import BaseAction, tool_api +from .bing_map import AsyncBINGMap, BINGMap +from .builtin_actions import FinishAction, InvalidAction, NoAction +from .google_scholar_search import AsyncGoogleScholar, GoogleScholar +from .google_search import AsyncGoogleSearch, GoogleSearch +from .ipython_interactive import AsyncIPythonInteractive, IPythonInteractive +from .ipython_interpreter import AsyncIPythonInterpreter, IPythonInterpreter +from .ipython_manager import IPythonInteractiveManager +from .parser import BaseParser, JsonParser, TupleParser +from .ppt import PPT, AsyncPPT +from .python_interpreter import AsyncPythonInterpreter, PythonInterpreter +from .web_browser import AsyncWebBrowser, WebBrowser +from .weather_query import WeatherQuery + +__all__ = [ + 'BaseAction', 'ActionExecutor', 'AsyncActionExecutor', 'InvalidAction', + 'FinishAction', 'NoAction', 'BINGMap', 'AsyncBINGMap', 'ArxivSearch', + 'AsyncArxivSearch', 'GoogleSearch', 'AsyncGoogleSearch', 'GoogleScholar', + 'AsyncGoogleScholar', 'IPythonInterpreter', 'AsyncIPythonInterpreter', + 'IPythonInteractive', 'AsyncIPythonInteractive', + 'IPythonInteractiveManager', 'PythonInterpreter', 'AsyncPythonInterpreter', + 'PPT', 'AsyncPPT', 'WebBrowser', 'AsyncWebBrowser', 'BaseParser', + 'JsonParser', 'TupleParser', 'tool_api', 'WeatherQuery' +] diff --git a/lagent/actions/__pycache__/__init__.cpython-310.pyc b/lagent/actions/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e53c7d4a7f94ec2d1e0622007f7e6606ee919a32 Binary files /dev/null and b/lagent/actions/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/action_executor.cpython-310.pyc b/lagent/actions/__pycache__/action_executor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2402315169a49a2804bc6aa94bf6ca881c51a872 Binary files /dev/null and b/lagent/actions/__pycache__/action_executor.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/arxiv_search.cpython-310.pyc b/lagent/actions/__pycache__/arxiv_search.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22ffc427ab14a64a787114dbec76f7b1ebd52f38 Binary files /dev/null and b/lagent/actions/__pycache__/arxiv_search.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/base_action.cpython-310.pyc b/lagent/actions/__pycache__/base_action.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b42a3be7c90300f9d60226039c5cc59ac35c5fcc Binary files /dev/null and b/lagent/actions/__pycache__/base_action.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/bing_map.cpython-310.pyc b/lagent/actions/__pycache__/bing_map.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a25535ed755a119470edbe617715189a0cf4c283 Binary files /dev/null and b/lagent/actions/__pycache__/bing_map.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/builtin_actions.cpython-310.pyc b/lagent/actions/__pycache__/builtin_actions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..979466ae2c69b7d9a7572b3df6f1f7d9dc13dfa1 Binary files /dev/null and b/lagent/actions/__pycache__/builtin_actions.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/google_scholar_search.cpython-310.pyc b/lagent/actions/__pycache__/google_scholar_search.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5381f81b21869f1727c69c8e69b71032c74e32c6 Binary files /dev/null and b/lagent/actions/__pycache__/google_scholar_search.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/google_search.cpython-310.pyc b/lagent/actions/__pycache__/google_search.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac9439e46b062e3a63e36a0e331101add4eb2e49 Binary files /dev/null and b/lagent/actions/__pycache__/google_search.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/ipython_interactive.cpython-310.pyc b/lagent/actions/__pycache__/ipython_interactive.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38a65522ac9ea3a6b8a7f8106a5412845da094bc Binary files /dev/null and b/lagent/actions/__pycache__/ipython_interactive.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/ipython_interpreter.cpython-310.pyc b/lagent/actions/__pycache__/ipython_interpreter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a4f3c643f1cc8e15ad88adbe6cc4b66649cc73e Binary files /dev/null and b/lagent/actions/__pycache__/ipython_interpreter.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/ipython_manager.cpython-310.pyc b/lagent/actions/__pycache__/ipython_manager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..294cc2cd3ed382224fc8fb8d61cef22877e48ff6 Binary files /dev/null and b/lagent/actions/__pycache__/ipython_manager.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/parser.cpython-310.pyc b/lagent/actions/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e993bd2baf2911876944ec517e4b6c455150f4e0 Binary files /dev/null and b/lagent/actions/__pycache__/parser.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/ppt.cpython-310.pyc b/lagent/actions/__pycache__/ppt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fe4a3658e45fb6bb510b6756c775f7f8ab00dbe Binary files /dev/null and b/lagent/actions/__pycache__/ppt.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/python_interpreter.cpython-310.pyc b/lagent/actions/__pycache__/python_interpreter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1c78d48bfc7d8366884895b8792f399d9ca7890 Binary files /dev/null and b/lagent/actions/__pycache__/python_interpreter.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/weather_query.cpython-310.pyc b/lagent/actions/__pycache__/weather_query.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86414c8849ad99fca06860ef1fac9cd11cdb6134 Binary files /dev/null and b/lagent/actions/__pycache__/weather_query.cpython-310.pyc differ diff --git a/lagent/actions/__pycache__/web_browser.cpython-310.pyc b/lagent/actions/__pycache__/web_browser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03b6a55f0e6b1ac79211abd5de7be479469b5729 Binary files /dev/null and b/lagent/actions/__pycache__/web_browser.cpython-310.pyc differ diff --git a/lagent/actions/action_executor.py b/lagent/actions/action_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..e28f32d57766be5d3c54da5fb0db7c0cc5a918ae --- /dev/null +++ b/lagent/actions/action_executor.py @@ -0,0 +1,198 @@ +import inspect +from collections import OrderedDict +from typing import Callable, Dict, List, Union + +from lagent.actions.base_action import BaseAction +from lagent.actions.builtin_actions import FinishAction, InvalidAction, NoAction +from lagent.hooks import Hook, RemovableHandle +from lagent.schema import ActionReturn, ActionValidCode, AgentMessage, FunctionCall +from lagent.utils import create_object + + +class ActionExecutor: + """The action executor class. + + Args: + actions (Union[BaseAction, List[BaseAction]]): The action or actions. + invalid_action (BaseAction, optional): The invalid action. Defaults to + InvalidAction(). + no_action (BaseAction, optional): The no action. + Defaults to NoAction(). + finish_action (BaseAction, optional): The finish action. Defaults to + FinishAction(). + finish_in_action (bool, optional): Whether the finish action is in the + action list. Defaults to False. + """ + + def __init__( + self, + actions: Union[BaseAction, List[BaseAction], Dict, List[Dict]], + invalid_action: BaseAction = dict(type=InvalidAction), + no_action: BaseAction = dict(type=NoAction), + finish_action: BaseAction = dict(type=FinishAction), + finish_in_action: bool = False, + hooks: List[Dict] = None, + ): + + if not isinstance(actions, list): + actions = [actions] + finish_action = create_object(finish_action) + if finish_in_action: + actions.append(finish_action) + for i, action in enumerate(actions): + actions[i] = create_object(action) + self.actions = {action.name: action for action in actions} + + self.invalid_action = create_object(invalid_action) + self.no_action = create_object(no_action) + self.finish_action = finish_action + self._hooks: Dict[int, Hook] = OrderedDict() + if hooks: + for hook in hooks: + hook = create_object(hook) + self.register_hook(hook) + + def description(self) -> List[Dict]: + actions = [] + for action_name, action in self.actions.items(): + if action.is_toolkit: + for api in action.description['api_list']: + api_desc = api.copy() + api_desc['name'] = f"{action_name}.{api_desc['name']}" + actions.append(api_desc) + else: + action_desc = action.description.copy() + actions.append(action_desc) + return actions + + def __contains__(self, name: str): + return name in self.actions + + def keys(self): + return list(self.actions.keys()) + + def __setitem__(self, name: str, action: Union[BaseAction, Dict]): + action = create_object(action) + self.actions[action.name] = action + + def __delitem__(self, name: str): + del self.actions[name] + + def forward(self, name, parameters, **kwargs) -> ActionReturn: + action_name, api_name = ( + name.split('.') if '.' in name else (name, 'run')) + action_return: ActionReturn = ActionReturn() + if action_name not in self: + if name == self.no_action.name: + action_return = self.no_action(parameters) + elif name == self.finish_action.name: + action_return = self.finish_action(parameters) + else: + action_return = self.invalid_action(parameters) + else: + action_return = self.actions[action_name](parameters, api_name) + action_return.valid = ActionValidCode.OPEN + return action_return + + def __call__(self, + message: AgentMessage, + session_id=0, + **kwargs) -> AgentMessage: + # message.receiver = self.name + for hook in self._hooks.values(): + result = hook.before_action(self, message, session_id) + if result: + message = result + + assert isinstance(message.content, FunctionCall) or ( + isinstance(message.content, dict) and 'name' in message.content + and 'parameters' in message.content) + if isinstance(message.content, dict): + name = message.content.get('name') + parameters = message.content.get('parameters') + else: + name = message.content.name + parameters = message.content.parameters + + response_message = self.forward( + name=name, parameters=parameters, **kwargs) + if not isinstance(response_message, AgentMessage): + response_message = AgentMessage( + sender=self.__class__.__name__, + content=response_message, + ) + + for hook in self._hooks.values(): + result = hook.after_action(self, response_message, session_id) + if result: + response_message = result + return response_message + + def register_hook(self, hook: Callable): + handle = RemovableHandle(self._hooks) + self._hooks[handle.id] = hook + return handle + + +class AsyncActionExecutor(ActionExecutor): + + async def forward(self, name, parameters, **kwargs) -> ActionReturn: + action_name, api_name = ( + name.split('.') if '.' in name else (name, 'run')) + action_return: ActionReturn = ActionReturn() + if action_name not in self: + if name == self.no_action.name: + action_return = self.no_action(parameters) + elif name == self.finish_action.name: + action_return = self.finish_action(parameters) + else: + action_return = self.invalid_action(parameters) + else: + action = self.actions[action_name] + if inspect.iscoroutinefunction(action.__call__): + action_return = await action(parameters, api_name) + else: + action_return = action(parameters, api_name) + action_return.valid = ActionValidCode.OPEN + return action_return + + async def __call__(self, + message: AgentMessage, + session_id=0, + **kwargs) -> AgentMessage: + # message.receiver = self.name + for hook in self._hooks.values(): + if inspect.iscoroutinefunction(hook.before_action): + result = await hook.before_action(self, message, session_id) + else: + result = hook.before_action(self, message, session_id) + if result: + message = result + + assert isinstance(message.content, FunctionCall) or ( + isinstance(message.content, dict) and 'name' in message.content + and 'parameters' in message.content) + if isinstance(message.content, dict): + name = message.content.get('name') + parameters = message.content.get('parameters') + else: + name = message.content.name + parameters = message.content.parameters + + response_message = await self.forward( + name=name, parameters=parameters, **kwargs) + if not isinstance(response_message, AgentMessage): + response_message = AgentMessage( + sender=self.__class__.__name__, + content=response_message, + ) + + for hook in self._hooks.values(): + if inspect.iscoroutinefunction(hook.after_action): + result = await hook.after_action(self, response_message, + session_id) + else: + result = hook.after_action(self, response_message, session_id) + if result: + response_message = result + return response_message diff --git a/lagent/actions/arxiv_search.py b/lagent/actions/arxiv_search.py new file mode 100644 index 0000000000000000000000000000000000000000..43ba7c056fa583054d17eb2b6661b37fab41de6f --- /dev/null +++ b/lagent/actions/arxiv_search.py @@ -0,0 +1,79 @@ +from typing import Optional, Type + +from asyncer import asyncify + +from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api +from lagent.actions.parser import BaseParser, JsonParser +from lagent.schema import ActionReturn, ActionStatusCode + + +class ArxivSearch(BaseAction): + """Search information from Arxiv.org. \ +Useful for when you need to answer questions about Physics, Mathematics, \ +Computer Science, Quantitative Biology, Quantitative Finance, Statistics, \ +Electrical Engineering, and Economics from scientific articles on arxiv.org. + """ + + def __init__( + self, + top_k_results: int = 3, + max_query_len: int = 300, + doc_content_chars_max: int = 1500, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ): + super().__init__(description, parser) + self.top_k_results = top_k_results + self.max_query_len = max_query_len + self.doc_content_chars_max = doc_content_chars_max + + @tool_api(explode_return=True) + def get_arxiv_article_information(self, query: str) -> dict: + """Run Arxiv search and get the article meta information. + + Args: + query (:class:`str`): the content of search query + + Returns: + :class:`dict`: article information + * content (str): a list of 3 arxiv search papers + """ + import arxiv + + try: + results = arxiv.Search( # type: ignore + query[: self.max_query_len], max_results=self.top_k_results + ).results() + except Exception as exc: + return ActionReturn(errmsg=f'Arxiv exception: {exc}', state=ActionStatusCode.HTTP_ERROR) + docs = [ + f'Published: {result.updated.date()}\nTitle: {result.title}\n' + f'Authors: {", ".join(a.name for a in result.authors)}\n' + f'Summary: {result.summary[:self.doc_content_chars_max]}' + for result in results + ] + if docs: + return {'content': '\n\n'.join(docs)} + return {'content': 'No good Arxiv Result was found'} + + +class AsyncArxivSearch(AsyncActionMixin, ArxivSearch): + """Search information from Arxiv.org. \ +Useful for when you need to answer questions about Physics, Mathematics, \ +Computer Science, Quantitative Biology, Quantitative Finance, Statistics, \ +Electrical Engineering, and Economics from scientific articles on arxiv.org. + """ + + @tool_api(explode_return=True) + @asyncify + def get_arxiv_article_information(self, query: str) -> dict: + """Run Arxiv search and get the article meta information. + + Args: + query (:class:`str`): the content of search query + + Returns: + :class:`dict`: article information + * content (str): a list of 3 arxiv search papers + """ + return super().get_arxiv_article_information(query) diff --git a/lagent/actions/base_action.py b/lagent/actions/base_action.py new file mode 100644 index 0000000000000000000000000000000000000000..2e546f75bb251c73991fee124c00a222f53bbce5 --- /dev/null +++ b/lagent/actions/base_action.py @@ -0,0 +1,434 @@ +import inspect +import logging +import re +from abc import ABCMeta +from copy import deepcopy +from functools import wraps +from typing import Callable, Optional, Type, get_args, get_origin + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + +from griffe import Docstring + +try: + from griffe import DocstringSectionKind +except ImportError: + from griffe.enumerations import DocstringSectionKind + +from ..schema import ActionReturn, ActionStatusCode +from .parser import BaseParser, JsonParser, ParseError + +logging.getLogger('griffe').setLevel(logging.ERROR) + + +def tool_api(func: Optional[Callable] = None, + *, + explode_return: bool = False, + returns_named_value: bool = False, + **kwargs): + """Turn functions into tools. It will parse typehints as well as docstrings + to build the tool description and attach it to functions via an attribute + ``api_description``. + + Examples: + + .. code-block:: python + + # typehints has higher priority than docstrings + from typing import Annotated + + @tool_api + def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1): + '''Add operation + + Args: + x (int): a + y (int): b + ''' + return a + b + + print(add.api_description) + + Args: + func (Optional[Callable]): function to decorate. Defaults to ``None``. + explode_return (bool): whether to flatten the dictionary or tuple return + as the ``return_data`` field. When enabled, it is recommended to + annotate the member in docstrings. Defaults to ``False``. + + .. code-block:: python + + @tool_api(explode_return=True) + def foo(a, b): + '''A simple function + + Args: + a (int): a + b (int): b + + Returns: + dict: information of inputs + * x: value of a + * y: value of b + ''' + return {'x': a, 'y': b} + + print(foo.api_description) + + returns_named_value (bool): whether to parse ``thing: Description`` in + returns sections as a name and description, rather than a type and + description. When true, type must be wrapped in parentheses: + ``(int): Description``. When false, parentheses are optional but + the items cannot be named: ``int: Description``. Defaults to ``False``. + + Returns: + Callable: wrapped function or partial decorator + + Important: + ``return_data`` field will be added to ``api_description`` only + when ``explode_return`` or ``returns_named_value`` is enabled. + """ + + def _detect_type(string): + field_type = 'STRING' + if 'list' in string: + field_type = 'Array' + elif 'str' not in string: + if 'float' in string: + field_type = 'FLOAT' + elif 'int' in string: + field_type = 'NUMBER' + elif 'bool' in string: + field_type = 'BOOLEAN' + return field_type + + def _explode(desc): + kvs = [] + desc = '\nArgs:\n' + '\n'.join([ + ' ' + item.lstrip(' -+*#.') + for item in desc.split('\n')[1:] if item.strip() + ]) + docs = Docstring(desc).parse('google') + if not docs: + return kvs + if docs[0].kind is DocstringSectionKind.parameters: + for d in docs[0].value: + d = d.as_dict() + if not d['annotation']: + d.pop('annotation') + else: + d['type'] = _detect_type(d.pop('annotation').lower()) + kvs.append(d) + return kvs + + def _parse_tool(function): + # remove rst syntax + docs = Docstring( + re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse( + 'google', returns_named_value=returns_named_value, **kwargs) + desc = dict( + name=function.__name__, + description=docs[0].value + if docs[0].kind is DocstringSectionKind.text else '', + parameters=[], + required=[], + ) + args_doc, returns_doc = {}, [] + for doc in docs: + if doc.kind is DocstringSectionKind.parameters: + for d in doc.value: + d = d.as_dict() + d['type'] = _detect_type(d.pop('annotation').lower()) + args_doc[d['name']] = d + if doc.kind is DocstringSectionKind.returns: + for d in doc.value: + d = d.as_dict() + if not d['name']: + d.pop('name') + if not d['annotation']: + d.pop('annotation') + else: + d['type'] = _detect_type(d.pop('annotation').lower()) + returns_doc.append(d) + + sig = inspect.signature(function) + for name, param in sig.parameters.items(): + if name == 'self': + continue + parameter = dict( + name=param.name, + type='STRING', + description=args_doc.get(param.name, + {}).get('description', '')) + annotation = param.annotation + if annotation is inspect.Signature.empty: + parameter['type'] = args_doc.get(param.name, + {}).get('type', 'STRING') + else: + if get_origin(annotation) is Annotated: + annotation, info = get_args(annotation) + if info: + parameter['description'] = info + while get_origin(annotation): + annotation = get_args(annotation) + parameter['type'] = _detect_type(str(annotation)) + desc['parameters'].append(parameter) + if param.default is inspect.Signature.empty: + desc['required'].append(param.name) + + return_data = [] + if explode_return: + return_data = _explode(returns_doc[0]['description']) + elif returns_named_value: + return_data = returns_doc + if return_data: + desc['return_data'] = return_data + return desc + + if callable(func): + + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def wrapper(self, *args, **kwargs): + return await func(self, *args, **kwargs) + + else: + + @wraps(func) + def wrapper(self, *args, **kwargs): + return func(self, *args, **kwargs) + + wrapper.api_description = _parse_tool(func) + return wrapper + + def decorate(func): + + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def wrapper(self, *args, **kwargs): + return await func(self, *args, **kwargs) + + else: + + @wraps(func) + def wrapper(self, *args, **kwargs): + return func(self, *args, **kwargs) + + wrapper.api_description = _parse_tool(func) + return wrapper + + return decorate + + +class ToolMeta(ABCMeta): + """Metaclass of tools.""" + + def __new__(mcs, name, base, attrs): + is_toolkit, tool_desc = True, dict( + name=name, + description=Docstring(attrs.get('__doc__', + '')).parse('google')[0].value) + for key, value in attrs.items(): + if callable(value) and hasattr(value, 'api_description'): + api_desc = getattr(value, 'api_description') + if key == 'run': + tool_desc['parameters'] = api_desc['parameters'] + tool_desc['required'] = api_desc['required'] + if api_desc['description']: + tool_desc['description'] = api_desc['description'] + if api_desc.get('return_data'): + tool_desc['return_data'] = api_desc['return_data'] + is_toolkit = False + else: + tool_desc.setdefault('api_list', []).append(api_desc) + if not is_toolkit and 'api_list' in tool_desc: + raise KeyError('`run` and other tool APIs can not be implemented ' + 'at the same time') + if is_toolkit and 'api_list' not in tool_desc: + is_toolkit = False + if callable(attrs.get('run')): + run_api = tool_api(attrs['run']) + api_desc = run_api.api_description + tool_desc['parameters'] = api_desc['parameters'] + tool_desc['required'] = api_desc['required'] + if api_desc['description']: + tool_desc['description'] = api_desc['description'] + if api_desc.get('return_data'): + tool_desc['return_data'] = api_desc['return_data'] + attrs['run'] = run_api + else: + tool_desc['parameters'], tool_desc['required'] = [], [] + attrs['_is_toolkit'] = is_toolkit + attrs['__tool_description__'] = tool_desc + return super().__new__(mcs, name, base, attrs) + + +class BaseAction(metaclass=ToolMeta): + """Base class for all actions. + + Args: + description (:class:`Optional[dict]`): The description of the action. + Defaults to ``None``. + parser (:class:`Type[BaseParser]`): The parser class to process the + action's inputs and outputs. Defaults to :class:`JsonParser`. + + Examples: + + * simple tool + + .. code-block:: python + + class Bold(BaseAction): + '''Make text bold''' + + def run(self, text: str): + ''' + Args: + text (str): input text + + Returns: + str: bold text + ''' + return '**' + text + '**' + + action = Bold() + + * toolkit with multiple APIs + + .. code-block:: python + + class Calculator(BaseAction): + '''Calculator''' + + @tool_api + def add(self, a, b): + '''Add operation + + Args: + a (int): augend + b (int): addend + + Returns: + int: sum + ''' + return a + b + + @tool_api + def sub(self, a, b): + '''Subtraction operation + + Args: + a (int): minuend + b (int): subtrahend + + Returns: + int: difference + ''' + return a - b + + action = Calculator() + """ + + def __init__( + self, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ): + self._description = deepcopy(description or self.__tool_description__) + self._name = self._description['name'] + self._parser = parser(self) + + def __call__(self, inputs: str, name='run') -> ActionReturn: + fallback_args = {'inputs': inputs, 'name': name} + if not hasattr(self, name): + return ActionReturn( + fallback_args, + type=self.name, + errmsg=f'invalid API: {name}', + state=ActionStatusCode.API_ERROR) + try: + inputs = self._parser.parse_inputs(inputs, name) + except ParseError as exc: + return ActionReturn( + fallback_args, + type=self.name, + errmsg=exc.err_msg, + state=ActionStatusCode.ARGS_ERROR) + try: + outputs = getattr(self, name)(**inputs) + except Exception as exc: + return ActionReturn( + inputs, + type=self.name, + errmsg=str(exc), + state=ActionStatusCode.API_ERROR) + if isinstance(outputs, ActionReturn): + action_return = outputs + if not action_return.args: + action_return.args = inputs + if not action_return.type: + action_return.type = self.name + else: + result = self._parser.parse_outputs(outputs) + action_return = ActionReturn(inputs, type=self.name, result=result) + return action_return + + @property + def name(self): + return self._name + + @property + def is_toolkit(self): + return self._is_toolkit + + @property + def description(self) -> dict: + """Description of the tool.""" + return self._description + + def __repr__(self): + return f'{self.description}' + + __str__ = __repr__ + + +class AsyncActionMixin: + + async def __call__(self, inputs: str, name='run') -> ActionReturn: + fallback_args = {'inputs': inputs, 'name': name} + if not hasattr(self, name): + return ActionReturn( + fallback_args, + type=self.name, + errmsg=f'invalid API: {name}', + state=ActionStatusCode.API_ERROR) + try: + inputs = self._parser.parse_inputs(inputs, name) + except ParseError as exc: + return ActionReturn( + fallback_args, + type=self.name, + errmsg=exc.err_msg, + state=ActionStatusCode.ARGS_ERROR) + try: + outputs = await getattr(self, name)(**inputs) + except Exception as exc: + return ActionReturn( + inputs, + type=self.name, + errmsg=str(exc), + state=ActionStatusCode.API_ERROR) + if isinstance(outputs, ActionReturn): + action_return = outputs + if not action_return.args: + action_return.args = inputs + if not action_return.type: + action_return.type = self.name + else: + result = self._parser.parse_outputs(outputs) + action_return = ActionReturn(inputs, type=self.name, result=result) + return action_return diff --git a/lagent/actions/bing_map.py b/lagent/actions/bing_map.py new file mode 100644 index 0000000000000000000000000000000000000000..75949be819a746f967cc43720e046e54b09581fa --- /dev/null +++ b/lagent/actions/bing_map.py @@ -0,0 +1,268 @@ +# flake8: noqa: E501 +import json +import os +from typing import Optional, Type + +import aiohttp +import requests + +from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api +from lagent.actions.parser import BaseParser, JsonParser + + +class BINGMap(BaseAction): + """BING Map plugin for looking up map information.""" + + def __init__( + self, + key: Optional[str] = None, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ) -> None: + super().__init__(description, parser) + key = os.environ.get('BING_MAP_KEY', key) + if key is None: + raise ValueError( + 'Please set BING Map API key either in the environment ' + 'as BING_MAP_KEY or pass it as `key` parameter.') + self.key = key + self.base_url = 'http://dev.virtualearth.net/REST/V1/' + + @tool_api(explode_return=True) + def get_distance(self, start: str, end: str) -> dict: + """Get the distance between two locations in km. + + Args: + start (:class:`str`): The start location + end (:class:`str`): The end location + + Returns: + :class:`dict`: distance information + * distance (str): the distance in km. + """ + # Request URL + url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key + # GET request + r = requests.get(url) + # TODO check request status? + data = json.loads(r.text) + # Extract route information + route = data['resourceSets'][0]['resources'][0] + # Extract distance in miles + distance = route['travelDistance'] + return dict(distance=distance) + + @tool_api(explode_return=True) + def get_route(self, start: str, end: str) -> dict: + """Get the route between two locations in km. + + Args: + start (:class:`str`): The start location + end (:class:`str`): The end location + + Returns: + :class:`dict`: route information + * route (list): the route, a list of actions. + """ + # Request URL + url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key + # GET request + r = requests.get(url) + data = json.loads(r.text) + # Extract route information + route = data['resourceSets'][0]['resources'][0] + itinerary = route['routeLegs'][0]['itineraryItems'] + # Extract route text information + route_text = [] + for item in itinerary: + if 'instruction' in item: + route_text.append(item['instruction']['text']) + return dict(route=route_text) + + @tool_api(explode_return=True) + def get_coordinates(self, location: str) -> dict: + """Get the coordinates of a location. + + Args: + location (:class:`str`): the location need to get coordinates. + + Returns: + :class:`dict`: coordinates information + * latitude (float): the latitude of the location. + * longitude (float): the longitude of the location. + """ + url = self.base_url + 'Locations' + params = {'query': location, 'key': self.key} + response = requests.get(url, params=params) + json_data = response.json() + coordinates = json_data['resourceSets'][0]['resources'][0]['point'][ + 'coordinates'] + return dict(latitude=coordinates[0], longitude=coordinates[1]) + + @tool_api(explode_return=True) + def search_nearby(self, + search_term: str, + places: str = 'unknown', + latitude: float = 0.0, + longitude: float = 0.0, + radius: int = 5000) -> dict: + """Search for places nearby a location, within a given radius, and return the results into a list. You can use either the places name or the latitude and longitude. + + Args: + search_term (:class:`str`): the place name. + places (:class:`str`): the name of the location. Defaults to ``'unknown'``. + latitude (:class:`float`): the latitude of the location. Defaults to ``0.0``. + longitude (:class:`float`): the longitude of the location. Defaults to ``0.0``. + radius (:class:`int`): radius in meters. Defaults to ``5000``. + + Returns: + :class:`dict`: places information + * places (list): the list of places, each place is a dict with name and address, at most 5 places. + """ + url = self.base_url + 'LocalSearch' + if places != 'unknown': + pos = self.get_coordinates(**{'location': places}) + latitude, longitude = pos[1]['latitude'], pos[1]['longitude'] + # Build the request query string + params = { + 'query': search_term, + 'userLocation': f'{latitude},{longitude}', + 'radius': radius, + 'key': self.key + } + # Make the request + response = requests.get(url, params=params) + # Parse the response + response_data = json.loads(response.content) + # Get the results + results = response_data['resourceSets'][0]['resources'] + addresses = [] + for result in results: + name = result['name'] + address = result['Address']['formattedAddress'] + addresses.append(dict(name=name, address=address)) + if len(addresses) == 5: + break + return dict(place=addresses) + + +class AsyncBINGMap(AsyncActionMixin, BINGMap): + """BING Map plugin for looking up map information.""" + + @tool_api(explode_return=True) + async def get_distance(self, start: str, end: str) -> dict: + """Get the distance between two locations in km. + + Args: + start (:class:`str`): The start location + end (:class:`str`): The end location + + Returns: + :class:`dict`: distance information + * distance (str): the distance in km. + """ + # Request URL + url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key + # GET request + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + # TODO check request status? + data = await resp.json() + # Extract route information + route = data['resourceSets'][0]['resources'][0] + # Extract distance in miles + distance = route['travelDistance'] + return dict(distance=distance) + + @tool_api(explode_return=True) + async def get_route(self, start: str, end: str) -> dict: + """Get the route between two locations in km. + + Args: + start (:class:`str`): The start location + end (:class:`str`): The end location + + Returns: + :class:`dict`: route information + * route (list): the route, a list of actions. + """ + # Request URL + url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key + # GET request + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + data = await resp.json() + # Extract route information + route = data['resourceSets'][0]['resources'][0] + itinerary = route['routeLegs'][0]['itineraryItems'] + # Extract route text information + route_text = [] + for item in itinerary: + if 'instruction' in item: + route_text.append(item['instruction']['text']) + return dict(route=route_text) + + @tool_api(explode_return=True) + async def get_coordinates(self, location: str) -> dict: + """Get the coordinates of a location. + + Args: + location (:class:`str`): the location need to get coordinates. + + Returns: + :class:`dict`: coordinates information + * latitude (float): the latitude of the location. + * longitude (float): the longitude of the location. + """ + url = self.base_url + 'Locations' + params = {'query': location, 'key': self.key} + async with aiohttp.ClientSession() as session: + async with session.get(url, params=params) as resp: + data = await resp.json() + coordinates = data['resourceSets'][0]['resources'][0]['point'][ + 'coordinates'] + return dict(latitude=coordinates[0], longitude=coordinates[1]) + + @tool_api(explode_return=True) + async def search_nearby(self, + search_term: str, + places: str = 'unknown', + latitude: float = 0.0, + longitude: float = 0.0, + radius: int = 5000) -> dict: + """Search for places nearby a location, within a given radius, and return the results into a list. You can use either the places name or the latitude and longitude. + + Args: + search_term (:class:`str`): the place name. + places (:class:`str`): the name of the location. Defaults to ``'unknown'``. + latitude (:class:`float`): the latitude of the location. Defaults to ``0.0``. + longitude (:class:`float`): the longitude of the location. Defaults to ``0.0``. + radius (:class:`int`): radius in meters. Defaults to ``5000``. + + Returns: + :class:`dict`: places information + * places (list): the list of places, each place is a dict with name and address, at most 5 places. + """ + url = self.base_url + 'LocalSearch' + if places != 'unknown': + pos = self.get_coordinates(**{'location': places}) + latitude, longitude = pos[1]['latitude'], pos[1]['longitude'] + # Build the request query string + params = { + 'query': search_term, + 'userLocation': f'{latitude},{longitude}', + 'radius': radius, + 'key': self.key + } + async with aiohttp.ClientSession() as session: + async with session.get(url, params=params) as resp: + data = await resp.json() + results = data['resourceSets'][0]['resources'] + addresses = [] + for result in results: + name = result['name'] + address = result['Address']['formattedAddress'] + addresses.append(dict(name=name, address=address)) + if len(addresses) == 5: + break + return dict(place=addresses) diff --git a/lagent/actions/builtin_actions.py b/lagent/actions/builtin_actions.py new file mode 100644 index 0000000000000000000000000000000000000000..6bc4a49525f1aae5cf84e55c0dd5c88930fd3ecc --- /dev/null +++ b/lagent/actions/builtin_actions.py @@ -0,0 +1,109 @@ +from typing import Optional + +from lagent.actions.base_action import BaseAction, tool_api +from lagent.actions.parser import BaseParser +from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode + + +class InvalidAction(BaseAction): + """This is a invalid action class, which is used to return error message + when the action is invalid. + + Args: + err_msg (str): The error message. Defaults to 'The action is invalid, + please check the action name'. + + Returns: + ActionReturn: The action return. + """ + + def __init__(self, + err_msg: + str = 'The action is invalid, please check the action name.', + description: Optional[dict] = None, + parser=BaseParser) -> None: + super().__init__(description, parser) + self._err_msg = err_msg + + @tool_api + def run(self, err_msg: Optional[str] = None) -> ActionReturn: + """Return the error message. + + Args: + err_msg (str, optional): The error message. If err_msg is not None, + it will be returned, otherwise the default error message will + be returned. Defaults to None. + """ + action_return = ActionReturn( + url=None, + args=dict(text=err_msg), + errmsg=err_msg or self._err_msg, + type=self.name, + valid=ActionValidCode.INVALID, + state=ActionStatusCode.API_ERROR) + return action_return + + +class NoAction(BaseAction): + """This is a no action class, which is used to return error message when + the response does not follow the format. + + Args: + err_msg (str): The error message. Defaults to + 'Please follow the format'. + """ + + def __init__(self, + err_msg: str = 'Please follow the format', + description: Optional[dict] = None, + parser=BaseParser): + super().__init__(description, parser) + self._err_msg = err_msg + + @tool_api + def run(self, err_msg: Optional[str] = None) -> ActionReturn: + """Return the error message. + + Args: + err_msg (str, optional): The error message. If err_msg is not None, + it will be returned, otherwise the default error message will + be returned. Defaults to None. + + Returns: + ActionReturn: The action return. + """ + action_return = ActionReturn( + url=None, + args=dict(text=err_msg), + type=self.name, + errmsg=err_msg or self._err_msg, + valid=ActionValidCode.INVALID, + state=ActionStatusCode.API_ERROR) + return action_return + + +class FinishAction(BaseAction): + """This is a finish action class, which is used to return the final + result.""" + + def __init__(self, description: Optional[dict] = None, parser=BaseParser): + super().__init__(description, parser) + + @tool_api + def run(self, response: str) -> ActionReturn: + """Return the final result. + + Args: + response (str): The final result. + + Returns: + ActionReturn: The action return. + """ + action_return = ActionReturn( + url=None, + args=dict(text=response), + result=[dict(type='text', content=response)], + type=self.name, + valid=ActionValidCode.FINISH, + state=ActionStatusCode.SUCCESS) + return action_return diff --git a/lagent/actions/google_scholar_search.py b/lagent/actions/google_scholar_search.py new file mode 100644 index 0000000000000000000000000000000000000000..6247f14777bb1e8c13bfb0f4eefee77df8ca47e6 --- /dev/null +++ b/lagent/actions/google_scholar_search.py @@ -0,0 +1,438 @@ +# flake8: noqa: E501 +import os +from typing import Optional, Type + +from asyncer import asyncify + +from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api +from lagent.schema import ActionReturn, ActionStatusCode +from .parser import BaseParser, JsonParser + + +class GoogleScholar(BaseAction): + """Plugin for google scholar search. + + Args: + api_key (str): API KEY to use serper google search API, + You can create a free API key at https://serper.dev. + description (dict): The description of the action. Defaults to ``None``. + parser (Type[BaseParser]): The parser class to process the + action's inputs and outputs. Defaults to :class:`JsonParser`. + """ + + def __init__( + self, + api_key: Optional[str] = None, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ): + super().__init__(description, parser) + api_key = os.environ.get('SERPER_API_KEY', api_key) + if api_key is None: + raise ValueError( + 'Please set Serper API key either in the environment ' + 'as SERPER_API_KEY or pass it as `api_key` parameter.' + ) + self.api_key = api_key + + @tool_api(explode_return=True) + def search_google_scholar( + self, + query: str, + cites: Optional[str] = None, + as_ylo: Optional[int] = None, + as_yhi: Optional[int] = None, + scisbd: Optional[int] = None, + cluster: Optional[str] = None, + hl: Optional[str] = None, + lr: Optional[str] = None, + start: Optional[int] = None, + num: Optional[int] = None, + as_sdt: Optional[str] = None, + safe: Optional[str] = None, + filter: Optional[str] = None, + as_vis: Optional[str] = None, + ) -> dict: + """Search for scholarly articles based on a query according to the google scholar. + + Args: + query (str): The query to search for. + cites (Optional[str]): The unique ID of an article for triggering "Cited By" searches. + as_ylo (Optional[int]): The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted). + as_yhi (Optional[int]): The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted). + scisbd (Optional[int]): Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything. + cluster (Optional[str]): The unique ID of an article for triggering "All Versions" searches. + hl (Optional[str]): The language to use for the Google Scholar search. + lr (Optional[str]): One or multiple languages to limit the search to. + start (Optional[int]): The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.) + num (Optional[int]): The maximum number of results to return, limited to 20. + as_sdt (Optional[str]): Can be used either as a search type or a filter. + safe (Optional[str]): The level of filtering for adult content. + filter (Optional[str]): Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off. + as_vis (Optional[str]): Defines whether to include citations or not. + + Returns: + :class:`dict`: article information + - title: a list of the titles of the three selected papers + - cited_by: a list of the citation numbers of the three selected papers + - organic_id: a list of the organic results' ids of the three selected papers + - pub_info: publication information of selected papers + """ + from serpapi import GoogleSearch + + params = { + 'q': query, + 'engine': 'google_scholar', + 'api_key': self.api_key, + 'cites': cites, + 'as_ylo': as_ylo, + 'as_yhi': as_yhi, + 'scisbd': scisbd, + 'cluster': cluster, + 'hl': hl, + 'lr': lr, + 'start': start, + 'num': num, + 'as_sdt': as_sdt, + 'safe': safe, + 'filter': filter, + 'as_vis': as_vis, + } + search = GoogleSearch(params) + try: + r = search.get_dict() + results = r['organic_results'] + title = [] + snippets = [] + cited_by = [] + organic_id = [] + pub_info = [] + for item in results[:3]: + title.append(item['title']) + pub_info.append(item['publication_info']['summary']) + citation = item['inline_links'].get('cited_by', {'total': ''}) + cited_by.append(citation['total']) + snippets.append(item['snippet']) + organic_id.append(item['result_id']) + return dict(title=title, cited_by=cited_by, organic_id=organic_id, snippets=snippets) + except Exception as e: + return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) + + @tool_api(explode_return=True) + def get_author_information( + self, + author_id: str, + hl: Optional[str] = None, + view_op: Optional[str] = None, + sort: Optional[str] = None, + citation_id: Optional[str] = None, + start: Optional[int] = None, + num: Optional[int] = None, + no_cache: Optional[bool] = None, + async_req: Optional[bool] = None, + output: Optional[str] = None, + ) -> dict: + """Search for an author's information by author's id provided by get_author_id. + + Args: + author_id (str): Required. The ID of an author. + hl (Optional[str]): The language to use for the Google Scholar Author search. Default is 'en'. + view_op (Optional[str]): Used for viewing specific parts of a page. + sort (Optional[str]): Used for sorting and refining articles. + citation_id (Optional[str]): Used for retrieving individual article citation. + start (Optional[int]): Defines the result offset. Default is 0. + num (Optional[int]): Defines the number of results to return. Default is 20. + no_cache (Optional[bool]): Forces SerpApi to fetch the results even if a cached version is already present. Default is False. + async_req (Optional[bool]): Defines the way you want to submit your search to SerpApi. Default is False. + output (Optional[str]): Defines the final output you want. Default is 'json'. + + Returns: + :class:`dict`: author information + * name: author's name + * affliation: the affliation of the author + * articles: at most 3 articles by the author + * website: the author's homepage url + """ + from serpapi import GoogleSearch + + params = { + 'engine': 'google_scholar_author', + 'author_id': author_id, + 'api_key': self.api_key, + 'hl': hl, + 'view_op': view_op, + 'sort': sort, + 'citation_id': citation_id, + 'start': start, + 'num': num, + 'no_cache': no_cache, + 'async': async_req, + 'output': output, + } + try: + search = GoogleSearch(params) + results = search.get_dict() + author = results['author'] + articles = results.get('articles', []) + return dict( + name=author['name'], + affiliations=author.get('affiliations', ''), + website=author.get('website', ''), + articles=[dict(title=article['title'], authors=article['authors']) for article in articles[:3]], + ) + except Exception as e: + return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) + + @tool_api(explode_return=True) + def get_citation_format( + self, + q: str, + no_cache: Optional[bool] = None, + async_: Optional[bool] = None, + output: Optional[str] = 'json', + ) -> dict: + """Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar. + + Args: + q (str): ID of an individual Google Scholar organic search result. + no_cache (Optional[bool]): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None. + async_ (Optional[bool]): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None. + output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'. + + Returns: + :class:`dict`: citation format + * authors: the authors of the article + * citation: the citation format of the article + """ + from serpapi import GoogleSearch + + params = { + 'q': q, + 'engine': 'google_scholar_cite', + 'api_key': self.api_key, + 'no_cache': no_cache, + 'async': async_, + 'output': output, + } + try: + search = GoogleSearch(params) + results = search.get_dict() + citation = results['citations'] + citation_info = citation[0]['snippet'] + return citation_info + except Exception as e: + return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) + + @tool_api(explode_return=True) + def get_author_id( + self, + mauthors: str, + hl: Optional[str] = 'en', + after_author: Optional[str] = None, + before_author: Optional[str] = None, + no_cache: Optional[bool] = False, + _async: Optional[bool] = False, + output: Optional[str] = 'json', + ) -> dict: + """The getAuthorId function is used to get the author's id by his or her name. + + Args: + mauthors (str): Defines the author you want to search for. + hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'. + after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None. + before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None. + no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False. + _async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False. + output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'. + + Returns: + :class:`dict`: author id + * author_id: the author_id of the author + """ + from serpapi import GoogleSearch + + params = { + 'mauthors': mauthors, + 'engine': 'google_scholar_profiles', + 'api_key': self.api_key, + 'hl': hl, + 'after_author': after_author, + 'before_author': before_author, + 'no_cache': no_cache, + 'async': _async, + 'output': output, + } + try: + search = GoogleSearch(params) + results = search.get_dict() + profile = results['profiles'] + author_info = dict(author_id=profile[0]['author_id']) + return author_info + except Exception as e: + return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR) + + +class AsyncGoogleScholar(AsyncActionMixin, GoogleScholar): + """Plugin for google scholar search. + + Args: + api_key (str): API KEY to use serper google search API, + You can create a free API key at https://serper.dev. + description (dict): The description of the action. Defaults to ``None``. + parser (Type[BaseParser]): The parser class to process the + action's inputs and outputs. Defaults to :class:`JsonParser`. + """ + + @tool_api(explode_return=True) + @asyncify + def search_google_scholar( + self, + query: str, + cites: Optional[str] = None, + as_ylo: Optional[int] = None, + as_yhi: Optional[int] = None, + scisbd: Optional[int] = None, + cluster: Optional[str] = None, + hl: Optional[str] = None, + lr: Optional[str] = None, + start: Optional[int] = None, + num: Optional[int] = None, + as_sdt: Optional[str] = None, + safe: Optional[str] = None, + filter: Optional[str] = None, + as_vis: Optional[str] = None, + ) -> dict: + """Search for scholarly articles based on a query according to the google scholar. + + Args: + query (str): The query to search for. + cites (Optional[str]): The unique ID of an article for triggering "Cited By" searches. + as_ylo (Optional[int]): The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted). + as_yhi (Optional[int]): The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted). + scisbd (Optional[int]): Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything. + cluster (Optional[str]): The unique ID of an article for triggering "All Versions" searches. + hl (Optional[str]): The language to use for the Google Scholar search. + lr (Optional[str]): One or multiple languages to limit the search to. + start (Optional[int]): The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.) + num (Optional[int]): The maximum number of results to return, limited to 20. + as_sdt (Optional[str]): Can be used either as a search type or a filter. + safe (Optional[str]): The level of filtering for adult content. + filter (Optional[str]): Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off. + as_vis (Optional[str]): Defines whether to include citations or not. + + Returns: + :class:`dict`: article information + - title: a list of the titles of the three selected papers + - cited_by: a list of the citation numbers of the three selected papers + - organic_id: a list of the organic results' ids of the three selected papers + - pub_info: publication information of selected papers + """ + return super().search_google_scholar( + query, + cites, + as_ylo, + as_yhi, + scisbd, + cluster, + hl, + lr, + start, + num, + as_sdt, + safe, + filter, + as_vis, + ) + + @tool_api(explode_return=True) + @asyncify + def get_author_information( + self, + author_id: str, + hl: Optional[str] = None, + view_op: Optional[str] = None, + sort: Optional[str] = None, + citation_id: Optional[str] = None, + start: Optional[int] = None, + num: Optional[int] = None, + no_cache: Optional[bool] = None, + async_req: Optional[bool] = None, + output: Optional[str] = None, + ) -> dict: + """Search for an author's information by author's id provided by get_author_id. + + Args: + author_id (str): Required. The ID of an author. + hl (Optional[str]): The language to use for the Google Scholar Author search. Default is 'en'. + view_op (Optional[str]): Used for viewing specific parts of a page. + sort (Optional[str]): Used for sorting and refining articles. + citation_id (Optional[str]): Used for retrieving individual article citation. + start (Optional[int]): Defines the result offset. Default is 0. + num (Optional[int]): Defines the number of results to return. Default is 20. + no_cache (Optional[bool]): Forces SerpApi to fetch the results even if a cached version is already present. Default is False. + async_req (Optional[bool]): Defines the way you want to submit your search to SerpApi. Default is False. + output (Optional[str]): Defines the final output you want. Default is 'json'. + + Returns: + :class:`dict`: author information + * name: author's name + * affliation: the affliation of the author + * articles: at most 3 articles by the author + * website: the author's homepage url + """ + return super().get_author_information( + author_id, hl, view_op, sort, citation_id, start, num, no_cache, async_req, output + ) + + @tool_api(explode_return=True) + @asyncify + def get_citation_format( + self, + q: str, + no_cache: Optional[bool] = None, + async_: Optional[bool] = None, + output: Optional[str] = 'json', + ) -> dict: + """Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar. + + Args: + q (str): ID of an individual Google Scholar organic search result. + no_cache (Optional[bool]): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None. + async_ (Optional[bool]): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None. + output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'. + + Returns: + :class:`dict`: citation format + * authors: the authors of the article + * citation: the citation format of the article + """ + return super().get_citation_format(q, no_cache, async_, output) + + @tool_api(explode_return=True) + @asyncify + def get_author_id( + self, + mauthors: str, + hl: Optional[str] = 'en', + after_author: Optional[str] = None, + before_author: Optional[str] = None, + no_cache: Optional[bool] = False, + _async: Optional[bool] = False, + output: Optional[str] = 'json', + ) -> dict: + """The getAuthorId function is used to get the author's id by his or her name. + + Args: + mauthors (str): Defines the author you want to search for. + hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'. + after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None. + before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None. + no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False. + _async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False. + output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'. + + Returns: + :class:`dict`: author id + * author_id: the author_id of the author + """ + return super().get_author_id(mauthors, hl, after_author, before_author, no_cache, _async, output) diff --git a/lagent/actions/google_search.py b/lagent/actions/google_search.py new file mode 100644 index 0000000000000000000000000000000000000000..27b314565cdb62ba330e785a6bde313c781c33f7 --- /dev/null +++ b/lagent/actions/google_search.py @@ -0,0 +1,244 @@ +import os +from typing import List, Optional, Tuple, Type, Union + +import aiohttp +import requests + +from lagent.schema import ActionReturn, ActionStatusCode +from .base_action import AsyncActionMixin, BaseAction, tool_api +from .parser import BaseParser, JsonParser + + +class GoogleSearch(BaseAction): + """Wrapper around the Serper.dev Google Search API. + + To use, you should pass your serper API key to the constructor. + + Code is modified from lang-chain GoogleSerperAPIWrapper + (https://github.com/langchain-ai/langchain/blob/ba5f + baba704a2d729a4b8f568ed70d7c53e799bb/libs/langchain/ + langchain/utilities/google_serper.py) + + Args: + api_key (str): API KEY to use serper google search API, + You can create a free API key at https://serper.dev. + timeout (int): Upper bound of waiting time for a serper request. + search_type (str): Serper API support ['search', 'images', 'news', + 'places'] types of search, currently we only support 'search'. + description (dict): The description of the action. Defaults to ``None``. + parser (Type[BaseParser]): The parser class to process the + action's inputs and outputs. Defaults to :class:`JsonParser`. + """ + result_key_for_type = { + 'news': 'news', + 'places': 'places', + 'images': 'images', + 'search': 'organic', + } + + def __init__( + self, + api_key: Optional[str] = None, + timeout: int = 5, + search_type: str = 'search', + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ): + super().__init__(description, parser) + api_key = os.environ.get('SERPER_API_KEY', api_key) + if api_key is None: + raise ValueError( + 'Please set Serper API key either in the environment ' + 'as SERPER_API_KEY or pass it as `api_key` parameter.') + self.api_key = api_key + self.timeout = timeout + self.search_type = search_type + + @tool_api + def run(self, query: str, k: int = 10) -> ActionReturn: + """一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。 + + Args: + query (str): the search content + k (int): select first k results in the search results as response + """ + tool_return = ActionReturn(type=self.name) + status_code, response = self._search(query, k=k) + # convert search results to ToolReturn format + if status_code == -1: + tool_return.errmsg = response + tool_return.state = ActionStatusCode.HTTP_ERROR + elif status_code == 200: + parsed_res = self._parse_results(response, k) + tool_return.result = [dict(type='text', content=str(parsed_res))] + tool_return.state = ActionStatusCode.SUCCESS + else: + tool_return.errmsg = str(status_code) + tool_return.state = ActionStatusCode.API_ERROR + return tool_return + + def _parse_results(self, results: dict, k: int) -> Union[str, List[str]]: + """Parse the search results from Serper API. + + Args: + results (dict): The search content from Serper API + in json format. + + Returns: + List[str]: The parsed search results. + """ + + snippets = [] + + if results.get('answerBox'): + answer_box = results.get('answerBox', {}) + if answer_box.get('answer'): + return [answer_box.get('answer')] + elif answer_box.get('snippet'): + return [answer_box.get('snippet').replace('\n', ' ')] + elif answer_box.get('snippetHighlighted'): + return answer_box.get('snippetHighlighted') + + if results.get('knowledgeGraph'): + kg = results.get('knowledgeGraph', {}) + title = kg.get('title') + entity_type = kg.get('type') + if entity_type: + snippets.append(f'{title}: {entity_type}.') + description = kg.get('description') + if description: + snippets.append(description) + for attribute, value in kg.get('attributes', {}).items(): + snippets.append(f'{title} {attribute}: {value}.') + + for result in results[self.result_key_for_type[ + self.search_type]][:k]: + if 'snippet' in result: + snippets.append(result['snippet']) + for attribute, value in result.get('attributes', {}).items(): + snippets.append(f'{attribute}: {value}.') + + if len(snippets) == 0: + return ['No good Google Search Result was found'] + return snippets + + def _search(self, + search_term: str, + search_type: Optional[str] = None, + **kwargs) -> Tuple[int, Union[dict, str]]: + """HTTP requests to Serper API. + + Args: + search_term (str): The search query. + search_type (str): search type supported by Serper API, + default to 'search'. + + Returns: + tuple: the return value is a tuple contains: + - status_code (int): HTTP status code from Serper API. + - response (dict): response context with json format. + """ + headers = { + 'X-API-KEY': self.api_key or '', + 'Content-Type': 'application/json', + } + params = { + 'q': search_term, + **{ + key: value + for key, value in kwargs.items() if value is not None + }, + } + try: + response = requests.post( + f'https://google.serper.dev/{search_type or self.search_type}', + headers=headers, + params=params, + timeout=self.timeout) + except Exception as e: + return -1, str(e) + return response.status_code, response.json() + + +class AsyncGoogleSearch(AsyncActionMixin, GoogleSearch): + """Wrapper around the Serper.dev Google Search API. + + To use, you should pass your serper API key to the constructor. + + Code is modified from lang-chain GoogleSerperAPIWrapper + (https://github.com/langchain-ai/langchain/blob/ba5f + baba704a2d729a4b8f568ed70d7c53e799bb/libs/langchain/ + langchain/utilities/google_serper.py) + + Args: + api_key (str): API KEY to use serper google search API, + You can create a free API key at https://serper.dev. + timeout (int): Upper bound of waiting time for a serper request. + search_type (str): Serper API support ['search', 'images', 'news', + 'places'] types of search, currently we only support 'search'. + description (dict): The description of the action. Defaults to ``None``. + parser (Type[BaseParser]): The parser class to process the + action's inputs and outputs. Defaults to :class:`JsonParser`. + """ + + @tool_api + async def run(self, query: str, k: int = 10) -> ActionReturn: + """一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。 + + Args: + query (str): the search content + k (int): select first k results in the search results as response + """ + tool_return = ActionReturn(type=self.name) + status_code, response = await self._search(query, k=k) + # convert search results to ToolReturn format + if status_code == -1: + tool_return.errmsg = response + tool_return.state = ActionStatusCode.HTTP_ERROR + elif status_code == 200: + parsed_res = self._parse_results(response) + tool_return.result = [dict(type='text', content=str(parsed_res))] + tool_return.state = ActionStatusCode.SUCCESS + else: + tool_return.errmsg = str(status_code) + tool_return.state = ActionStatusCode.API_ERROR + return tool_return + + async def _search(self, + search_term: str, + search_type: Optional[str] = None, + **kwargs) -> Tuple[int, Union[dict, str]]: + """HTTP requests to Serper API. + + Args: + search_term (str): The search query. + search_type (str): search type supported by Serper API, + default to 'search'. + + Returns: + tuple: the return value is a tuple contains: + - status_code (int): HTTP status code from Serper API. + - response (dict): response context with json format. + """ + headers = { + 'X-API-KEY': self.api_key or '', + 'Content-Type': 'application/json', + } + params = { + 'q': search_term, + **{ + key: value + for key, value in kwargs.items() if value is not None + }, + } + timeout = aiohttp.ClientTimeout(total=self.timeout) + async with aiohttp.ClientSession(timeout=timeout) as session: + try: + async with session.post( + f'https://google.serper.dev/{search_type or self.search_type}', + headers=headers, + params=params) as resp: + code, ret = resp.status, await resp.json() + except aiohttp.ClientError as e: + code, ret = -1, str(e) + return code, ret diff --git a/lagent/actions/ipython_interactive.py b/lagent/actions/ipython_interactive.py new file mode 100644 index 0000000000000000000000000000000000000000..e84df9c22e3e4feee89775e84f87e495500c19b8 --- /dev/null +++ b/lagent/actions/ipython_interactive.py @@ -0,0 +1,273 @@ +import re +import signal +from contextlib import contextmanager, redirect_stdout +from dataclasses import dataclass +from enum import Enum +from io import StringIO +from typing import Optional, Type + +from ..schema import ActionReturn, ActionStatusCode +from .base_action import AsyncActionMixin, BaseAction, tool_api +from .parser import BaseParser, JsonParser + + +class Status(str, Enum): + """Execution status.""" + SUCCESS = 'success' + FAILURE = 'failure' + + +@dataclass +class ExecutionResult: + """Execution result.""" + status: Status + value: Optional[str] = None + msg: Optional[str] = None + + +@contextmanager +def _raise_timeout(timeout): + + def _handler(signum, frame): + raise TimeoutError() + + signal.signal(signal.SIGALRM, _handler) + signal.alarm(timeout) + + try: + yield + finally: + signal.alarm(0) + + +class IPythonInteractive(BaseAction): + """An interactive IPython shell for code execution. + + Args: + timeout (int): Upper bound of waiting time for Python script execution. + Defaults to ``20``. + max_out_len (int): maximum output length. No truncation occurs if negative. + Defaults to ``2048``. + use_signals (bool): whether signals should be used for timing function out + or the multiprocessing. Set to ``False`` when not running in the main + thread, e.g. web applications. Defaults to ``True`` + description (dict): The description of the action. Defaults to ``None``. + parser (Type[BaseParser]): The parser class to process the + action's inputs and outputs. Defaults to :class:`JsonParser`. + """ + + def __init__( + self, + timeout: int = 30, + max_out_len: int = 8192, + use_signals: bool = True, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ): + super().__init__(description, parser) + self.timeout = timeout + self._executor = self.create_shell() + self._highlighting = re.compile( + r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]') + self._max_out_len = max_out_len if max_out_len >= 0 else None + self._use_signals = use_signals + + def reset(self): + """Clear the context.""" + self._executor.reset() + + @tool_api + def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn: + """Launch an IPython Interactive Shell to execute code. + + Args: + command (:class:`str`): Python code snippet + timeout (:class:`Optional[int]`): timeout for execution. + This argument only works in the main thread. Defaults to ``None``. + """ + from timeout_decorator import timeout as timer + tool_return = ActionReturn(args={'text': command}, type=self.name) + ret = ( + timer(timeout or self.timeout)(self.exec)(command) + if self._use_signals else self.exec(command)) + if ret.status is Status.SUCCESS: + tool_return.result = [{'type': 'text', 'content': ret.value}] + tool_return.state = ActionStatusCode.SUCCESS + else: + tool_return.errmsg = ret.msg + tool_return.state = ActionStatusCode.API_ERROR + return tool_return + + def exec(self, code: str) -> ExecutionResult: + """Run Python scripts in IPython shell. + + Args: + code (:class:`str`): code block + + Returns: + :py:class:`ExecutionResult`: execution result + """ + with StringIO() as io: + with redirect_stdout(io): + ret = self._executor.run_cell(self.extract_code(code)) + result = ret.result + if result is not None: + return ExecutionResult(Status.SUCCESS, + str(result)[:self._max_out_len]) + outs = io.getvalue().strip().split('\n') + if not outs: + return ExecutionResult(Status.SUCCESS, '') + for i, out in enumerate(outs): + if re.search('Error|Traceback', out, re.S): + if 'TimeoutError' in out: + return ExecutionResult( + Status.FAILURE, + msg=('The code interpreter encountered ' + 'a timeout error.')) + err_idx = i + break + else: + return ExecutionResult(Status.SUCCESS, + outs[-1].strip()[:self._max_out_len]) + return ExecutionResult( + Status.FAILURE, + msg=self._highlighting.sub( + '', '\n'.join(outs[err_idx:])[:self._max_out_len]), + ) + + @staticmethod + def create_shell(): + from IPython import InteractiveShell + from traitlets.config import Config + + c = Config() + c.HistoryManager.enabled = False + c.HistoryManager.hist_file = ':memory:' + return InteractiveShell( + user_ns={'_raise_timeout': _raise_timeout}, config=c) + + @staticmethod + def extract_code(text: str) -> str: + """Extract Python code from markup languages. + + Args: + text (:class:`str`): Markdown-formatted text + + Returns: + :class:`str`: Python code + """ + import json5 + + # Match triple backtick blocks first + triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) + # Match single backtick blocks second + single_match = re.search(r'`([^`]*)`', text, re.DOTALL) + if triple_match: + text = triple_match.group(1) + elif single_match: + text = single_match.group(1) + else: + try: + text = json5.loads(text)['code'] + except Exception: + pass + # If no code blocks found, return original text + return text + + @staticmethod + def wrap_code_with_timeout(code: str, timeout: int) -> str: + if not code.strip(): + return code + code = code.strip('\n').rstrip() + indent = len(code) - len(code.lstrip()) + handle = ' ' * indent + f'with _raise_timeout({timeout}):\n' + block = '\n'.join([' ' + line for line in code.split('\n')]) + wrapped_code = handle + block + last_line = code.split('\n')[-1] + is_expression = True + try: + compile(last_line.lstrip(), '', 'eval') + except SyntaxError: + is_expression = False + if is_expression: + wrapped_code += '\n' * 5 + last_line + return wrapped_code + + +class AsyncIPythonInteractive(AsyncActionMixin, IPythonInteractive): + """An interactive IPython shell for code execution. + + Args: + timeout (int): Upper bound of waiting time for Python script execution. + Defaults to ``20``. + max_out_len (int): maximum output length. No truncation occurs if negative. + Defaults to ``2048``. + use_signals (bool): whether signals should be used for timing function out + or the multiprocessing. Set to ``False`` when not running in the main + thread, e.g. web applications. Defaults to ``True`` + description (dict): The description of the action. Defaults to ``None``. + parser (Type[BaseParser]): The parser class to process the + action's inputs and outputs. Defaults to :class:`JsonParser`. + """ + + @tool_api + async def run(self, + command: str, + timeout: Optional[int] = None) -> ActionReturn: + """Launch an IPython Interactive Shell to execute code. + + Args: + command (:class:`str`): Python code snippet + timeout (:class:`Optional[int]`): timeout for execution. + This argument only works in the main thread. Defaults to ``None``. + """ + tool_return = ActionReturn(args={'text': command}, type=self.name) + ret = await self.exec(command, timeout) + if ret.status is Status.SUCCESS: + tool_return.result = [{'type': 'text', 'content': ret.value}] + tool_return.state = ActionStatusCode.SUCCESS + else: + tool_return.errmsg = ret.msg + tool_return.state = ActionStatusCode.API_ERROR + return tool_return + + async def exec(self, code: str, timeout: int = None) -> ExecutionResult: + """Asynchronously run Python scripts in IPython shell. + + Args: + code (:class:`str`): code block + timeout (:class:`int`): max waiting time for code execution + + Returns: + :py:class:`ExecutionResult`: execution result + """ + with StringIO() as io: + with redirect_stdout(io): + ret = await self._executor.run_cell_async( + # ret = await self.create_shell().run_cell_async( + self.wrap_code_with_timeout( + self.extract_code(code), timeout or self.timeout)) + result = ret.result + if result is not None: + return ExecutionResult(Status.SUCCESS, + str(result)[:self._max_out_len]) + outs = io.getvalue().strip().split('\n') + if not outs: + return ExecutionResult(Status.SUCCESS, '') + for i, out in enumerate(outs): + if re.search('Error|Traceback', out, re.S): + if 'TimeoutError' in out: + return ExecutionResult( + Status.FAILURE, + msg=('The code interpreter encountered a ' + 'timeout error.')) + err_idx = i + break + else: + return ExecutionResult(Status.SUCCESS, + outs[-1].strip()[:self._max_out_len]) + return ExecutionResult( + Status.FAILURE, + msg=self._highlighting.sub( + '', '\n'.join(outs[err_idx:])[:self._max_out_len]), + ) diff --git a/lagent/actions/ipython_interpreter.py b/lagent/actions/ipython_interpreter.py new file mode 100644 index 0000000000000000000000000000000000000000..68e9a0de94602845d191eb63d9e92dc7ae230ca3 --- /dev/null +++ b/lagent/actions/ipython_interpreter.py @@ -0,0 +1,584 @@ +# flake8: noqa: E501 +import asyncio +import base64 +import io +import json +import logging +import os +import queue +import re +import signal +import sys +import tempfile +import traceback +import uuid +from typing import Optional, Tuple, Type + +from jupyter_client import AsyncKernelClient, AsyncKernelManager, AsyncMultiKernelManager +from tenacity import retry, retry_if_result, stop_after_attempt, wait_fixed + +from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api +from lagent.actions.parser import BaseParser, JsonParser +from lagent.schema import ActionReturn, ActionStatusCode + +logger = logging.getLogger(__name__) + +START_CODE = """ +def input(*args, **kwargs): + raise NotImplementedError('Python input() function is disabled.') + +get_ipython().system = lambda *args: print('Assume we have this package, ! is disabled!') +{} +""" # noqa + + +class TimeoutError(Exception): + pass + + +class KernelDeath(Exception): + pass + + +async def async_run_code( + km: AsyncKernelManager, + code, + *, + interrupt_after=30, + iopub_timeout=40, + wait_for_ready_timeout=60, + shutdown_kernel=True, +): + assert iopub_timeout > interrupt_after + try: + + async def get_iopub_msg_with_death_detection(kc: AsyncKernelClient, + *, + timeout=None): + loop = asyncio.get_running_loop() + dead_fut = loop.create_future() + + def restarting(): + assert ( + False + ), "Restart shouldn't happen because config.KernelRestarter.restart_limit is expected to be set to 0" + + def dead(): + logger.info("Kernel has died, will NOT restart") + dead_fut.set_result(None) + + msg_task = asyncio.create_task(kc.get_iopub_msg(timeout=timeout)) + km.add_restart_callback(restarting, "restart") + km.add_restart_callback(dead, "dead") + try: + done, _ = await asyncio.wait( + [dead_fut, msg_task], return_when=asyncio.FIRST_COMPLETED) + if dead_fut in done: + raise KernelDeath() + assert msg_task in done + return await msg_task + finally: + msg_task.cancel() + km.remove_restart_callback(restarting, "restart") + km.remove_restart_callback(dead, "dead") + + async def send_interrupt(): + await asyncio.sleep(interrupt_after) + logger.info("Sending interrupt to kernel") + await km.interrupt_kernel() + + @retry( + retry=retry_if_result(lambda ret: ret[-1].strip() in [ + 'KeyboardInterrupt', + f"Kernel didn't respond in {wait_for_ready_timeout} seconds", + ] if isinstance(ret, tuple) else False), + stop=stop_after_attempt(3), + wait=wait_fixed(1), + retry_error_callback=lambda state: state.outcome.result()) + async def run(): + execute_result = None + error_traceback = None + stream_text_list = [] + kc = km.client() + assert isinstance(kc, AsyncKernelClient) + kc.start_channels() + try: + await kc.wait_for_ready(timeout=wait_for_ready_timeout) + msg_id = kc.execute(code) + while True: + message = await get_iopub_msg_with_death_detection( + kc, timeout=iopub_timeout) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + json.dumps(message, indent=2, default=str)) + assert message["parent_header"]["msg_id"] == msg_id + msg_type = message["msg_type"] + if msg_type == "status": + if message["content"]["execution_state"] == "idle": + break + elif msg_type == "stream": + stream_name = message["content"]["name"] + stream_text = message["content"]["text"] + stream_text_list.append(stream_text) + elif msg_type == "execute_result": + execute_result = message["content"]["data"] + elif msg_type == "error": + error_traceback_lines = message["content"]["traceback"] + error_traceback = "\n".join(error_traceback_lines) + elif msg_type == "execute_input": + pass + else: + assert False, f"Unknown message_type: {msg_type}" + finally: + kc.stop_channels() + return execute_result, error_traceback, "".join(stream_text_list) + + if interrupt_after: + run_task = asyncio.create_task(run()) + send_interrupt_task = asyncio.create_task(send_interrupt()) + done, _ = await asyncio.wait([run_task, send_interrupt_task], + return_when=asyncio.FIRST_COMPLETED) + if run_task in done: + send_interrupt_task.cancel() + else: + assert send_interrupt_task in done + result = await run_task + else: + result = await run() + return result + finally: + if shutdown_kernel: + await km.shutdown_kernel() + + +class IPythonInterpreter(BaseAction): + """A IPython executor that can execute Python scripts in a jupyter manner. + + Args: + timeout (int): Upper bound of waiting time for Python script execution. + Defaults to 20. + user_data_dir (str, optional): Specified the user data directory for files + loading. If set to `ENV`, use `USER_DATA_DIR` environment variable. + Defaults to `ENV`. + work_dir (str, optional): Specify which directory to save output images to. + Defaults to ``'./work_dir/tmp_dir'``. + description (dict): The description of the action. Defaults to ``None``. + parser (Type[BaseParser]): The parser class to process the + action's inputs and outputs. Defaults to :class:`JsonParser`. + """ + + _KERNEL_CLIENTS = {} + + def __init__( + self, + timeout: int = 20, + user_data_dir: str = 'ENV', + work_dir='./work_dir/tmp_dir', + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ): + super().__init__(description, parser) + + self.timeout = timeout + if user_data_dir == 'ENV': + user_data_dir = os.environ.get('USER_DATA_DIR', '') + + if user_data_dir: + user_data_dir = os.path.dirname(user_data_dir) + user_data_dir = f"import os\nos.chdir('{user_data_dir}')" + self.user_data_dir = user_data_dir + self._initialized = False + self.work_dir = work_dir + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir, exist_ok=True) + + @staticmethod + def start_kernel(): + from jupyter_client import KernelManager + + # start the kernel and manager + km = KernelManager() + km.start_kernel() + kc = km.client() + return km, kc + + def initialize(self): + if self._initialized: + return + pid = os.getpid() + if pid not in self._KERNEL_CLIENTS: + self._KERNEL_CLIENTS[pid] = self.start_kernel() + self.kernel_manager, self.kernel_client = self._KERNEL_CLIENTS[pid] + self._initialized = True + self._call(START_CODE.format(self.user_data_dir), None) + + def reset(self): + if not self._initialized: + self.initialize() + else: + code = "get_ipython().run_line_magic('reset', '-f')\n" + \ + START_CODE.format(self.user_data_dir) + self._call(code, None) + + def _call(self, + command: str, + timeout: Optional[int] = None) -> Tuple[str, bool]: + self.initialize() + command = extract_code(command) + + # check previous remaining result + while True: + try: + msg = self.kernel_client.get_iopub_msg(timeout=5) + msg_type = msg['msg_type'] + if msg_type == 'status': + if msg['content'].get('execution_state') == 'idle': + break + except queue.Empty: + # assume no result + break + + self.kernel_client.execute(command) + + def _inner_call(): + result = '' + images = [] + succeed = True + image_idx = 0 + + while True: + text = '' + image = '' + finished = False + msg_type = 'error' + try: + msg = self.kernel_client.get_iopub_msg(timeout=20) + msg_type = msg['msg_type'] + if msg_type == 'status': + if msg['content'].get('execution_state') == 'idle': + finished = True + elif msg_type == 'execute_result': + text = msg['content']['data'].get('text/plain', '') + if 'image/png' in msg['content']['data']: + image_b64 = msg['content']['data']['image/png'] + image_url = publish_image_to_local( + image_b64, self.work_dir) + image_idx += 1 + image = '![fig-%03d](%s)' % (image_idx, image_url) + + elif msg_type == 'display_data': + if 'image/png' in msg['content']['data']: + image_b64 = msg['content']['data']['image/png'] + image_url = publish_image_to_local( + image_b64, self.work_dir) + image_idx += 1 + image = '![fig-%03d](%s)' % (image_idx, image_url) + + else: + text = msg['content']['data'].get('text/plain', '') + elif msg_type == 'stream': + msg_type = msg['content']['name'] # stdout, stderr + text = msg['content']['text'] + elif msg_type == 'error': + succeed = False + text = escape_ansi('\n'.join( + msg['content']['traceback'])) + if 'M6_CODE_INTERPRETER_TIMEOUT' in text: + text = f'Timeout. No response after {timeout} seconds.' # noqa + except queue.Empty: + # stop current task in case break next input. + self.kernel_manager.interrupt_kernel() + succeed = False + text = f'Timeout. No response after {timeout} seconds.' + finished = True + except Exception: + succeed = False + msg = ''.join(traceback.format_exception(*sys.exc_info())) + # text = 'The code interpreter encountered an unexpected error.' # noqa + text = msg + logging.warning(msg) + finished = True + if text: + # result += f'\n\n{msg_type}:\n\n```\n{text}\n```' + result += f'{text}' + + if image: + images.append(image_url) + if finished: + return succeed, dict(text=result, image=images) + + try: + if timeout: + + def handler(signum, frame): + raise TimeoutError() + + signal.signal(signal.SIGALRM, handler) + signal.alarm(timeout) + succeed, result = _inner_call() + except TimeoutError: + succeed = False + text = 'The code interpreter encountered an unexpected error.' + result = f'\n\nerror:\n\n```\n{text}\n```' + finally: + if timeout: + signal.alarm(0) + + # result = result.strip('\n') + return succeed, result + + @tool_api + def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn: + r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail. + + Args: + command (:class:`str`): Python code + timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution. + """ + tool_return = ActionReturn(url=None, args=None, type=self.name) + tool_return.args = dict(text=command) + succeed, result = self._call(command, timeout) + if succeed: + text = result['text'] + image = result.get('image', []) + resp = [dict(type='text', content=text)] + if image: + resp.extend([dict(type='image', content=im) for im in image]) + tool_return.result = resp + # tool_return.result = dict( + # text=result['text'], image=result.get('image', [])[0]) + tool_return.state = ActionStatusCode.SUCCESS + else: + tool_return.errmsg = result.get('text', '') if isinstance( + result, dict) else result + tool_return.state = ActionStatusCode.API_ERROR + return tool_return + + +class AsyncIPythonInterpreter(AsyncActionMixin, IPythonInterpreter): + """A IPython executor that can execute Python scripts in a jupyter manner. + + Args: + timeout (int): Upper bound of waiting time for Python script execution. + Defaults to 20. + user_data_dir (str, optional): Specified the user data directory for files + loading. If set to `ENV`, use `USER_DATA_DIR` environment variable. + Defaults to `ENV`. + work_dir (str, optional): Specify which directory to save output images to. + Defaults to ``'./work_dir/tmp_dir'``. + description (dict): The description of the action. Defaults to ``None``. + parser (Type[BaseParser]): The parser class to process the + action's inputs and outputs. Defaults to :class:`JsonParser`. + """ + + _UNBOUND_KERNEL_CLIENTS = asyncio.Queue() + + def __init__( + self, + timeout: int = 20, + user_data_dir: str = 'ENV', + work_dir=os.path.join(tempfile.gettempdir(), 'tmp_dir'), + max_kernels: Optional[int] = None, + reuse_kernel: bool = True, + startup_rate: bool = 32, + connection_dir: str = tempfile.gettempdir(), + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ): + super().__init__(timeout, user_data_dir, work_dir, description, parser) + from traitlets.config import Config + + c = Config() + c.KernelManager.transport = 'ipc' + self._amkm = AsyncMultiKernelManager( + config=c, connection_dir=connection_dir) + self._max_kernels = max_kernels + self._reuse_kernel = reuse_kernel + self._sem = asyncio.Semaphore(startup_rate) + self._lock = asyncio.Lock() + + async def initialize(self, session_id: str): + session_id = str(session_id) + while True: + if session_id in self._KERNEL_CLIENTS: + return self._KERNEL_CLIENTS[session_id] + if self._reuse_kernel and not self._UNBOUND_KERNEL_CLIENTS.empty(): + self._KERNEL_CLIENTS[ + session_id] = await self._UNBOUND_KERNEL_CLIENTS.get() + return self._KERNEL_CLIENTS[session_id] + async with self._sem: + if self._max_kernels is None or len( + self._KERNEL_CLIENTS + ) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels: + kernel_id = None + try: + kernel_id = await self._amkm.start_kernel() + kernel = self._amkm.get_kernel(kernel_id) + client = kernel.client() + _, error_stacktrace, stream_text = await async_run_code( + kernel, + START_CODE.format(self.user_data_dir), + shutdown_kernel=False) + # check if the output of START_CODE meets expectations + if not (error_stacktrace is None + and stream_text == ''): + raise RuntimeError + except Exception as e: + print(f'Starting kernel error: {e}') + if kernel_id: + await self._amkm.shutdown_kernel(kernel_id) + self._amkm.remove_kernel(kernel_id) + await asyncio.sleep(1) + continue + if self._max_kernels is None: + self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel, + client) + return kernel_id, kernel, client + async with self._lock: + if len(self._KERNEL_CLIENTS + ) + self._UNBOUND_KERNEL_CLIENTS.qsize( + ) < self._max_kernels: + self._KERNEL_CLIENTS[session_id] = (kernel_id, + kernel, client) + return kernel_id, kernel, client + await self._amkm.shutdown_kernel(kernel_id) + self._amkm.remove_kernel(kernel_id) + await asyncio.sleep(1) + + async def reset(self, session_id: str): + session_id = str(session_id) + if session_id not in self._KERNEL_CLIENTS: + return + _, kernel, _ = self._KERNEL_CLIENTS[session_id] + code = "get_ipython().run_line_magic('reset', '-f')\n" + \ + START_CODE.format(self.user_data_dir) + await async_run_code(kernel, code, shutdown_kernel=False) + + async def shutdown(self, session_id: str): + session_id = str(session_id) + if session_id in self._KERNEL_CLIENTS: + kernel_id, _, _ = self._KERNEL_CLIENTS.get(session_id) + await self._amkm.shutdown_kernel(kernel_id) + self._amkm.remove_kernel(kernel_id) + del self._KERNEL_CLIENTS[session_id] + + async def close_session(self, session_id: str): + session_id = str(session_id) + if self._reuse_kernel: + if session_id in self._KERNEL_CLIENTS: + await self.reset(session_id) + await self._UNBOUND_KERNEL_CLIENTS.put( + self._KERNEL_CLIENTS.pop(session_id)) + else: + await self.shutdown(session_id) + + async def _call(self, command, timeout=None, session_id=None): + _, kernel, _ = await self.initialize(str(session_id)) + result = await async_run_code( + kernel, + extract_code(command), + interrupt_after=timeout or self.timeout, + shutdown_kernel=False) + execute_result, error_stacktrace, stream_text = result + if error_stacktrace is not None: + ret = re.sub('^-*\n', '', escape_ansi(error_stacktrace)) + if ret.endswith('KeyboardInterrupt: '): + ret = 'The code interpreter encountered a timeout error.' + status, ret = False, ret.strip() + elif execute_result is not None: + status, ret = True, dict(text=execute_result.get('text/plain', '')) + else: + status, ret = True, dict(text=stream_text.strip()) + return status, ret + + @tool_api + async def run(self, + command: str, + timeout: Optional[int] = None, + session_id: Optional[str] = None) -> ActionReturn: + r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail. + + Args: + command (:class:`str`): Python code + timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution. + """ + tool_return = ActionReturn(url=None, args=None, type=self.name) + tool_return.args = dict(text=command) + succeed, result = await self._call(command, timeout, session_id) + if succeed: + text = result['text'] + image = result.get('image', []) + resp = [dict(type='text', content=text)] + if image: + resp.extend([dict(type='image', content=im) for im in image]) + tool_return.result = resp + # tool_return.result = dict( + # text=result['text'], image=result.get('image', [])[0]) + tool_return.state = ActionStatusCode.SUCCESS + else: + tool_return.errmsg = result.get('text', '') if isinstance( + result, dict) else result + tool_return.state = ActionStatusCode.API_ERROR + return tool_return + + +def extract_code(text): + import json5 + + # Match triple backtick blocks first + triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) + # Match single backtick blocks second + single_match = re.search(r'`([^`]*)`', text, re.DOTALL) + if triple_match: + text = triple_match.group(1) + elif single_match: + text = single_match.group(1) + else: + try: + text = json5.loads(text)['code'] + except Exception: + pass + # If no code blocks found, return original text + return text + + +def escape_ansi(line): + ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]') + return ansi_escape.sub('', line) + + +def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'): + import PIL.Image + image_file = str(uuid.uuid4()) + '.png' + local_image_file = os.path.join(work_dir, image_file) + + png_bytes = base64.b64decode(image_base64) + assert isinstance(png_bytes, bytes) + bytes_io = io.BytesIO(png_bytes) + PIL.Image.open(bytes_io).save(local_image_file, 'png') + + return local_image_file + + +# local test for code interpreter +def get_multiline_input(hint): + print(hint) + print('// Press ENTER to make a new line. Press CTRL-D to end input.') + lines = [] + while True: + try: + line = input() + except EOFError: # CTRL-D + break + lines.append(line) + print('// Input received.') + if lines: + return '\n'.join(lines) + else: + return '' + + +if __name__ == '__main__': + code_interpreter = IPythonInterpreter() + while True: + print(code_interpreter(get_multiline_input('Enter python code:'))) diff --git a/lagent/actions/ipython_manager.py b/lagent/actions/ipython_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..4e8a2fdb1697be4d2e527c30c0542e7003b0bef8 --- /dev/null +++ b/lagent/actions/ipython_manager.py @@ -0,0 +1,220 @@ +import re +import sys +from collections import defaultdict +from contextlib import nullcontext +from io import StringIO +from multiprocessing import Process, Queue +from typing import List, Optional, Type, Union + +from filelock import FileLock +from timeout_decorator import timeout as tm + +from ..schema import ActionReturn, ActionStatusCode +from .base_action import BaseAction +from .parser import BaseParser, JsonParser + + +class IPythonProcess(Process): + + def __init__(self, + in_q: Queue, + out_q: Queue, + timeout: int = 20, + ci_lock: str = None, + daemon: bool = True): + super().__init__(daemon=daemon) + self.in_q = in_q + self.out_q = out_q + self.timeout = timeout + self.session_id2shell = defaultdict(self.create_shell) + self.ci_lock = FileLock( + ci_lock) if ci_lock else nullcontext() # avoid core corruption + self._highlighting = re.compile(r'\x1b\[\d{,3}(;\d{,3}){,3}m') + + def run(self): + while True: + msg = self.in_q.get() + if msg == 'reset': + for session_id, shell in self.session_id2shell.items(): + with self.ci_lock: + try: + shell.reset(new_session=False) + # shell.run_line_magic('reset', '-sf') + except Exception: + self.session_id2shell[ + session_id] = self.create_shell() + self.out_q.put('ok') + elif isinstance(msg, tuple) and len(msg) == 3: + i, session_id, code = msg + res = self.exec(session_id, code) + self.out_q.put((i, session_id, res)) + + def exec(self, session_id, code): + try: + shell = self.session_id2shell[session_id] + with StringIO() as io: + old_stdout = sys.stdout + sys.stdout = io + if self.timeout is False or self.timeout < 0: + shell.run_cell(self.extract_code(code)) + else: + tm(self.timeout)(shell.run_cell)(self.extract_code(code)) + sys.stdout = old_stdout + output = self._highlighting.sub('', io.getvalue().strip()) + output = re.sub(r'^Out\[\d+\]: ', '', output) + if 'Error' in output or 'Traceback' in output: + output = output.lstrip('-').strip() + if output.startswith('TimeoutError'): + output = 'The code interpreter encountered a timeout error.' + return {'status': 'FAILURE', 'msg': output, 'code': code} + return {'status': 'SUCCESS', 'value': output, 'code': code} + except Exception as e: + return {'status': 'FAILURE', 'msg': str(e), 'code': code} + + @staticmethod + def create_shell(enable_history: bool = False, in_memory: bool = True): + from IPython import InteractiveShell + from traitlets.config import Config + + c = Config() + c.HistoryManager.enabled = enable_history + if in_memory: + c.HistoryManager.hist_file = ':memory:' + shell = InteractiveShell(config=c) + return shell + + @staticmethod + def extract_code(text: str) -> str: + """Extract Python code from markup languages. + + Args: + text (:class:`str`): Markdown-formatted text + + Returns: + :class:`str`: Python code + """ + import json5 + + # Match triple backtick blocks first + triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) + # Match single backtick blocks second + single_match = re.search(r'`([^`]*)`', text, re.DOTALL) + if triple_match: + text = triple_match.group(1) + elif single_match: + text = single_match.group(1) + else: + try: + text = json5.loads(text)['code'] + except Exception: + pass + # If no code blocks found, return original text + return text + + +class IPythonInteractiveManager(BaseAction): + """An interactive IPython shell manager for code execution""" + + def __init__( + self, + max_workers: int = 50, + timeout: int = 20, + ci_lock: str = None, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ): + super().__init__(description, parser) + self.max_workers = max_workers + self.timeout = timeout + self.ci_lock = ci_lock + self.id2queue = defaultdict(Queue) + self.id2process = {} + self.out_queue = Queue() + + def __call__(self, + commands: Union[str, List[str]], + session_ids: Union[int, List[int]] = None): + if isinstance(commands, list): + batch_size = len(commands) + is_batch = True + else: + batch_size = 1 + commands = [commands] + is_batch = False + if session_ids is None: + session_ids = range(batch_size) + elif isinstance(session_ids, int): + session_ids = [session_ids] + if len(session_ids) != batch_size or len(session_ids) != len( + set(session_ids)): + raise ValueError( + 'the size of `session_ids` must equal that of `commands`') + try: + exec_results = self.run_code_blocks([ + (session_id, command) + for session_id, command in zip(session_ids, commands) + ]) + except KeyboardInterrupt: + self.clear() + exit(1) + action_returns = [] + for result, code in zip(exec_results, commands): + action_return = ActionReturn({'command': code}, type=self.name) + if result['status'] == 'SUCCESS': + action_return.result = [ + dict(type='text', content=result['value']) + ] + action_return.state = ActionStatusCode.SUCCESS + else: + action_return.errmsg = result['msg'] + action_return.state = ActionStatusCode.API_ERROR + action_returns.append(action_return) + if not is_batch: + return action_returns[0] + return action_returns + + def process_code(self, index, session_id, code): + ipy_id = session_id % self.max_workers + input_queue = self.id2queue[ipy_id] + proc = self.id2process.setdefault( + ipy_id, + IPythonProcess( + input_queue, + self.out_queue, + self.timeout, + self.ci_lock, + daemon=True)) + if not proc.is_alive(): + proc.start() + input_queue.put((index, session_id, code)) + + def run_code_blocks(self, session_code_pairs): + size = len(session_code_pairs) + for index, (session_id, code) in enumerate(session_code_pairs): + self.process_code(index, session_id, code) + results = [] + while len(results) < size: + msg = self.out_queue.get() + if isinstance(msg, tuple) and len(msg) == 3: + index, _, result = msg + results.append((index, result)) + results.sort() + return [item[1] for item in results] + + def clear(self): + self.id2queue.clear() + for proc in self.id2process.values(): + proc.terminate() + self.id2process.clear() + while not self.out_queue.empty(): + self.out_queue.get() + + def reset(self): + cnt = 0 + for q in self.id2queue.values(): + q.put('reset') + cnt += 1 + while cnt > 0: + msg = self.out_queue.get() + if msg == 'ok': + cnt -= 1 diff --git a/lagent/actions/parser.py b/lagent/actions/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..4188ae39ee27f61e7219ab856414827add1cdab7 --- /dev/null +++ b/lagent/actions/parser.py @@ -0,0 +1,146 @@ +import json +import re +from ast import literal_eval +from typing import Any, List, Union + + +class ParseError(Exception): + """Parsing exception class.""" + + def __init__(self, err_msg: str): + self.err_msg = err_msg + + +class BaseParser: + """Base parser to process inputs and outputs of actions. + + Args: + action (:class:`BaseAction`): action to validate + + Attributes: + PARAMETER_DESCRIPTION (:class:`str`): declare the input format which + LLMs should follow when generating arguments for decided tools. + """ + + PARAMETER_DESCRIPTION: str = '' + + def __init__(self, action): + self.action = action + self._api2param = {} + self._api2required = {} + # perform basic argument validation + if action.description: + for api in action.description.get('api_list', + [action.description]): + name = (f'{action.name}.{api["name"]}' + if self.action.is_toolkit else api['name']) + required_parameters = set(api['required']) + all_parameters = {j['name'] for j in api['parameters']} + if not required_parameters.issubset(all_parameters): + raise ValueError( + f'unknown parameters for function "{name}": ' + f'{required_parameters - all_parameters}') + if self.PARAMETER_DESCRIPTION: + api['parameter_description'] = self.PARAMETER_DESCRIPTION + api_name = api['name'] if self.action.is_toolkit else 'run' + self._api2param[api_name] = api['parameters'] + self._api2required[api_name] = api['required'] + + def parse_inputs(self, inputs: str, name: str = 'run') -> dict: + """Parse inputs LLMs generate for the action. + + Args: + inputs (:class:`str`): input string extracted from responses + + Returns: + :class:`dict`: processed input + """ + inputs = {self._api2param[name][0]['name']: inputs} + return inputs + + def parse_outputs(self, outputs: Any) -> List[dict]: + """Parser outputs returned by the action. + + Args: + outputs (:class:`Any`): raw output of the action + + Returns: + :class:`List[dict]`: processed output of which each member is a + dictionary with two keys - 'type' and 'content'. + """ + if isinstance(outputs, dict): + outputs = json.dumps(outputs, ensure_ascii=False) + elif not isinstance(outputs, str): + outputs = str(outputs) + return [{ + 'type': 'text', + 'content': outputs.encode('gbk', 'ignore').decode('gbk') + }] + + +class JsonParser(BaseParser): + """Json parser to convert input string into a dictionary. + + Args: + action (:class:`BaseAction`): action to validate + """ + + PARAMETER_DESCRIPTION = ( + 'If you call this tool, you must pass arguments in ' + 'the JSON format {key: value}, where the key is the parameter name.') + + def parse_inputs(self, + inputs: Union[str, dict], + name: str = 'run') -> dict: + if not isinstance(inputs, dict): + try: + match = re.search(r'^\s*(```json\n)?(.*)\n```\s*$', inputs, + re.S) + if match: + inputs = match.group(2).strip() + inputs = json.loads(inputs) + except json.JSONDecodeError as exc: + raise ParseError(f'invalid json format: {inputs}') from exc + input_keys = set(inputs) + all_keys = {param['name'] for param in self._api2param[name]} + if not input_keys.issubset(all_keys): + raise ParseError(f'unknown arguments: {input_keys - all_keys}') + required_keys = set(self._api2required[name]) + if not input_keys.issuperset(required_keys): + raise ParseError( + f'missing required arguments: {required_keys - input_keys}') + return inputs + + +class TupleParser(BaseParser): + """Tuple parser to convert input string into a tuple. + + Args: + action (:class:`BaseAction`): action to validate + """ + + PARAMETER_DESCRIPTION = ( + 'If you call this tool, you must pass arguments in the tuple format ' + 'like (arg1, arg2, arg3), and the arguments are ordered.') + + def parse_inputs(self, + inputs: Union[str, tuple], + name: str = 'run') -> dict: + if not isinstance(inputs, tuple): + try: + inputs = literal_eval(inputs) + except Exception as exc: + raise ParseError(f'invalid tuple format: {inputs}') from exc + if len(inputs) < len(self._api2required[name]): + raise ParseError( + f'API takes {len(self._api2required[name])} required positional ' + f'arguments but {len(inputs)} were given') + if len(inputs) > len(self._api2param[name]): + raise ParseError( + f'API takes {len(self._api2param[name])} positional arguments ' + f'but {len(inputs)} were given') + inputs = { + self._api2param[name][i]['name']: item + for i, item in enumerate(inputs) + } + return inputs diff --git a/lagent/actions/ppt.py b/lagent/actions/ppt.py new file mode 100644 index 0000000000000000000000000000000000000000..38bcfa8472519d282421d7776c77781658b8d588 --- /dev/null +++ b/lagent/actions/ppt.py @@ -0,0 +1,233 @@ +from typing import Dict, Optional, Type + +from asyncer import asyncify + +from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api +from lagent.actions.parser import BaseParser, JsonParser + +THEME_MAPPING = { + 'Default': { + 'template': None, + 'title': 'Title Slide', + 'single': 'Title and Content', + 'two': 'Two Content', + } +} + + +class PPT(BaseAction): + """Plugin to create ppt slides with text, paragraph, images in good looking styles.""" + + def __init__( + self, + theme_mapping: Optional[Dict[str, dict]] = None, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ): + super().__init__(description, parser) + self.theme_mapping = theme_mapping or THEME_MAPPING + self.pointer = None + self.location = None + + @tool_api(explode_return=True) + def create_file(self, theme: str, abs_location: str) -> dict: + """Create a pptx file with specific themes. + + Args: + theme (:class:`str`): the theme used. The value should be one of ['Default']. + abs_location (:class:`str`): the ppt file's absolute location + + Returns: + :class:`dict`: operation status + * status: the result of the execution + """ + from pptx import Presentation + + self.location = abs_location + try: + self.pointer = Presentation(self.theme_mapping[theme]['template']) + self.pointer.slide_master.name = theme + # print('created') + except Exception as e: + print(e) + return dict(status='created a ppt file.') + + @tool_api(explode_return=True) + def add_first_page(self, title: str, subtitle: str) -> dict: + """Add the first page of ppt. + + Args: + title (:class:`str`): the title of ppt + subtitle (:class:`str`): the subtitle of ppt + + Returns: + :class:`dict`: operation status + * status: the result of the execution + """ + layout_name = self.theme_mapping[self.pointer.slide_master.name]['title'] + layout = next(i for i in self.pointer.slide_master.slide_layouts if i.name == layout_name) + slide = self.pointer.slides.add_slide(layout) + ph_title, ph_subtitle = slide.placeholders + ph_title.text = title + if subtitle: + ph_subtitle.text = subtitle + return dict(status='added page') + + @tool_api(explode_return=True) + def add_text_page(self, title: str, bullet_items: str) -> dict: + """Add text page of ppt. + + Args: + title (:class:`str`): the title of the page + bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them. + + Returns: + :class:`dict`: operation status + * status: the result of the execution + """ # noqa: E501 + layout_name = self.theme_mapping[self.pointer.slide_master.name]['single'] + layout = next(i for i in self.pointer.slide_master.slide_layouts if i.name == layout_name) + slide = self.pointer.slides.add_slide(layout) + ph_title, ph_body = slide.placeholders + ph_title.text = title + ph = ph_body + tf = ph.text_frame + for i, item in enumerate(bullet_items.split('[SPAN]')): + if i == 0: + p = tf.paragraphs[0] + else: + p = tf.add_paragraph() + p.text = item.strip() + p.level = 0 + return dict(status='added page') + + @tool_api(explode_return=True) + def add_text_image_page(self, title: str, bullet_items: str, image: str) -> dict: + """Add a text page with one image. Image should be a path. + + Args: + title (:class:`str`): the title of the page + bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them. + image (:class:`str`): the path of the image + + Returns: + :class:`dict`: operation status + * status: the result of the execution + """ # noqa: E501 + from PIL import Image + + layout_name = self.theme_mapping[self.pointer.slide_master.name]['two'] + layout = next(i for i in self.pointer.slide_master.slide_layouts if i.name == layout_name) + slide = self.pointer.slides.add_slide(layout) + ph_title, ph_body1, ph_body2 = slide.placeholders + ph_title.text = title + ph = ph_body2 + image = Image.open(image) + image_pil = image.to_pil() + left = ph.left + width = ph.width + height = int(width / image_pil.width * image_pil.height) + top = (ph.top + (ph.top + ph.height)) // 2 - height // 2 + slide.shapes.add_picture(image.to_path(), left, top, width, height) + + ph = ph_body1 + tf = ph.text_frame + for i, item in enumerate(bullet_items.split('[SPAN]')): + if i == 0: + p = tf.paragraphs[0] + else: + p = tf.add_paragraph() + p.text = item.strip() + p.level = 0 + + return dict(status='added page') + + @tool_api(explode_return=True) + def submit_file(self) -> dict: + """When all steps done, YOU MUST use submit_file() to submit your work. + + Returns: + :class:`dict`: operation status + * status: the result of the execution + """ + # file_path = os.path.join(self.CACHE_DIR, f'{self._return_timestamp()}.pptx') + # self.pointer.save(file_path) + # retreival_url = upload_file(file_path) + self.pointer.save(self.location) + return dict(status=f'submitted. view ppt at {self.location}') + + +class AsyncPPT(AsyncActionMixin, PPT): + """Plugin to create ppt slides with text, paragraph, images in good looking styles.""" + + @tool_api(explode_return=True) + @asyncify + def create_file(self, theme: str, abs_location: str) -> dict: + """Create a pptx file with specific themes. + + Args: + theme (:class:`str`): the theme used. The value should be one of ['Default']. + abs_location (:class:`str`): the ppt file's absolute location + + Returns: + :class:`dict`: operation status + * status: the result of the execution + """ + return super().create_file(theme, abs_location) + + @tool_api(explode_return=True) + @asyncify + def add_first_page(self, title: str, subtitle: str) -> dict: + """Add the first page of ppt. + + Args: + title (:class:`str`): the title of ppt + subtitle (:class:`str`): the subtitle of ppt + + Returns: + :class:`dict`: operation status + * status: the result of the execution + """ + return super().add_first_page(title, subtitle) + + @tool_api(explode_return=True) + @asyncify + def add_text_page(self, title: str, bullet_items: str) -> dict: + """Add text page of ppt. + + Args: + title (:class:`str`): the title of the page + bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them. + + Returns: + :class:`dict`: operation status + * status: the result of the execution + """ # noqa: E501 + return super().add_text_page(title, bullet_items) + + @tool_api(explode_return=True) + @asyncify + def add_text_image_page(self, title: str, bullet_items: str, image: str) -> dict: + """Add a text page with one image. Image should be a path. + + Args: + title (:class:`str`): the title of the page + bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them. + image (:class:`str`): the path of the image + + Returns: + :class:`dict`: operation status + * status: the result of the execution + """ # noqa: E501 + return super().add_text_image_page(title, bullet_items, image) + + @tool_api(explode_return=True) + @asyncify + def submit_file(self) -> dict: + """When all steps done, YOU MUST use submit_file() to submit your work. + + Returns: + :class:`dict`: operation status + * status: the result of the execution + """ + return super().submit_file() diff --git a/lagent/actions/python_interpreter.py b/lagent/actions/python_interpreter.py new file mode 100644 index 0000000000000000000000000000000000000000..8bd3266b3ac8c642023161d21db969c26547fc3d --- /dev/null +++ b/lagent/actions/python_interpreter.py @@ -0,0 +1,176 @@ +# flake8: noqa: E501 +import copy +import io +from contextlib import redirect_stdout +from typing import Any, Optional, Type + +from asyncer import asyncify + +from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api +from lagent.actions.parser import BaseParser, JsonParser +from lagent.schema import ActionReturn, ActionStatusCode + + +class GenericRuntime: + GLOBAL_DICT = {} + LOCAL_DICT = None + HEADERS = [] + + def __init__(self): + self._global_vars = copy.copy(self.GLOBAL_DICT) + self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None + + for c in self.HEADERS: + self.exec_code(c) + + def exec_code(self, code_piece: str) -> None: + exec(code_piece, self._global_vars) + + def eval_code(self, expr: str) -> Any: + return eval(expr, self._global_vars) + + +class PythonInterpreter(BaseAction): + """A Python executor that can execute Python scripts. + + Args: + answer_symbol (str, Optional): the answer symbol from LLM. Defaults to ``None``. + answer_expr (str, Optional): the answer function name of the Python + script. Defaults to ``'solution()'``. + answer_from_stdout (boolean, Optional): whether the execution results is from + stdout. Defaults to ``False``. + timeout (int, Optional): Upper bound of waiting time for Python script execution. + Defaults to ``20``. + description (dict, Optional): The description of the action. Defaults to ``None``. + parser (Type[BaseParser]): The parser class to process the + action's inputs and outputs. Defaults to :class:`JsonParser`. + """ + + def __init__( + self, + answer_symbol: Optional[str] = None, + answer_expr: Optional[str] = 'solution()', + answer_from_stdout: bool = False, + timeout: int = 20, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ) -> None: + super().__init__(description, parser) + self.answer_symbol = answer_symbol + self.answer_expr = answer_expr + self.answer_from_stdout = answer_from_stdout + self.timeout = timeout + + @tool_api + def run(self, command: str) -> ActionReturn: + """用来执行Python代码。代码必须是一个函数,函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下: + + ```python + # import 依赖包 + import xxx + def solution(): + # 初始化一些变量 + variable_names_with_real_meaning = xxx + # 步骤一 + mid_variable = func(variable_names_with_real_meaning) + # 步骤 x + mid_variable = func(mid_variable) + # 最后结果 + final_answer = func(mid_variable) + return final_answer + ``` + + Args: + command (:class:`str`): Python code snippet + """ + from func_timeout import FunctionTimedOut, func_set_timeout + + self.runtime = GenericRuntime() + try: + tool_return = func_set_timeout(self.timeout)(self._call)(command) + except FunctionTimedOut as e: + tool_return = ActionReturn(type=self.name) + tool_return.errmsg = repr(e) + tool_return.state = ActionStatusCode.API_ERROR + return tool_return + + def _call(self, command: str) -> ActionReturn: + tool_return = ActionReturn(type=self.name) + try: + if '```python' in command: + command = command.split('```python')[1].split('```')[0] + elif '```' in command: + command = command.split('```')[1].split('```')[0] + tool_return.args = dict(text='```python\n' + command + '\n```') + command = command.split('\n') + + if self.answer_from_stdout: + program_io = io.StringIO() + with redirect_stdout(program_io): + self.runtime.exec_code('\n'.join(command)) + program_io.seek(0) + res = program_io.readlines()[-1] + elif self.answer_symbol: + self.runtime.exec_code('\n'.join(command)) + res = self.runtime._global_vars[self.answer_symbol] + elif self.answer_expr: + self.runtime.exec_code('\n'.join(command)) + res = self.runtime.eval_code(self.answer_expr) + else: + self.runtime.exec_code('\n'.join(command[:-1])) + res = self.runtime.eval_code(command[-1]) + except Exception as e: + tool_return.errmsg = repr(e) + tool_return.type = self.name + tool_return.state = ActionStatusCode.API_ERROR + return tool_return + try: + tool_return.result = [dict(type='text', content=str(res))] + tool_return.state = ActionStatusCode.SUCCESS + except Exception as e: + tool_return.errmsg = repr(e) + tool_return.type = self.name + tool_return.state = ActionStatusCode.API_ERROR + return tool_return + + +class AsyncPythonInterpreter(AsyncActionMixin, PythonInterpreter): + """A Python executor that can execute Python scripts. + + Args: + answer_symbol (str, Optional): the answer symbol from LLM. Defaults to ``None``. + answer_expr (str, Optional): the answer function name of the Python + script. Defaults to ``'solution()'``. + answer_from_stdout (boolean, Optional): whether the execution results is from + stdout. Defaults to ``False``. + timeout (int, Optional): Upper bound of waiting time for Python script execution. + Defaults to ``20``. + description (dict, Optional): The description of the action. Defaults to ``None``. + parser (Type[BaseParser]): The parser class to process the + action's inputs and outputs. Defaults to :class:`JsonParser`. + """ + + @tool_api + @asyncify + def run(self, command: str) -> ActionReturn: + """用来执行Python代码。代码必须是一个函数,函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下: + + ```python + # import 依赖包 + import xxx + def solution(): + # 初始化一些变量 + variable_names_with_real_meaning = xxx + # 步骤一 + mid_variable = func(variable_names_with_real_meaning) + # 步骤 x + mid_variable = func(mid_variable) + # 最后结果 + final_answer = func(mid_variable) + return final_answer + ``` + + Args: + command (:class:`str`): Python code snippet + """ + return super().run(command) diff --git a/lagent/actions/weather_query.py b/lagent/actions/weather_query.py new file mode 100644 index 0000000000000000000000000000000000000000..dbe3e991dbca34e0a6d373d62d457c7237317741 --- /dev/null +++ b/lagent/actions/weather_query.py @@ -0,0 +1,71 @@ +import os +import requests +from lagent.actions.base_action import BaseAction, tool_api +from lagent.schema import ActionReturn, ActionStatusCode + +class WeatherQuery(BaseAction): + def __init__(self): + super().__init__() + self.api_key = os.getenv("weather_token") + print(self.api_key) + if not self.api_key: + raise EnvironmentError("未找到环境变量 'token'。请设置你的和风天气 API Key 到 'weather_token' 环境变量中,比如export weather_token='xxx' ") + + @tool_api + def run(self, location: str) -> dict: + """ + 查询实时天气信息。 + + Args: + location (str): 要查询的地点名称、LocationID 或经纬度坐标(如 "101010100" 或 "116.41,39.92")。 + + Returns: + dict: 包含天气信息的字典 + * location: 地点名称 + * weather: 天气状况 + * temperature: 当前温度 + * wind_direction: 风向 + * wind_speed: 风速(公里/小时) + * humidity: 相对湿度(%) + * report_time: 数据报告时间 + """ + try: + # 如果 location 不是坐标格式(例如 "116.41,39.92"),则调用 GeoAPI 获取 LocationID + if not ("," in location and location.replace(",", "").replace(".", "").isdigit()): + # 使用 GeoAPI 获取 LocationID + geo_url = f"https://geoapi.qweather.com/v2/city/lookup?location={location}&key={self.api_key}" + geo_response = requests.get(geo_url) + geo_data = geo_response.json() + + if geo_data.get("code") != "200" or not geo_data.get("location"): + raise Exception(f"GeoAPI 返回错误码:{geo_data.get('code')} 或未找到位置") + + location = geo_data["location"][0]["id"] + + # 构建天气查询的 API 请求 URL + weather_url = f"https://devapi.qweather.com/v7/weather/now?location={location}&key={self.api_key}" + response = requests.get(weather_url) + data = response.json() + + # 检查 API 响应码 + if data.get("code") != "200": + raise Exception(f"Weather API 返回错误码:{data.get('code')}") + + # 解析和组织天气信息 + weather_info = { + "location": location, + "weather": data["now"]["text"], + "temperature": data["now"]["temp"] + "°C", + "wind_direction": data["now"]["windDir"], + "wind_speed": data["now"]["windSpeed"] + " km/h", + "humidity": data["now"]["humidity"] + "%", + "report_time": data["updateTime"] + } + + return {"result": weather_info} + + except Exception as exc: + return ActionReturn( + errmsg=f"WeatherQuery 异常:{exc}", + state=ActionStatusCode.HTTP_ERROR + ) \ No newline at end of file diff --git a/lagent/actions/web_browser.py b/lagent/actions/web_browser.py new file mode 100644 index 0000000000000000000000000000000000000000..432fe27eee69175220aea506b415278ed97ea767 --- /dev/null +++ b/lagent/actions/web_browser.py @@ -0,0 +1,908 @@ +import asyncio +import hashlib +import hmac +import json +import logging +import random +import re +import time +import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from http.client import HTTPSConnection +from typing import List, Optional, Tuple, Type, Union + +import aiohttp +import aiohttp.client_exceptions +import requests +from asyncache import cached as acached +from bs4 import BeautifulSoup +from cachetools import TTLCache, cached +from duckduckgo_search import DDGS, AsyncDDGS + +from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api +from lagent.actions.parser import BaseParser, JsonParser +from lagent.utils import async_as_completed + + +class BaseSearch: + + def __init__(self, topk: int = 3, black_list: List[str] = None): + self.topk = topk + self.black_list = black_list + + def _filter_results(self, results: List[tuple]) -> dict: + filtered_results = {} + count = 0 + for url, snippet, title in results: + if all(domain not in url + for domain in self.black_list) and not url.endswith('.pdf'): + filtered_results[count] = { + 'url': url, + 'summ': json.dumps(snippet, ensure_ascii=False)[1:-1], + 'title': title + } + count += 1 + if count >= self.topk: + break + return filtered_results + + +class DuckDuckGoSearch(BaseSearch): + + def __init__(self, + topk: int = 3, + black_list: List[str] = [ + 'enoN', + 'youtube.com', + 'bilibili.com', + 'researchgate.net', + ], + **kwargs): + self.proxy = kwargs.get('proxy') + self.timeout = kwargs.get('timeout', 30) + super().__init__(topk, black_list) + + @cached(cache=TTLCache(maxsize=100, ttl=600)) + def search(self, query: str, max_retry: int = 3) -> dict: + for attempt in range(max_retry): + try: + response = self._call_ddgs( + query, timeout=self.timeout, proxy=self.proxy) + return self._parse_response(response) + except Exception as e: + logging.exception(str(e)) + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') + time.sleep(random.randint(2, 5)) + raise Exception( + 'Failed to get search results from DuckDuckGo after retries.') + + @acached(cache=TTLCache(maxsize=100, ttl=600)) + async def asearch(self, query: str, max_retry: int = 3) -> dict: + for attempt in range(max_retry): + try: + ddgs = AsyncDDGS(timeout=self.timeout, proxy=self.proxy) + response = await ddgs.atext(query.strip("'"), max_results=10) + return self._parse_response(response) + except Exception as e: + if isinstance(e, asyncio.TimeoutError): + logging.exception('Request to DDGS timed out.') + logging.exception(str(e)) + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') + await asyncio.sleep(random.randint(2, 5)) + raise Exception( + 'Failed to get search results from DuckDuckGo after retries.') + + async def _async_call_ddgs(self, query: str, **kwargs) -> dict: + ddgs = DDGS(**kwargs) + try: + response = await asyncio.wait_for( + asyncio.to_thread(ddgs.text, query.strip("'"), max_results=10), + timeout=self.timeout) + return response + except asyncio.TimeoutError: + logging.exception('Request to DDGS timed out.') + raise + + def _call_ddgs(self, query: str, **kwargs) -> dict: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + response = loop.run_until_complete( + self._async_call_ddgs(query, **kwargs)) + return response + finally: + loop.close() + + def _parse_response(self, response: dict) -> dict: + raw_results = [] + for item in response: + raw_results.append( + (item['href'], item['description'] + if 'description' in item else item['body'], item['title'])) + return self._filter_results(raw_results) + + +class BingSearch(BaseSearch): + + def __init__(self, + api_key: str, + region: str = 'zh-CN', + topk: int = 3, + black_list: List[str] = [ + 'enoN', + 'youtube.com', + 'bilibili.com', + 'researchgate.net', + ], + **kwargs): + self.api_key = api_key + self.market = region + self.proxy = kwargs.get('proxy') + super().__init__(topk, black_list) + + @cached(cache=TTLCache(maxsize=100, ttl=600)) + def search(self, query: str, max_retry: int = 3) -> dict: + for attempt in range(max_retry): + try: + response = self._call_bing_api(query) + return self._parse_response(response) + except Exception as e: + logging.exception(str(e)) + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') + time.sleep(random.randint(2, 5)) + raise Exception( + 'Failed to get search results from Bing Search after retries.') + + @acached(cache=TTLCache(maxsize=100, ttl=600)) + async def asearch(self, query: str, max_retry: int = 3) -> dict: + for attempt in range(max_retry): + try: + response = await self._async_call_bing_api(query) + return self._parse_response(response) + except Exception as e: + logging.exception(str(e)) + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') + await asyncio.sleep(random.randint(2, 5)) + raise Exception( + 'Failed to get search results from Bing Search after retries.') + + def _call_bing_api(self, query: str) -> dict: + endpoint = 'https://api.bing.microsoft.com/v7.0/search' + params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'} + headers = {'Ocp-Apim-Subscription-Key': self.api_key} + response = requests.get( + endpoint, headers=headers, params=params, proxies=self.proxy) + response.raise_for_status() + return response.json() + + async def _async_call_bing_api(self, query: str) -> dict: + endpoint = 'https://api.bing.microsoft.com/v7.0/search' + params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'} + headers = {'Ocp-Apim-Subscription-Key': self.api_key} + async with aiohttp.ClientSession(raise_for_status=True) as session: + async with session.get( + endpoint, + headers=headers, + params=params, + proxy=self.proxy and + (self.proxy.get('http') or self.proxy.get('https'))) as resp: + return await resp.json() + + def _parse_response(self, response: dict) -> dict: + webpages = { + w['id']: w + for w in response.get('webPages', {}).get('value', []) + } + raw_results = [] + + for item in response.get('rankingResponse', + {}).get('mainline', {}).get('items', []): + if item['answerType'] == 'WebPages': + webpage = webpages.get(item['value']['id']) + if webpage: + raw_results.append( + (webpage['url'], webpage['snippet'], webpage['name'])) + elif item['answerType'] == 'News' and item['value'][ + 'id'] == response.get('news', {}).get('id'): + for news in response.get('news', {}).get('value', []): + raw_results.append( + (news['url'], news['description'], news['name'])) + + return self._filter_results(raw_results) + + +class BraveSearch(BaseSearch): + """ + Wrapper around the Brave Search API. + + To use, you should pass your Brave Search API key to the constructor. + + Args: + api_key (str): API KEY to use Brave Search API. + You can create a free API key at https://api.search.brave.com/app/keys. + search_type (str): Brave Search API supports ['web', 'news', 'images', 'videos'], + currently only supports 'news' and 'web'. + topk (int): The number of search results returned in response from API search results. + region (str): The country code string. Specifies the country where the search results come from. + language (str): The language code string. Specifies the preferred language for the search results. + extra_snippets (bool): Allows retrieving up to 5 additional snippets, which are alternative excerpts from the search results. + **kwargs: Any other parameters related to the Brave Search API. Find more details at + https://api.search.brave.com/app/documentation/web-search/get-started. + """ + + def __init__(self, + api_key: str, + region: str = 'ALL', + language: str = 'zh-hans', + extra_snippests: bool = True, + topk: int = 3, + black_list: List[str] = [ + 'enoN', + 'youtube.com', + 'bilibili.com', + 'researchgate.net', + ], + **kwargs): + self.api_key = api_key + self.market = region + self.proxy = kwargs.get('proxy') + self.language = language + self.extra_snippests = extra_snippests + self.search_type = kwargs.get('search_type', 'web') + self.kwargs = kwargs + super().__init__(topk, black_list) + + @cached(cache=TTLCache(maxsize=100, ttl=600)) + def search(self, query: str, max_retry: int = 3) -> dict: + for attempt in range(max_retry): + try: + response = self._call_brave_api(query) + return self._parse_response(response) + except Exception as e: + logging.exception(str(e)) + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') + time.sleep(random.randint(2, 5)) + raise Exception( + 'Failed to get search results from Brave Search after retries.') + + @acached(cache=TTLCache(maxsize=100, ttl=600)) + async def asearch(self, query: str, max_retry: int = 3) -> dict: + for attempt in range(max_retry): + try: + response = await self._async_call_brave_api(query) + return self._parse_response(response) + except Exception as e: + logging.exception(str(e)) + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') + await asyncio.sleep(random.randint(2, 5)) + raise Exception( + 'Failed to get search results from Brave Search after retries.') + + def _call_brave_api(self, query: str) -> dict: + endpoint = f'https://api.search.brave.com/res/v1/{self.search_type}/search' + params = { + 'q': query, + 'country': self.market, + 'search_lang': self.language, + 'extra_snippets': self.extra_snippests, + 'count': self.topk, + **{ + key: value + for key, value in self.kwargs.items() if value is not None + }, + } + headers = { + 'X-Subscription-Token': self.api_key or '', + 'Accept': 'application/json' + } + response = requests.get( + endpoint, headers=headers, params=params, proxies=self.proxy) + response.raise_for_status() + return response.json() + + async def _async_call_brave_api(self, query: str) -> dict: + endpoint = f'https://api.search.brave.com/res/v1/{self.search_type}/search' + params = { + 'q': query, + 'country': self.market, + 'search_lang': self.language, + 'extra_snippets': self.extra_snippests, + 'count': self.topk, + **{ + key: value + for key, value in self.kwargs.items() if value is not None + }, + } + headers = { + 'X-Subscription-Token': self.api_key or '', + 'Accept': 'application/json' + } + async with aiohttp.ClientSession(raise_for_status=True) as session: + async with session.get( + endpoint, + headers=headers, + params=params, + proxy=self.proxy and + (self.proxy.get('http') or self.proxy.get('https'))) as resp: + return await resp.json() + + def _parse_response(self, response: dict) -> dict: + if self.search_type == 'web': + filtered_result = response.get('web', {}).get('results', []) + else: + filtered_result = response.get('results', {}) + raw_results = [] + + for item in filtered_result: + raw_results.append(( + item.get('url', ''), + ' '.join( + filter(None, [ + item.get('description'), + *item.get('extra_snippets', []) + ])), + item.get('title', ''), + )) + return self._filter_results(raw_results) + + +class GoogleSearch(BaseSearch): + """ + Wrapper around the Serper.dev Google Search API. + + To use, you should pass your serper API key to the constructor. + + Args: + api_key (str): API KEY to use serper google search API. + You can create a free API key at https://serper.dev. + search_type (str): Serper API supports ['search', 'images', 'news', + 'places'] types of search, currently we only support 'search' and 'news'. + topk (int): The number of search results returned in response from api search results. + **kwargs: Any other parameters related to the Serper API. Find more details at + https://serper.dev/playground + """ + + result_key_for_type = { + 'news': 'news', + 'places': 'places', + 'images': 'images', + 'search': 'organic', + } + + def __init__(self, + api_key: str, + topk: int = 3, + black_list: List[str] = [ + 'enoN', + 'youtube.com', + 'bilibili.com', + 'researchgate.net', + ], + **kwargs): + self.api_key = api_key + self.proxy = kwargs.get('proxy') + self.search_type = kwargs.get('search_type', 'search') + self.kwargs = kwargs + super().__init__(topk, black_list) + + @cached(cache=TTLCache(maxsize=100, ttl=600)) + def search(self, query: str, max_retry: int = 3) -> dict: + for attempt in range(max_retry): + try: + response = self._call_serper_api(query) + return self._parse_response(response) + except Exception as e: + logging.exception(str(e)) + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') + time.sleep(random.randint(2, 5)) + raise Exception( + 'Failed to get search results from Google Serper Search after retries.' + ) + + @acached(cache=TTLCache(maxsize=100, ttl=600)) + async def asearch(self, query: str, max_retry: int = 3) -> dict: + for attempt in range(max_retry): + try: + response = await self._async_call_serper_api(query) + return self._parse_response(response) + except Exception as e: + logging.exception(str(e)) + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') + await asyncio.sleep(random.randint(2, 5)) + raise Exception( + 'Failed to get search results from Google Serper Search after retries.' + ) + + def _call_serper_api(self, query: str) -> dict: + endpoint = f'https://google.serper.dev/{self.search_type}' + params = { + 'q': query, + 'num': self.topk, + **{ + key: value + for key, value in self.kwargs.items() if value is not None + }, + } + headers = { + 'X-API-KEY': self.api_key or '', + 'Content-Type': 'application/json' + } + response = requests.get( + endpoint, headers=headers, params=params, proxies=self.proxy) + response.raise_for_status() + return response.json() + + async def _async_call_serper_api(self, query: str) -> dict: + endpoint = f'https://google.serper.dev/{self.search_type}' + params = { + 'q': query, + 'num': self.topk, + **{ + key: value + for key, value in self.kwargs.items() if value is not None + }, + } + headers = { + 'X-API-KEY': self.api_key or '', + 'Content-Type': 'application/json' + } + async with aiohttp.ClientSession(raise_for_status=True) as session: + async with session.get( + endpoint, + headers=headers, + params=params, + proxy=self.proxy and + (self.proxy.get('http') or self.proxy.get('https'))) as resp: + return await resp.json() + + def _parse_response(self, response: dict) -> dict: + raw_results = [] + + if response.get('answerBox'): + answer_box = response.get('answerBox', {}) + if answer_box.get('answer'): + raw_results.append(('', answer_box.get('answer'), '')) + elif answer_box.get('snippet'): + raw_results.append( + ('', answer_box.get('snippet').replace('\n', ' '), '')) + elif answer_box.get('snippetHighlighted'): + raw_results.append( + ('', answer_box.get('snippetHighlighted'), '')) + + if response.get('knowledgeGraph'): + kg = response.get('knowledgeGraph', {}) + description = kg.get('description', '') + attributes = '. '.join( + f'{attribute}: {value}' + for attribute, value in kg.get('attributes', {}).items()) + raw_results.append( + (kg.get('descriptionLink', ''), + f'{description}. {attributes}' if attributes else description, + f"{kg.get('title', '')}: {kg.get('type', '')}.")) + + for result in response[self.result_key_for_type[ + self.search_type]][:self.topk]: + description = result.get('snippet', '') + attributes = '. '.join( + f'{attribute}: {value}' + for attribute, value in result.get('attributes', {}).items()) + raw_results.append( + (result.get('link', ''), + f'{description}. {attributes}' if attributes else description, + result.get('title', ''))) + + return self._filter_results(raw_results) + + +class TencentSearch(BaseSearch): + """Wrapper around the tencentclound Search API. + + To use, you should pass your secret_id and secret_key to the constructor. + + Args: + secret_id (str): Your Tencent Cloud secret ID for accessing the API. + For more details, refer to the documentation: https://cloud.tencent.com/document/product/598/40488. + secret_key (str): Your Tencent Cloud secret key for accessing the API. + api_key (str, optional): Additional API key, if required. + action (str): The action for this interface, use `SearchCommon`. + version (str): The API version, use `2020-12-29`. + service (str): The service name, use `tms`. + host (str): The API host, use `tms.tencentcloudapi.com`. + topk (int): The maximum number of search results to return. + tsn (int): Time filter for search results. Valid values: + 1 (within 1 day), 2 (within 1 week), 3 (within 1 month), + 4 (within 1 year), 5 (within 6 months), 6 (within 3 years). + insite (str): Specify a site to search within (supports only a single site). + If not specified, the entire web is searched. Example: `zhihu.com`. + category (str): Vertical category for filtering results. Optional values include: + `baike` (encyclopedia), `weather`, `calendar`, `medical`, `news`, `train`, `star` (horoscope). + vrid (str): Result card type(s). Different `vrid` values represent different types of result cards. + Supports multiple values separated by commas. Example: `30010255`. + """ + + def __init__(self, + secret_id: str = 'Your SecretId', + secret_key: str = 'Your SecretKey', + api_key: str = '', + action: str = 'SearchCommon', + version: str = '2020-12-29', + service: str = 'tms', + host: str = 'tms.tencentcloudapi.com', + topk: int = 3, + tsn: int = None, + insite: str = None, + category: str = None, + vrid: str = None, + black_list: List[str] = [ + 'enoN', + 'youtube.com', + 'bilibili.com', + 'researchgate.net', + ]): + self.secret_id = secret_id + self.secret_key = secret_key + self.api_key = api_key + self.action = action + self.version = version + self.service = service + self.host = host + self.tsn = tsn + self.insite = insite + self.category = category + self.vrid = vrid + super().__init__(topk, black_list=black_list) + + @cached(cache=TTLCache(maxsize=100, ttl=600)) + def search(self, query: str, max_retry: int = 3) -> dict: + for attempt in range(max_retry): + try: + response = self._call_tencent_api(query) + return self._parse_response(response) + except Exception as e: + logging.exception(str(e)) + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') + time.sleep(random.randint(2, 5)) + raise Exception( + 'Failed to get search results from Bing Search after retries.') + + @acached(cache=TTLCache(maxsize=100, ttl=600)) + async def asearch(self, query: str, max_retry: int = 3) -> dict: + for attempt in range(max_retry): + try: + response = await self._async_call_tencent_api(query) + return self._parse_response(response) + except Exception as e: + logging.exception(str(e)) + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') + await asyncio.sleep(random.randint(2, 5)) + raise Exception( + 'Failed to get search results from Bing Search after retries.') + + def _get_headers_and_payload(self, query: str) -> tuple: + + def sign(key, msg): + return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest() + + params = dict(Query=query) + # if self.topk: + # params['Cnt'] = self.topk + if self.tsn: + params['Tsn'] = self.tsn + if self.insite: + params['Insite'] = self.insite + if self.category: + params['Category'] = self.category + if self.vrid: + params['Vrid'] = self.vrid + payload = json.dumps(params) + algorithm = 'TC3-HMAC-SHA256' + timestamp = int(time.time()) + date = datetime.utcfromtimestamp(timestamp).strftime('%Y-%m-%d') + + # ************* 步骤 1:拼接规范请求串 ************* + http_request_method = 'POST' + canonical_uri = '/' + canonical_querystring = '' + ct = 'application/json; charset=utf-8' + canonical_headers = f'content-type:{ct}\nhost:{self.host}\nx-tc-action:{self.action.lower()}\n' + signed_headers = 'content-type;host;x-tc-action' + hashed_request_payload = hashlib.sha256( + payload.encode('utf-8')).hexdigest() + canonical_request = ( + http_request_method + '\n' + canonical_uri + '\n' + + canonical_querystring + '\n' + canonical_headers + '\n' + + signed_headers + '\n' + hashed_request_payload) + + # ************* 步骤 2:拼接待签名字符串 ************* + credential_scope = date + '/' + self.service + '/' + 'tc3_request' + hashed_canonical_request = hashlib.sha256( + canonical_request.encode('utf-8')).hexdigest() + string_to_sign = ( + algorithm + '\n' + str(timestamp) + '\n' + credential_scope + + '\n' + hashed_canonical_request) + + # ************* 步骤 3:计算签名 ************* + secret_date = sign(('TC3' + self.secret_key).encode('utf-8'), date) + secret_service = sign(secret_date, self.service) + secret_signing = sign(secret_service, 'tc3_request') + signature = hmac.new(secret_signing, string_to_sign.encode('utf-8'), + hashlib.sha256).hexdigest() + + # ************* 步骤 4:拼接 Authorization ************* + authorization = ( + algorithm + ' ' + 'Credential=' + self.secret_id + '/' + + credential_scope + ', ' + 'SignedHeaders=' + signed_headers + + ', ' + 'Signature=' + signature) + + # ************* 步骤 5:构造并发起请求 ************* + headers = { + 'Authorization': authorization, + 'Content-Type': 'application/json; charset=utf-8', + 'Host': self.host, + 'X-TC-Action': self.action, + 'X-TC-Timestamp': str(timestamp), + 'X-TC-Version': self.version + } + # if self.region: + # headers["X-TC-Region"] = self.region + if self.api_key: + headers['X-TC-Token'] = self.api_key + return headers, payload + + def _call_tencent_api(self, query: str) -> dict: + headers, payload = self._get_headers_and_payload(query) + req = HTTPSConnection(self.host) + req.request('POST', '/', headers=headers, body=payload.encode('utf-8')) + resp = req.getresponse() + try: + resp = json.loads(resp.read().decode('utf-8')) + except Exception as e: + logging.warning(str(e)) + import ast + resp = ast.literal_eval(resp) + return resp.get('Response', dict()) + + async def _async_call_tencent_api(self, query: str): + headers, payload = self._get_headers_and_payload(query) + async with aiohttp.ClientSession(raise_for_status=True) as session: + async with session.post( + 'https://' + self.host.lstrip('/'), + headers=headers, + data=payload) as resp: + return (await resp.json()).get('Response', {}) + + def _parse_response(self, response: dict) -> dict: + raw_results = [] + for item in response.get('Pages', []): + display = json.loads(item['Display']) + if not display['url']: + continue + raw_results.append((display['url'], display['content'] + or display['abstract_info'], display['title'])) + return self._filter_results(raw_results) + + +class ContentFetcher: + + def __init__(self, timeout: int = 5): + self.timeout = timeout + + @cached(cache=TTLCache(maxsize=100, ttl=600)) + def fetch(self, url: str) -> Tuple[bool, str]: + try: + response = requests.get(url, timeout=self.timeout) + response.raise_for_status() + html = response.content + except requests.RequestException as e: + return False, str(e) + + text = BeautifulSoup(html, 'html.parser').get_text() + cleaned_text = re.sub(r'\n+', '\n', text) + return True, cleaned_text + + @acached(cache=TTLCache(maxsize=100, ttl=600)) + async def afetch(self, url: str) -> Tuple[bool, str]: + try: + async with aiohttp.ClientSession( + raise_for_status=True, + timeout=aiohttp.ClientTimeout(self.timeout)) as session: + async with session.get(url) as resp: + html = await resp.text(errors='ignore') + text = BeautifulSoup(html, 'html.parser').get_text() + cleaned_text = re.sub(r'\n+', '\n', text) + return True, cleaned_text + except Exception as e: + return False, str(e) + + +class WebBrowser(BaseAction): + """Wrapper around the Web Browser Tool. + """ + + def __init__(self, + searcher_type: str = 'DuckDuckGoSearch', + timeout: int = 5, + black_list: Optional[List[str]] = [ + 'enoN', + 'youtube.com', + 'bilibili.com', + 'researchgate.net', + ], + topk: int = 20, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + **kwargs): + self.searcher = eval(searcher_type)( + black_list=black_list, topk=topk, **kwargs) + self.fetcher = ContentFetcher(timeout=timeout) + self.search_results = None + super().__init__(description, parser) + + @tool_api + def search(self, query: Union[str, List[str]]) -> dict: + """BING search API + Args: + query (List[str]): list of search query strings + """ + queries = query if isinstance(query, list) else [query] + search_results = {} + + with ThreadPoolExecutor() as executor: + future_to_query = { + executor.submit(self.searcher.search, q): q + for q in queries + } + + for future in as_completed(future_to_query): + query = future_to_query[future] + try: + results = future.result() + except Exception as exc: + warnings.warn(f'{query} generated an exception: {exc}') + else: + for result in results.values(): + if result['url'] not in search_results: + search_results[result['url']] = result + else: + search_results[ + result['url']]['summ'] += f"\n{result['summ']}" + + self.search_results = { + idx: result + for idx, result in enumerate(search_results.values()) + } + return self.search_results + + @tool_api + def select(self, select_ids: List[int]) -> dict: + """get the detailed content on the selected pages. + + Args: + select_ids (List[int]): list of index to select. Max number of index to be selected is no more than 4. + """ + if not self.search_results: + raise ValueError('No search results to select from.') + + new_search_results = {} + with ThreadPoolExecutor() as executor: + future_to_id = { + executor.submit(self.fetcher.fetch, self.search_results[select_id]['url']): select_id + for select_id in select_ids if select_id in self.search_results + } + for future in as_completed(future_to_id): + select_id = future_to_id[future] + try: + web_success, web_content = future.result() + except Exception as exc: + warnings.warn(f'{select_id} generated an exception: {exc}') + else: + if web_success: + self.search_results[select_id][ + 'content'] = web_content[:8192] + new_search_results[select_id] = self.search_results[ + select_id].copy() + new_search_results[select_id].pop('summ') + + return new_search_results + + @tool_api + def open_url(self, url: str) -> dict: + print(f'Start Browsing: {url}') + web_success, web_content = self.fetcher.fetch(url) + if web_success: + return {'type': 'text', 'content': web_content} + else: + return {'error': web_content} + + +class AsyncWebBrowser(AsyncActionMixin, WebBrowser): + """Wrapper around the Web Browser Tool. + """ + + @tool_api + async def search(self, query: Union[str, List[str]]) -> dict: + """BING search API + + Args: + query (List[str]): list of search query strings + """ + queries = query if isinstance(query, list) else [query] + search_results = {} + + tasks = [] + for q in queries: + task = asyncio.create_task(self.searcher.asearch(q)) + task.query = q + tasks.append(task) + async for future in async_as_completed(tasks): + query = future.query + try: + results = await future + except Exception as exc: + warnings.warn(f'{query} generated an exception: {exc}') + else: + for result in results.values(): + if result['url'] not in search_results: + search_results[result['url']] = result + else: + search_results[ + result['url']]['summ'] += f"\n{result['summ']}" + + self.search_results = { + idx: result + for idx, result in enumerate(search_results.values()) + } + return self.search_results + + @tool_api + async def select(self, select_ids: List[int]) -> dict: + """get the detailed content on the selected pages. + + Args: + select_ids (List[int]): list of index to select. Max number of index to be selected is no more than 4. + """ + if not self.search_results: + raise ValueError('No search results to select from.') + + new_search_results = {} + tasks = [] + for select_id in select_ids: + if select_id in self.search_results: + task = asyncio.create_task( + self.fetcher.afetch(self.search_results[select_id]['url'])) + task.select_id = select_id + tasks.append(task) + async for future in async_as_completed(tasks): + select_id = future.select_id + try: + web_success, web_content = await future + except Exception as exc: + warnings.warn(f'{select_id} generated an exception: {exc}') + else: + if web_success: + self.search_results[select_id][ + 'content'] = web_content[:8192] + new_search_results[select_id] = self.search_results[ + select_id].copy() + new_search_results[select_id].pop('summ') + return new_search_results + + @tool_api + async def open_url(self, url: str) -> dict: + print(f'Start Browsing: {url}') + web_success, web_content = await self.fetcher.afetch(url) + if web_success: + return {'type': 'text', 'content': web_content} + else: + return {'error': web_content} diff --git a/lagent/agents/__init__.py b/lagent/agents/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f06972cc56e13012e8fe54a9fe8764748ae93f43 --- /dev/null +++ b/lagent/agents/__init__.py @@ -0,0 +1,9 @@ +from .agent import Agent, AgentDict, AgentList, AsyncAgent, AsyncSequential, Sequential +from .react import AsyncReAct, ReAct +from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder + +__all__ = [ + 'Agent', 'AgentDict', 'AgentList', 'AsyncAgent', 'AgentForInternLM', + 'AsyncAgentForInternLM', 'MathCoder', 'AsyncMathCoder', 'ReAct', + 'AsyncReAct', 'Sequential', 'AsyncSequential' +] diff --git a/lagent/agents/__pycache__/__init__.cpython-310.pyc b/lagent/agents/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..877f0681e291e3dce4c73f61ba7bc3b3016a3bbb Binary files /dev/null and b/lagent/agents/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/agents/__pycache__/agent.cpython-310.pyc b/lagent/agents/__pycache__/agent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27c032fc6725726693843177f6d287cbf976f1b0 Binary files /dev/null and b/lagent/agents/__pycache__/agent.cpython-310.pyc differ diff --git a/lagent/agents/__pycache__/react.cpython-310.pyc b/lagent/agents/__pycache__/react.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05c5774d794f281f5b2cc4f91b6e0f132fc89563 Binary files /dev/null and b/lagent/agents/__pycache__/react.cpython-310.pyc differ diff --git a/lagent/agents/__pycache__/stream.cpython-310.pyc b/lagent/agents/__pycache__/stream.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cd8aa9919efba5b4aefb53eb3a0c3866f7edd57 Binary files /dev/null and b/lagent/agents/__pycache__/stream.cpython-310.pyc differ diff --git a/lagent/agents/agent.py b/lagent/agents/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..b1e941baa442a52deb37755f64002724316bcf08 --- /dev/null +++ b/lagent/agents/agent.py @@ -0,0 +1,400 @@ +import copy +import warnings +from collections import OrderedDict, UserDict, UserList, abc +from functools import wraps +from itertools import chain, repeat +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union + +from lagent.agents.aggregator import DefaultAggregator +from lagent.hooks import Hook, RemovableHandle +from lagent.llms import BaseLLM +from lagent.memory import Memory, MemoryManager +from lagent.prompts.parsers import StrParser +from lagent.prompts.prompt_template import PromptTemplate +from lagent.schema import AgentMessage +from lagent.utils import create_object + + +class Agent: + """Agent is the basic unit of the system. It is responsible for + communicating with the LLM, managing the memory, and handling the + message aggregation and parsing. It can also be extended with hooks + + Args: + llm (Union[BaseLLM, Dict]): The language model used by the agent. + template (Union[PromptTemplate, str]): The template used to format the + messages. + memory (Dict): The memory used by the agent. + output_format (Dict): The output format used by the agent. + aggregator (Dict): The aggregator used by the agent. + name (Optional[str]): The name of the agent. + description (Optional[str]): The description of the agent. + hooks (Optional[Union[List[Dict], Dict]]): The hooks used by the agent. + + Returns: + AgentMessage: The response message. + """ + + def __init__( + self, + llm: Union[BaseLLM, Dict] = None, + template: Union[PromptTemplate, str, dict, List[dict]] = None, + memory: Dict = dict(type=Memory), + output_format: Optional[Dict] = None, + aggregator: Dict = dict(type=DefaultAggregator), + name: Optional[str] = None, + description: Optional[str] = None, + hooks: Optional[Union[List[Dict], Dict]] = None, + ): + self.name = name or self.__class__.__name__ + self.llm: BaseLLM = create_object(llm) + self.memory: MemoryManager = MemoryManager(memory) if memory else None + self.output_format: StrParser = create_object(output_format) + self.template = template + self.description = description + self.aggregator: DefaultAggregator = create_object(aggregator) + self._hooks: Dict[int, Hook] = OrderedDict() + if hooks: + for hook in hooks: + hook = create_object(hook) + self.register_hook(hook) + + def update_memory(self, message, session_id=0): + if self.memory: + self.memory.add(message, session_id=session_id) + + def __call__( + self, + *message: Union[str, AgentMessage, List[AgentMessage]], + session_id=0, + **kwargs, + ) -> AgentMessage: + # message.receiver = self.name + message = [ + AgentMessage(sender='user', content=m) + if isinstance(m, str) else copy.deepcopy(m) for m in message + ] + for hook in self._hooks.values(): + result = hook.before_agent(self, message, session_id) + if result: + message = result + self.update_memory(message, session_id=session_id) + response_message = self.forward( + *message, session_id=session_id, **kwargs) + if not isinstance(response_message, AgentMessage): + response_message = AgentMessage( + sender=self.name, + content=response_message, + ) + self.update_memory(response_message, session_id=session_id) + response_message = copy.deepcopy(response_message) + for hook in self._hooks.values(): + result = hook.after_agent(self, response_message, session_id) + if result: + response_message = result + return response_message + + def forward(self, + *message: AgentMessage, + session_id=0, + **kwargs) -> Union[AgentMessage, str]: + formatted_messages = self.aggregator.aggregate( + self.memory.get(session_id), + self.name, + self.output_format, + self.template, + ) + llm_response = self.llm.chat(formatted_messages, **kwargs) + if self.output_format: + formatted_messages = self.output_format.parse_response( + llm_response) + return AgentMessage( + sender=self.name, + content=llm_response, + formatted=formatted_messages, + ) + return llm_response + + def __setattr__(self, __name: str, __value: Any) -> None: + if isinstance(__value, Agent): + _agents = getattr(self, '_agents', OrderedDict()) + _agents[__name] = __value + super().__setattr__('_agents', _agents) + super().__setattr__(__name, __value) + + def state_dict(self, session_id=0): + state_dict, stack = {}, [('', self)] + while stack: + prefix, node = stack.pop() + key = prefix + 'memory' + if node.memory is not None: + if session_id not in node.memory.memory_map: + warnings.warn(f'No session id {session_id} in {key}') + memory = node.memory.get(session_id) + state_dict[key] = memory and memory.save() or [] + if hasattr(node, '_agents'): + for name, value in reversed(node._agents.items()): + stack.append((prefix + name + '.', value)) + return state_dict + + def load_state_dict(self, state_dict: Dict, session_id=0): + _state_dict = self.state_dict() + missing_keys = set(_state_dict) - set(state_dict) + if missing_keys: + raise KeyError(f'Missing keys: {missing_keys}') + extra_keys = set(state_dict) - set(_state_dict) + if extra_keys: + warnings.warn(f'Mismatch keys which are not used: {extra_keys}') + for key in _state_dict: + obj = self + for attr in key.split('.')[:-1]: + if isinstance(obj, AgentList): + assert attr.isdigit() + obj = obj[int(attr)] + elif isinstance(obj, AgentDict): + obj = obj[attr] + else: + obj = getattr(obj, attr) + if obj.memory is not None: + if session_id not in obj.memory.memory_map: + obj.memory.create_instance(session_id) + obj.memory.memory_map[session_id].load(state_dict[key] or []) + + def register_hook(self, hook: Callable): + handle = RemovableHandle(self._hooks) + self._hooks[handle.id] = hook + return handle + + def reset(self, + session_id=0, + keypath: Optional[str] = None, + recursive: bool = False): + assert not (keypath and + recursive), 'keypath and recursive can\'t be used together' + if keypath: + keys, agent = keypath.split('.'), self + for key in keys: + agents = getattr(agent, '_agents', {}) + if key not in agents: + raise KeyError(f'No sub-agent named {key} in {agent}') + agent = agents[key] + agent.reset(session_id, recursive=False) + else: + if self.memory: + self.memory.reset(session_id=session_id) + if recursive: + for agent in getattr(self, '_agents', {}).values(): + agent.reset(session_id, recursive=True) + + def __repr__(self): + + def _rcsv_repr(agent, n_indent=1): + res = agent.__class__.__name__ + (f"(name='{agent.name}')" + if agent.name else '') + modules = [ + f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}" + for name, agent in getattr(agent, '_agents', {}).items() + ] + if modules: + res += '(\n' + '\n'.join( + modules) + f'\n{(n_indent - 1) * " "})' + elif not res.endswith(')'): + res += '()' + return res + + return _rcsv_repr(self) + + +class AsyncAgent(Agent): + + async def __call__(self, + *message: AgentMessage | List[AgentMessage], + session_id=0, + **kwargs) -> AgentMessage: + message = [ + AgentMessage(sender='user', content=m) + if isinstance(m, str) else copy.deepcopy(m) for m in message + ] + for hook in self._hooks.values(): + result = hook.before_agent(self, message, session_id) + if result: + message = result + self.update_memory(message, session_id=session_id) + response_message = await self.forward( + *message, session_id=session_id, **kwargs) + if not isinstance(response_message, AgentMessage): + response_message = AgentMessage( + sender=self.name, + content=response_message, + ) + self.update_memory(response_message, session_id=session_id) + response_message = copy.deepcopy(response_message) + for hook in self._hooks.values(): + result = hook.after_agent(self, response_message, session_id) + if result: + response_message = result + return response_message + + async def forward(self, + *message: AgentMessage, + session_id=0, + **kwargs) -> Union[AgentMessage, str]: + formatted_messages = self.aggregator.aggregate( + self.memory.get(session_id), + self.name, + self.output_format, + self.template, + ) + llm_response = await self.llm.chat(formatted_messages, session_id, + **kwargs) + if self.output_format: + formatted_messages = self.output_format.parse_response( + llm_response) + return AgentMessage( + sender=self.name, + content=llm_response, + formatted=formatted_messages, + ) + return llm_response + + +class Sequential(Agent): + """Sequential is an agent container that forwards messages to each agent + in the order they are added.""" + + def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs): + super().__init__(**kwargs) + self._agents = OrderedDict() + if not agents: + raise ValueError('At least one agent should be provided') + if isinstance(agents[0], + Iterable) and not isinstance(agents[0], Agent): + if not agents[0]: + raise ValueError('At least one agent should be provided') + agents = agents[0] + for key, agent in enumerate(agents): + if isinstance(agents, Mapping): + key, agent = agent, agents[agent] + elif isinstance(agent, tuple): + key, agent = agent + self.add_agent(key, agent) + + def add_agent(self, name: str, agent: Union[Agent, AsyncAgent]): + assert isinstance( + agent, (Agent, AsyncAgent + )), f'{type(agent)} is not an Agent or AsyncAgent subclass' + self._agents[str(name)] = agent + + def forward(self, + *message: AgentMessage, + session_id=0, + exit_at: Optional[int] = None, + **kwargs) -> AgentMessage: + assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' + if exit_at is None: + exit_at = len(self) - 1 + iterator = chain.from_iterable(repeat(self._agents.values())) + for _ in range(exit_at + 1): + agent = next(iterator) + if isinstance(message, AgentMessage): + message = (message, ) + message = agent(*message, session_id=session_id, **kwargs) + return message + + def __getitem__(self, key): + if isinstance(key, int) and key < 0: + assert key >= -len(self), 'index out of range' + key = len(self) + key + return self._agents[str(key)] + + def __len__(self): + return len(self._agents) + + +class AsyncSequential(Sequential, AsyncAgent): + + async def forward(self, + *message: AgentMessage, + session_id=0, + exit_at: Optional[int] = None, + **kwargs) -> AgentMessage: + assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' + if exit_at is None: + exit_at = len(self) - 1 + iterator = chain.from_iterable(repeat(self._agents.values())) + for _ in range(exit_at + 1): + agent = next(iterator) + if isinstance(message, AgentMessage): + message = (message, ) + message = await agent(*message, session_id=session_id, **kwargs) + return message + + +class AgentContainerMixin: + + def __init_subclass__(cls): + super().__init_subclass__() + + def wrap_api(func): + + @wraps(func) + def wrapped_func(self, *args, **kwargs): + data = self.data.copy() if hasattr(self, 'data') else None + + def _backup(d): + if d is None: + self.data.clear() + else: + self.data = d + + ret = func(self, *args, **kwargs) + agents = OrderedDict() + for k, item in (self.data.items() if isinstance( + self.data, abc.Mapping) else enumerate(self.data)): + if isinstance(self.data, + abc.Mapping) and not isinstance(k, str): + _backup(data) + raise KeyError( + f'agent name should be a string, got {type(k)}') + if isinstance(k, str) and '.' in k: + _backup(data) + raise KeyError( + f'agent name can\'t contain ".", got {k}') + if not isinstance(item, (Agent, AsyncAgent)): + _backup(data) + raise TypeError( + f'{type(item)} is not an Agent or AsyncAgent subclass' + ) + agents[str(k)] = item + self._agents = agents + return ret + + return wrapped_func + + for method in [ + 'append', 'sort', 'reverse', 'pop', 'clear', 'update', + 'insert', 'extend', 'remove', '__init__', '__setitem__', + '__delitem__', '__add__', '__iadd__', '__radd__', '__mul__', + '__imul__', '__rmul__' + ]: + if hasattr(cls, method): + setattr(cls, method, wrap_api(getattr(cls, method))) + + +class AgentList(Agent, UserList, AgentContainerMixin): + + def __init__(self, + agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None): + Agent.__init__(self, memory=None) + UserList.__init__(self, agents) + self.name = None + + +class AgentDict(Agent, UserDict, AgentContainerMixin): + + def __init__(self, + agents: Optional[Mapping[str, Union[Agent, + AsyncAgent]]] = None): + Agent.__init__(self, memory=None) + UserDict.__init__(self, agents) + self.name = None diff --git a/lagent/agents/aggregator/__init__.py b/lagent/agents/aggregator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9065ccf6f783f2cd85c427b9371e85e6699776 --- /dev/null +++ b/lagent/agents/aggregator/__init__.py @@ -0,0 +1,4 @@ +from .default_aggregator import DefaultAggregator +from .tool_aggregator import InternLMToolAggregator + +__all__ = ['DefaultAggregator', 'InternLMToolAggregator'] diff --git a/lagent/agents/aggregator/__pycache__/__init__.cpython-310.pyc b/lagent/agents/aggregator/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4933129dd3e2a01cf121a0cfa8a212439a5075ca Binary files /dev/null and b/lagent/agents/aggregator/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/agents/aggregator/__pycache__/default_aggregator.cpython-310.pyc b/lagent/agents/aggregator/__pycache__/default_aggregator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f81e969a25aa20e7d1c61c2d3b59f2ed0bc8939 Binary files /dev/null and b/lagent/agents/aggregator/__pycache__/default_aggregator.cpython-310.pyc differ diff --git a/lagent/agents/aggregator/__pycache__/tool_aggregator.cpython-310.pyc b/lagent/agents/aggregator/__pycache__/tool_aggregator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2936bcb95d458d54a434b665c6d80b325615170 Binary files /dev/null and b/lagent/agents/aggregator/__pycache__/tool_aggregator.cpython-310.pyc differ diff --git a/lagent/agents/aggregator/default_aggregator.py b/lagent/agents/aggregator/default_aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..0888aef54b79e4d28753dd5bc1186f412964408e --- /dev/null +++ b/lagent/agents/aggregator/default_aggregator.py @@ -0,0 +1,44 @@ +from typing import Dict, List + +from lagent.memory import Memory +from lagent.prompts import StrParser + + +class DefaultAggregator: + + def aggregate(self, + messages: Memory, + name: str, + parser: StrParser = None, + system_instruction: str = None) -> List[Dict[str, str]]: + _message = [] + messages = messages.get_memory() + if system_instruction: + _message.extend( + self.aggregate_system_intruction(system_instruction)) + for message in messages: + if message.sender == name: + _message.append( + dict(role='assistant', content=str(message.content))) + else: + user_message = message.content + if len(_message) > 0 and _message[-1]['role'] == 'user': + _message[-1]['content'] += user_message + else: + _message.append(dict(role='user', content=user_message)) + return _message + + @staticmethod + def aggregate_system_intruction(system_intruction) -> List[dict]: + if isinstance(system_intruction, str): + system_intruction = dict(role='system', content=system_intruction) + if isinstance(system_intruction, dict): + system_intruction = [system_intruction] + if isinstance(system_intruction, list): + for msg in system_intruction: + if not isinstance(msg, dict): + raise TypeError(f'Unsupported message type: {type(msg)}') + if not ('role' in msg and 'content' in msg): + raise KeyError( + f"Missing required key 'role' or 'content': {msg}") + return system_intruction diff --git a/lagent/agents/aggregator/tool_aggregator.py b/lagent/agents/aggregator/tool_aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..0ea473794b897690d67019aabfab24c89a209105 --- /dev/null +++ b/lagent/agents/aggregator/tool_aggregator.py @@ -0,0 +1,106 @@ +from typing import Dict, List, Optional, Union + +from lagent.agents.aggregator.default_aggregator import DefaultAggregator +from lagent.memory.base_memory import Memory +from lagent.prompts.parsers.tool_parser import MixedToolParser, ToolParser, ToolStatusCode + + +class InternLMToolAggregator(DefaultAggregator): + + def __init__(self, + environment_role='environment', + environment_begin='', + environment_end='', + user_names: Optional[List[str]] = None, + few_shot: Optional[List[List[dict]]] = None): + self.environment_role = environment_role + self.environment_begin = environment_begin + self.environment_end = environment_end + self.user_names = user_names or ['user'] + self.few_shot = few_shot or [] + + def aggregate(self, + messages: Memory, + name: str, + parser: Union[ToolParser, MixedToolParser], + system_instruction: str = None) -> List[Dict[str, str]]: + _message = [] + messages = messages.get_memory() + if system_instruction: + _message.extend( + self.aggregate_system_intruction(system_instruction)) + tool_instruction = parser.format_instruction() + if tool_instruction: + if isinstance(tool_instruction, str): + tool_instruction = dict( + role='system', content=tool_instruction) + if parser.tool_type: + tool_instruction['name'] = parser.tool_type + if isinstance(tool_instruction, dict): + tool_instruction = [tool_instruction] + _message.extend(tool_instruction) + + for shot in self.few_shot: + i = 0 + while i < len(shot): + msg = shot[i] + if msg['role'] in ['assistant', 'user', 'system']: + _message.append(msg) + elif msg['role'] == self.environment_role: + if not msg['content'].startswith(self.environment_begin): + msg['content'] = self.environment_begin + msg['content'] + if not msg['content'].endswith(self.environment_end): + msg['content'] += self.environment_end + _message.append(msg) + elif msg['role'] in ['thought', 'language']: + if i < len(shot) - 1 and shot[i + 1]['role'] == 'tool': + _message.append( + dict( + role='assistant', + content=parser.format_response( + dict( + tool_type=shot[i + 1]['name'], + thought=msg['content'], + action=shot[i + 1]['content'], + status=None)))) + i += 1 + else: + _message.append( + dict( + role='assistant', + content=parser.format_response( + dict( + tool_type=None, + thought=msg['content'], + action=None, + status=None)))) + else: + raise KeyError(f'Unkown role: {msg["role"]}') + i += 1 + + tool_type = None + for message in messages: + if message.sender == name: + if isinstance(message.formatted, dict): + parsed = message.formatted + if parsed['status'] == ToolStatusCode.PARSING_ERROR: + continue + _message.append( + dict( + role='assistant', + content=parser.format_response(parsed))) + tool_type = parsed['tool_type'] + else: + _message.append( + dict(role='assistant', content=str(message.content))) + elif message.sender in self.user_names: + _message.append(dict(role='user', content=message.content)) + else: + msg = dict( + role=self.environment_role, + content=self.environment_begin + str(message.content) + + self.environment_end) + if tool_type: + msg['name'] = tool_type + _message.append(msg) + return _message diff --git a/lagent/agents/react.py b/lagent/agents/react.py new file mode 100644 index 0000000000000000000000000000000000000000..41d2414d0f1d15066aba5f56cae9afd9c9140c7c --- /dev/null +++ b/lagent/agents/react.py @@ -0,0 +1,161 @@ +import json +from typing import Callable, Dict, List, Union + +from pydantic import BaseModel, Field + +from lagent.actions import ActionExecutor, AsyncActionExecutor, BaseAction +from lagent.agents.agent import Agent, AsyncAgent +from lagent.agents.aggregator import DefaultAggregator +from lagent.hooks import ActionPreprocessor +from lagent.llms import BaseLLM +from lagent.memory import Memory +from lagent.prompts.parsers.json_parser import JSONParser +from lagent.prompts.prompt_template import PromptTemplate +from lagent.schema import AgentMessage +from lagent.utils import create_object + +select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括: +{action_info} +{output_format} +开始!""" + +output_format_template = """如果使用工具请遵循以下格式回复: +{function_format} + +如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复 +{finish_format}""" + + +class ReAct(Agent): + + def __init__(self, + llm: Union[BaseLLM, Dict], + actions: Union[BaseAction, List[BaseAction]], + template: Union[PromptTemplate, str] = None, + memory: Dict = dict(type=Memory), + output_format: Dict = dict(type=JSONParser), + aggregator: Dict = dict(type=DefaultAggregator), + hooks: List = [dict(type=ActionPreprocessor)], + finish_condition: Callable[[AgentMessage], bool] = lambda m: + 'conclusion' in m.content or 'conclusion' in m.formatted, + max_turn: int = 5, + **kwargs): + self.max_turn = max_turn + self.finish_condition = finish_condition + actions = dict( + type=ActionExecutor, + actions=actions, + hooks=hooks, + ) + self.actions: ActionExecutor = create_object(actions) + select_agent = dict( + type=Agent, + llm=llm, + template=template.format( + action_info=json.dumps(self.actions.description()), + output_format=output_format.format_instruction()), + output_format=output_format, + memory=memory, + aggregator=aggregator, + hooks=hooks, + ) + self.select_agent = create_object(select_agent) + super().__init__(**kwargs) + + def forward(self, message: AgentMessage, **kwargs) -> AgentMessage: + for _ in range(self.max_turn): + message = self.select_agent(message) + if self.finish_condition(message): + return message + message = self.actions(message) + return message + + +class AsyncReAct(AsyncAgent): + + def __init__(self, + llm: Union[BaseLLM, Dict], + actions: Union[BaseAction, List[BaseAction]], + template: Union[PromptTemplate, str] = None, + memory: Dict = dict(type=Memory), + output_format: Dict = dict(type=JSONParser), + aggregator: Dict = dict(type=DefaultAggregator), + hooks: List = [dict(type=ActionPreprocessor)], + finish_condition: Callable[[AgentMessage], bool] = lambda m: + 'conclusion' in m.content or 'conclusion' in m.formatted, + max_turn: int = 5, + **kwargs): + self.max_turn = max_turn + self.finish_condition = finish_condition + actions = dict( + type=AsyncActionExecutor, + actions=actions, + hooks=hooks, + ) + self.actions: AsyncActionExecutor = create_object(actions) + select_agent = dict( + type=AsyncAgent, + llm=llm, + template=template.format( + action_info=json.dumps(self.actions.description()), + output_format=output_format.format_instruction()), + output_format=output_format, + memory=memory, + aggregator=aggregator, + hooks=hooks, + ) + self.select_agent = create_object(select_agent) + super().__init__(**kwargs) + + async def forward(self, message: AgentMessage, **kwargs) -> AgentMessage: + for _ in range(self.max_turn): + message = await self.select_agent(message) + if self.finish_condition(message): + return message + message = await self.actions(message) + return message + + +if __name__ == '__main__': + from lagent.llms import GPTAPI + + class ActionCall(BaseModel): + name: str = Field(description='调用的函数名称') + parameters: Dict = Field(description='调用函数的参数') + + class ActionFormat(BaseModel): + thought_process: str = Field( + description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。') + action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名称和参数。') + + class FinishFormat(BaseModel): + thought_process: str = Field( + description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。') + conclusion: str = Field(description='总结当前的搜索结果,回答问题。') + + prompt_template = PromptTemplate(select_action_template) + output_format = JSONParser( + output_format_template, + function_format=ActionFormat, + finish_format=FinishFormat) + + llm = dict( + type=GPTAPI, + model_type='gpt-4o-2024-05-13', + key=None, + max_new_tokens=4096, + proxies=dict(), + retry=1000) + + agent = ReAct( + llm=llm, + template=prompt_template, + output_format=output_format, + aggregator=dict(type='DefaultAggregator'), + actions=[dict(type='PythonInterpreter')], + ) + response = agent( + AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5')) + print(response) + response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢')) + print(response) diff --git a/lagent/agents/stream.py b/lagent/agents/stream.py new file mode 100644 index 0000000000000000000000000000000000000000..512250ff02c7dd3f09dd844e999e343b597feab8 --- /dev/null +++ b/lagent/agents/stream.py @@ -0,0 +1,316 @@ +import json +import warnings +from copy import deepcopy +from typing import Callable, Dict, List, Union + +from lagent.actions import ActionExecutor, AsyncActionExecutor, AsyncIPythonInterpreter, IPythonInteractive +from lagent.agents.agent import Agent, AsyncAgent +from lagent.agents.aggregator import InternLMToolAggregator +from lagent.hooks import InternLMActionProcessor +from lagent.llms import BaseLLM +from lagent.memory import Memory +from lagent.prompts.parsers import InterpreterParser, MixedToolParser, PluginParser, ToolStatusCode +from lagent.schema import AgentMessage +from lagent.utils import create_object + +API_PREFIX = ( + "This is the subfunction for tool '{tool_name}', you can use this tool. " + 'The description of this function is: \n{description}') + +META_CN = ('当开启工具以及代码时,根据需求选择合适的工具进行调用') + +INTERPRETER_CN = ('你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。' + '当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。' + '这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),' + '复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),' + '文本处理和分析(比如文本解析和自然语言处理),' + '机器学习和数据科学(用于展示模型训练和数据可视化),' + '以及文件操作和数据导入(处理CSV、JSON等格式的文件)。') + +PLUGIN_CN = ('你可以使用如下工具:' + '\n{prompt}\n' + '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! ' + '同时注意你可以使用的工具,不要随意捏造!') + + +def get_plugin_prompt(actions, api_desc_template=API_PREFIX): + plugin_descriptions = [] + for action in actions if isinstance(actions, list) else [actions]: + action = create_object(action) + action_desc = deepcopy(action.description) + if action.is_toolkit: + for api in action_desc['api_list']: + api['name'] = f"{action.name}.{api['name']}" + api['description'] = api_desc_template.format( + tool_name=action.name, description=api['description']) + api['parameters'] = [ + param for param in api['parameters'] + if param['name'] in api['required'] + ] + plugin_descriptions.append(api) + else: + action_desc['description'] = api_desc_template.format( + tool_name=action.name, description=action_desc['description']) + action_desc['parameters'] = [ + param for param in action_desc['parameters'] + if param['name'] in action_desc['required'] + ] + plugin_descriptions.append(action_desc) + return json.dumps(plugin_descriptions, ensure_ascii=False, indent=4) + + +class AgentForInternLM(Agent): + + _INTERNAL_AGENT_CLS = Agent + + def __init__( + self, + llm: Union[BaseLLM, Dict], + plugins: Union[dict, List[dict]] = None, + interpreter: dict = None, + template: Union[str, dict, List[dict]] = None, + memory: Dict = dict(type=Memory), + output_format: Dict = dict( + type=MixedToolParser, + template=META_CN, + parsers=[ + dict(type=PluginParser, template=PLUGIN_CN), + dict(type=InterpreterParser, template=INTERPRETER_CN), + ]), + aggregator: Dict = dict(type=InternLMToolAggregator), + action_hooks: List = [dict(type=InternLMActionProcessor)], + finish_condition: Callable[ + [AgentMessage], + bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, + max_turn: int = 4, + **kwargs, + ): + agent = dict( + type=self._INTERNAL_AGENT_CLS, + llm=llm, + template=template, + output_format=output_format, + memory=memory, + aggregator=aggregator, + hooks=kwargs.pop('hooks', None), + ) + self.agent = create_object(agent) + self.plugin_executor = plugins and ActionExecutor( + plugins, hooks=action_hooks) + self.interpreter_executor = interpreter and ActionExecutor( + interpreter, hooks=action_hooks) + if not (self.plugin_executor or self.interpreter_executor): + warnings.warn( + 'Neither plugin nor interpreter executor is initialized. ' + 'An exception will be thrown when the agent call a tool.') + self.finish_condition = finish_condition + self.max_turn = max_turn + super().__init__(**kwargs) + + def forward(self, message: AgentMessage, session_id=0, **kwargs): + if isinstance(message, str): + message = AgentMessage(sender='user', content=message) + for _ in range(self.max_turn): + message = self.agent(message, session_id=session_id, **kwargs) + assert isinstance(message.formatted, dict) + if self.finish_condition(message): + return message + if message.formatted['tool_type']: + tool_type = message.formatted["tool_type"] + executor = getattr(self, f'{tool_type}_executor', None) + if not executor: + raise RuntimeError(f'No available {tool_type} executor') + message = executor(message, session_id=session_id) + return message + + def get_steps(self, session_id=0): + steps, tool_type = [], None + for msg in self.agent.memory.get_memory(session_id): + if msg.sender == self.agent.name: + steps.append( + dict(role='thought', content=msg.formatted['thought'])) + if msg.formatted['tool_type']: + tool_type = msg.formatted['tool_type'] + steps.append( + dict( + role='tool', + content=msg.formatted['action'], + name=tool_type)) + elif msg.sender != 'user': + feedback = dict(role='environment', content=msg.content) + if tool_type: + feedback['name'] = tool_type + steps.append(feedback) + return steps + + +class MathCoder(AgentForInternLM): + + def __init__( + self, + llm: Union[BaseLLM, Dict], + interpreter: dict = dict( + type=IPythonInteractive, timeout=20, max_out_len=8192), + template: Union[str, dict, List[dict]] = None, + memory: Dict = dict(type=Memory), + output_format: Dict = dict( + type=InterpreterParser, + template= + ('Integrate step-by-step reasoning and Python code to solve math problems ' + 'using the following guidelines:\n' + '- Analyze the question and write jupyter code to solve the problem;\n' + r"- Present the final result in LaTeX using a '\boxed{{}}' without any " + 'units. \n')), + aggregator: Dict = dict(type=InternLMToolAggregator), + action_hooks: List = [dict(type=InternLMActionProcessor)], + finish_condition: Callable[ + [AgentMessage], + bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, + max_turn: int = 6, + **kwargs, + ): + kwargs.pop('plugins', None) + super().__init__( + llm=llm, + interpreter=interpreter, + template=template, + memory=memory, + output_format=output_format, + aggregator=aggregator, + action_hooks=action_hooks, + finish_condition=finish_condition, + max_turn=max_turn, + **kwargs) + + +class AsyncAgentForInternLM(AsyncAgent): + + _INTERNAL_AGENT_CLS = AsyncAgent + + def __init__( + self, + llm: Union[BaseLLM, Dict], + plugins: Union[dict, List[dict]] = None, + interpreter: dict = None, + template: Union[str, dict, List[dict]] = None, + memory: Dict = dict(type=Memory), + output_format: Dict = dict( + type=MixedToolParser, + template=META_CN, + parsers=[ + dict(type=PluginParser, template=PLUGIN_CN), + dict(type=InterpreterParser, template=INTERPRETER_CN), + ]), + aggregator: Dict = dict(type=InternLMToolAggregator), + action_hooks: List = [dict(type=InternLMActionProcessor)], + finish_condition: Callable[ + [AgentMessage], + bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, + max_turn: int = 4, + **kwargs, + ): + agent = dict( + type=self._INTERNAL_AGENT_CLS, + llm=llm, + template=template, + output_format=output_format, + memory=memory, + aggregator=aggregator, + hooks=kwargs.pop('hooks', None), + ) + self.agent = create_object(agent) + self.plugin_executor = plugins and AsyncActionExecutor( + plugins, hooks=action_hooks) + self.interpreter_executor = interpreter and AsyncActionExecutor( + interpreter, hooks=action_hooks) + if not (self.plugin_executor or self.interpreter_executor): + warnings.warn( + 'Neither plugin nor interpreter executor is initialized. ' + 'An exception will be thrown when the agent call a tool.') + self.finish_condition = finish_condition + self.max_turn = max_turn + super().__init__(**kwargs) + + async def forward(self, message: AgentMessage, session_id=0, **kwargs): + if isinstance(message, str): + message = AgentMessage(sender='user', content=message) + for _ in range(self.max_turn): + message = await self.agent( + message, session_id=session_id, **kwargs) + assert isinstance(message.formatted, dict) + if self.finish_condition(message): + return message + if message.formatted['tool_type']: + tool_type = message.formatted["tool_type"] + executor = getattr(self, f'{tool_type}_executor', None) + if not executor: + raise RuntimeError(f'No available {tool_type} executor') + message = await executor(message, session_id=session_id) + return message + + def get_steps(self, session_id=0): + steps, tool_type = [], None + for msg in self.agent.memory.get_memory(session_id): + if msg.sender == self.agent.name: + steps.append( + dict(role='thought', content=msg.formatted['thought'])) + if msg.formatted['tool_type']: + tool_type = msg.formatted['tool_type'] + steps.append( + dict( + role='tool', + content=msg.formatted['action'], + name=tool_type)) + elif msg.sender != 'user': + feedback = dict(role='environment', content=msg.content) + if tool_type: + feedback['name'] = tool_type + steps.append(feedback) + return steps + + +class AsyncMathCoder(AsyncAgentForInternLM): + + def __init__( + self, + llm: Union[BaseLLM, Dict], + interpreter: dict = dict(type=AsyncIPythonInterpreter), + template: Union[str, dict, List[dict]] = None, + memory: Dict = dict(type=Memory), + output_format: Dict = dict( + type=InterpreterParser, + template= + ('Integrate step-by-step reasoning and Python code to solve math problems ' + 'using the following guidelines:\n' + '- Analyze the question and write jupyter code to solve the problem;\n' + r"- Present the final result in LaTeX using a '\boxed{{}}' without any " + 'units. \n')), + aggregator: Dict = dict(type=InternLMToolAggregator), + action_hooks: List = [dict(type=InternLMActionProcessor)], + finish_condition: Callable[ + [AgentMessage], + bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, + max_turn: int = 6, + **kwargs, + ): + kwargs.pop('plugins', None) + super().__init__( + llm=llm, + interpreter=interpreter, + template=template, + memory=memory, + output_format=output_format, + aggregator=aggregator, + action_hooks=action_hooks, + finish_condition=finish_condition, + max_turn=max_turn, + **kwargs) + + async def forward(self, message: AgentMessage, session_id=0, **kwargs): + try: + return await super().forward(message, session_id, **kwargs) + finally: + interpreter = next( + iter(self.interpreter_executor.actions.values())) + if interpreter.name == 'AsyncIPythonInterpreter': + await interpreter.close_session(session_id) diff --git a/lagent/distributed/__init__.py b/lagent/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..98ad5210225eaa60ff2e2f361950af365b487e47 --- /dev/null +++ b/lagent/distributed/__init__.py @@ -0,0 +1,8 @@ +from .http_serve import AgentAPIServer, AsyncHTTPAgentClient, AsyncHTTPAgentServer, HTTPAgentClient, HTTPAgentServer +from .ray_serve import AgentRayActor, AsyncAgentRayActor + +__all__ = [ + 'AsyncAgentRayActor', 'AgentRayActor', 'HTTPAgentServer', + 'HTTPAgentClient', 'AsyncHTTPAgentServer', 'AsyncHTTPAgentClient', + 'AgentAPIServer' +] diff --git a/lagent/distributed/http_serve/__init__.py b/lagent/distributed/http_serve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a7082f30d659d980cc60d096f14cb4573a69debf --- /dev/null +++ b/lagent/distributed/http_serve/__init__.py @@ -0,0 +1,7 @@ +from .api_server import AsyncHTTPAgentClient, AsyncHTTPAgentServer, HTTPAgentClient, HTTPAgentServer +from .app import AgentAPIServer + +__all__ = [ + 'HTTPAgentServer', 'HTTPAgentClient', 'AsyncHTTPAgentClient', + 'AsyncHTTPAgentServer', 'AgentAPIServer' +] diff --git a/lagent/distributed/http_serve/api_server.py b/lagent/distributed/http_serve/api_server.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb6907ab4ce1bc3b73dbc74898b29cbd3f8c6f8 --- /dev/null +++ b/lagent/distributed/http_serve/api_server.py @@ -0,0 +1,123 @@ +import json +import os +import subprocess +import sys +import time + +import aiohttp +import requests + +from lagent.schema import AgentMessage + + +class HTTPAgentClient: + + def __init__(self, host='127.0.0.1', port=8090, timeout=None): + self.host = host + self.port = port + self.timeout = timeout + + @property + def is_alive(self): + try: + resp = requests.get( + f'http://{self.host}:{self.port}/health_check', + timeout=self.timeout) + return resp.status_code == 200 + except: + return False + + def __call__(self, *message, session_id: int = 0, **kwargs): + response = requests.post( + f'http://{self.host}:{self.port}/chat_completion', + json={ + 'message': [ + m if isinstance(m, str) else m.model_dump() + for m in message + ], + 'session_id': session_id, + **kwargs, + }, + headers={'Content-Type': 'application/json'}, + timeout=self.timeout) + resp = response.json() + if response.status_code != 200: + return resp + return AgentMessage.model_validate(resp) + + def state_dict(self, session_id: int = 0): + resp = requests.get( + f'http://{self.host}:{self.port}/memory/{session_id}', + timeout=self.timeout) + return resp.json() + + +class HTTPAgentServer(HTTPAgentClient): + + def __init__(self, gpu_id, config, host='127.0.0.1', port=8090): + super().__init__(host, port) + self.gpu_id = gpu_id + self.config = config + self.start_server() + + def start_server(self): + # set CUDA_VISIBLE_DEVICES in subprocess + env = os.environ.copy() + env['CUDA_VISIBLE_DEVICES'] = self.gpu_id + cmds = [ + sys.executable, 'lagent/distributed/http_serve/app.py', '--host', + self.host, '--port', + str(self.port), '--config', + json.dumps(self.config) + ] + self.process = subprocess.Popen( + cmds, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True) + + while True: + output = self.process.stdout.readline() + if not output: # 如果读到 EOF,跳出循环 + break + sys.stdout.write(output) # 打印到标准输出 + sys.stdout.flush() + if 'Uvicorn running on' in output: # 根据实际输出调整 + break + time.sleep(0.1) + + def shutdown(self): + self.process.terminate() + self.process.wait() + + +class AsyncHTTPAgentMixin: + + async def __call__(self, *message, session_id: int = 0, **kwargs): + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(self.timeout)) as session: + async with session.post( + f'http://{self.host}:{self.port}/chat_completion', + json={ + 'message': [ + m if isinstance(m, str) else m.model_dump() + for m in message + ], + 'session_id': session_id, + **kwargs, + }, + headers={'Content-Type': 'application/json'}, + ) as response: + resp = await response.json() + if response.status != 200: + return resp + return AgentMessage.model_validate(resp) + + +class AsyncHTTPAgentClient(AsyncHTTPAgentMixin, HTTPAgentClient): + pass + + +class AsyncHTTPAgentServer(AsyncHTTPAgentMixin, HTTPAgentServer): + pass diff --git a/lagent/distributed/http_serve/app.py b/lagent/distributed/http_serve/app.py new file mode 100644 index 0000000000000000000000000000000000000000..7d05660a8b7677c842e6107045b714407fa1c4f4 --- /dev/null +++ b/lagent/distributed/http_serve/app.py @@ -0,0 +1,96 @@ +import argparse +import json +import logging +import time + +import uvicorn +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.requests import Request + +from lagent.schema import AgentMessage +from lagent.utils import load_class_from_string + + +class AgentAPIServer: + + def __init__(self, + config: dict, + host: str = '127.0.0.1', + port: int = 8090): + self.app = FastAPI(docs_url='/') + self.app.add_middleware( + CORSMiddleware, + allow_origins=['*'], + allow_credentials=True, + allow_methods=['*'], + allow_headers=['*'], + ) + cls_name = config.pop('type') + python_path = config.pop('python_path', None) + cls_name = load_class_from_string(cls_name, python_path) if isinstance( + cls_name, str) else cls_name + self.agent = cls_name(**config) + self.setup_routes() + self.run(host, port) + + def setup_routes(self): + + def heartbeat(): + return {'status': 'success', 'timestamp': time.time()} + + async def process_message(request: Request): + try: + body = await request.json() + message = [ + m if isinstance(m, str) else AgentMessage.model_validate(m) + for m in body.pop('message') + ] + result = await self.agent(*message, **body) + return result + except Exception as e: + logging.error(f'Error processing message: {str(e)}') + raise HTTPException( + status_code=500, detail='Internal Server Error') + + def get_memory(session_id: int = 0): + try: + result = self.agent.state_dict(session_id) + return result + except KeyError: + raise HTTPException( + status_code=404, detail="Session ID not found") + except Exception as e: + logging.error(f'Error processing message: {str(e)}') + raise HTTPException( + status_code=500, detail='Internal Server Error') + + self.app.add_api_route('/health_check', heartbeat, methods=['GET']) + self.app.add_api_route( + '/chat_completion', process_message, methods=['POST']) + self.app.add_api_route( + '/memory/{session_id}', get_memory, methods=['GET']) + + def run(self, host='127.0.0.1', port=8090): + logging.info(f'Starting server at {host}:{port}') + uvicorn.run(self.app, host=host, port=port) + + +def parse_args(): + parser = argparse.ArgumentParser(description='Async Agent API Server') + parser.add_argument('--host', type=str, default='127.0.0.1') + parser.add_argument('--port', type=int, default=8090) + parser.add_argument( + '--config', + type=json.loads, + required=True, + help='JSON configuration for the agent') + args = parser.parse_args() + + return args + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + args = parse_args() + AgentAPIServer(args.config, host=args.host, port=args.port) diff --git a/lagent/distributed/ray_serve/__init__.py b/lagent/distributed/ray_serve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8739b2a1b4f20ce21e06e6aea90460b073f1a396 --- /dev/null +++ b/lagent/distributed/ray_serve/__init__.py @@ -0,0 +1,3 @@ +from .ray_warpper import AgentRayActor, AsyncAgentRayActor + +__all__ = ['AsyncAgentRayActor', 'AgentRayActor'] diff --git a/lagent/distributed/ray_serve/ray_warpper.py b/lagent/distributed/ray_serve/ray_warpper.py new file mode 100644 index 0000000000000000000000000000000000000000..d401c9443ffc8f806060855fb4e6810fdc60954d --- /dev/null +++ b/lagent/distributed/ray_serve/ray_warpper.py @@ -0,0 +1,48 @@ +import importlib +import sys +from typing import Dict + +import ray + +from lagent.schema import AgentMessage +from lagent.utils import load_class_from_string + + +class AsyncAgentRayActor: + + def __init__( + self, + config: Dict, + num_gpus: int, + ): + cls_name = config.pop('type') + python_path = config.pop('python_path', None) + cls_name = load_class_from_string(cls_name, python_path) if isinstance( + cls_name, str) else cls_name + AsyncAgentActor = ray.remote(num_gpus=num_gpus)(cls_name) + self.agent_actor = AsyncAgentActor.remote(**config) + + async def __call__(self, *message: AgentMessage, session_id=0, **kwargs): + response = await self.agent_actor.__call__.remote( + *message, session_id=session_id, **kwargs) + return response + + +class AgentRayActor: + + def __init__( + self, + config: Dict, + num_gpus: int, + ): + cls_name = config.pop('type') + python_path = config.pop('python_path', None) + cls_name = load_class_from_string(cls_name, python_path) if isinstance( + cls_name, str) else cls_name + AgentActor = ray.remote(num_gpus=num_gpus)(cls_name) + self.agent_actor = AgentActor.remote(**config) + + def __call__(self, *message: AgentMessage, session_id=0, **kwargs): + response = self.agent_actor.__call__.remote( + *message, session_id=session_id, **kwargs) + return ray.get(response) diff --git a/lagent/hooks/__init__.py b/lagent/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0e935281e0ec2127e00da58b77f621fdbc0dd94f --- /dev/null +++ b/lagent/hooks/__init__.py @@ -0,0 +1,8 @@ +from .action_preprocessor import ActionPreprocessor, InternLMActionProcessor +from .hook import Hook, RemovableHandle +from .logger import MessageLogger + +__all__ = [ + 'Hook', 'RemovableHandle', 'ActionPreprocessor', 'InternLMActionProcessor', + 'MessageLogger' +] diff --git a/lagent/hooks/__pycache__/__init__.cpython-310.pyc b/lagent/hooks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af26dadd500f7028a2a636bac88a7c870992044e Binary files /dev/null and b/lagent/hooks/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/hooks/__pycache__/action_preprocessor.cpython-310.pyc b/lagent/hooks/__pycache__/action_preprocessor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2f03caa921f3fc36e4ff5e3dcc02152518db35f Binary files /dev/null and b/lagent/hooks/__pycache__/action_preprocessor.cpython-310.pyc differ diff --git a/lagent/hooks/__pycache__/hook.cpython-310.pyc b/lagent/hooks/__pycache__/hook.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c3c46cd0699b9324120545670d1f99420b064c7 Binary files /dev/null and b/lagent/hooks/__pycache__/hook.cpython-310.pyc differ diff --git a/lagent/hooks/__pycache__/logger.cpython-310.pyc b/lagent/hooks/__pycache__/logger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd4b8634d6497782b0104086a0e60c4ec881d298 Binary files /dev/null and b/lagent/hooks/__pycache__/logger.cpython-310.pyc differ diff --git a/lagent/hooks/action_preprocessor.py b/lagent/hooks/action_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..51083aa86d4241acc307817b77708299c32daf93 --- /dev/null +++ b/lagent/hooks/action_preprocessor.py @@ -0,0 +1,62 @@ +from copy import deepcopy + +from lagent.schema import ActionReturn, ActionStatusCode, FunctionCall +from .hook import Hook + + +class ActionPreprocessor(Hook): + """The ActionPreprocessor is a hook that preprocesses the action message + and postprocesses the action return message. + + """ + + def before_action(self, executor, message, session_id): + assert isinstance(message.formatted, FunctionCall) or ( + isinstance(message.formatted, dict) and 'name' in message.content + and 'parameters' in message.formatted) or ( + 'action' in message.formatted + and 'parameters' in message.formatted['action'] + and 'name' in message.formatted['action']) + if isinstance(message.formatted, dict): + name = message.formatted.get('name', + message.formatted['action']['name']) + parameters = message.formatted.get( + 'parameters', message.formatted['action']['parameters']) + else: + name = message.formatted.name + parameters = message.formatted.parameters + message.content = dict(name=name, parameters=parameters) + return message + + def after_action(self, executor, message, session_id): + action_return = message.content + if isinstance(action_return, ActionReturn): + if action_return.state == ActionStatusCode.SUCCESS: + response = action_return.format_result() + else: + response = action_return.errmsg + else: + response = action_return + message.content = response + return message + + +class InternLMActionProcessor(ActionPreprocessor): + + def __init__(self, code_parameter: str = 'command'): + self.code_parameter = code_parameter + + def before_action(self, executor, message, session_id): + message = deepcopy(message) + assert isinstance(message.formatted, dict) and set( + message.formatted).issuperset( + {'tool_type', 'thought', 'action', 'status'}) + if isinstance(message.formatted['action'], str): + # encapsulate code interpreter arguments + action_name = next(iter(executor.actions)) + parameters = {self.code_parameter: message.formatted['action']} + if action_name in ['AsyncIPythonInterpreter']: + parameters['session_id'] = session_id + message.formatted['action'] = dict( + name=action_name, parameters=parameters) + return super().before_action(executor, message, session_id) diff --git a/lagent/hooks/hook.py b/lagent/hooks/hook.py new file mode 100644 index 0000000000000000000000000000000000000000..c3b3e303dae379797e491fee4e227f3c274251db --- /dev/null +++ b/lagent/hooks/hook.py @@ -0,0 +1,50 @@ +from itertools import count +from typing import Tuple + +from lagent.schema import AgentMessage + + +class Hook: + + def before_agent( + self, + agent, + message: Tuple[AgentMessage], + session_id: int, + ): + pass + + def after_agent( + self, + agent, + message: AgentMessage, + session_id: int, + ): + pass + + def before_action( + self, + executor, + message: AgentMessage, + session_id: int, + ): + pass + + def after_action( + self, + executor, + message: AgentMessage, + session_id: int, + ): + pass + + +class RemovableHandle: + _id_iter = count(0) + + def __init__(self, hooks_dict): + self.hooks_dict = hooks_dict + self.id = next(self._id_iter) + + def remove(self): + del self.hooks_dict[self.id] diff --git a/lagent/hooks/logger.py b/lagent/hooks/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..50224e432a6ca1177f2f39ff760fd5855fcf43d9 --- /dev/null +++ b/lagent/hooks/logger.py @@ -0,0 +1,37 @@ +import random +from typing import Optional + +from termcolor import COLORS, colored + +from lagent.utils import get_logger +from .hook import Hook + + +class MessageLogger(Hook): + + def __init__(self, name: str = 'lagent'): + self.logger = get_logger( + name, 'info', '%(asctime)s %(levelname)8s %(name)8s - %(message)s') + self.sender2color = {} + + def before_agent(self, agent, messages, session_id): + for message in messages: + self._process_message(message, session_id) + + def after_agent(self, agent, message, session_id): + self._process_message(message, session_id) + + def before_action(self, executor, message, session_id): + self._process_message(message, session_id) + + def after_action(self, executor, message, session_id): + self._process_message(message, session_id) + + def _process_message(self, message, session_id): + sender = message.sender + color = self.sender2color.setdefault(sender, + random.choice(list(COLORS))) + self.logger.info( + colored( + f'session id: {session_id}, message sender: {sender}\n' + f'{message.content}', color)) diff --git a/lagent/llms/__init__.py b/lagent/llms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fcbbd07d4622b1bf53a9b0daebb6a1c35a6a1711 --- /dev/null +++ b/lagent/llms/__init__.py @@ -0,0 +1,32 @@ +from .base_api import AsyncBaseAPILLM, BaseAPILLM +from .base_llm import AsyncBaseLLM, BaseLLM +from .huggingface import HFTransformer, HFTransformerCasualLM, HFTransformerChat +from .lmdeploy_wrapper import (AsyncLMDeployClient, AsyncLMDeployPipeline, + AsyncLMDeployServer, LMDeployClient, + LMDeployPipeline, LMDeployServer) +from .meta_template import INTERNLM2_META +from .openai import GPTAPI, AsyncGPTAPI +from .sensenova import SensenovaAPI +from .vllm_wrapper import AsyncVllmModel, VllmModel + +__all__ = [ + 'AsyncBaseLLM', + 'BaseLLM', + 'AsyncBaseAPILLM', + 'BaseAPILLM', + 'AsyncGPTAPI', + 'GPTAPI', + 'LMDeployClient', + 'AsyncLMDeployClient', + 'LMDeployPipeline', + 'AsyncLMDeployPipeline', + 'LMDeployServer', + 'AsyncLMDeployServer', + 'HFTransformer', + 'HFTransformerCasualLM', + 'INTERNLM2_META', + 'HFTransformerChat', + 'VllmModel', + 'AsyncVllmModel', + 'SensenovaAPI', +] diff --git a/lagent/llms/__pycache__/__init__.cpython-310.pyc b/lagent/llms/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1d1eab97843124fa39ed95693ee60e4ffd2188b Binary files /dev/null and b/lagent/llms/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/llms/__pycache__/base_api.cpython-310.pyc b/lagent/llms/__pycache__/base_api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc79446a8dd2ef27afd3da8fea01672c174a3df1 Binary files /dev/null and b/lagent/llms/__pycache__/base_api.cpython-310.pyc differ diff --git a/lagent/llms/__pycache__/base_llm.cpython-310.pyc b/lagent/llms/__pycache__/base_llm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5db19fb6a212e3f62ebba94fec7e2cdc7aa8cfe Binary files /dev/null and b/lagent/llms/__pycache__/base_llm.cpython-310.pyc differ diff --git a/lagent/llms/__pycache__/huggingface.cpython-310.pyc b/lagent/llms/__pycache__/huggingface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c468ae61c70a93a24cac5d1c298ebc72d860011d Binary files /dev/null and b/lagent/llms/__pycache__/huggingface.cpython-310.pyc differ diff --git a/lagent/llms/__pycache__/lmdeploy_wrapper.cpython-310.pyc b/lagent/llms/__pycache__/lmdeploy_wrapper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db5b1a1f9206e1cf6e72b25692a100f1c2a9b106 Binary files /dev/null and b/lagent/llms/__pycache__/lmdeploy_wrapper.cpython-310.pyc differ diff --git a/lagent/llms/__pycache__/meta_template.cpython-310.pyc b/lagent/llms/__pycache__/meta_template.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39686ce2f981fff65727d59540fdc23a8eb85b09 Binary files /dev/null and b/lagent/llms/__pycache__/meta_template.cpython-310.pyc differ diff --git a/lagent/llms/__pycache__/openai.cpython-310.pyc b/lagent/llms/__pycache__/openai.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a8bf0dbc8ab16bbaf5ee5fd25cc3385f7f179fd Binary files /dev/null and b/lagent/llms/__pycache__/openai.cpython-310.pyc differ diff --git a/lagent/llms/__pycache__/sensenova.cpython-310.pyc b/lagent/llms/__pycache__/sensenova.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e965a03ecd21f0b903ff7503acffac382b3c7f2e Binary files /dev/null and b/lagent/llms/__pycache__/sensenova.cpython-310.pyc differ diff --git a/lagent/llms/__pycache__/vllm_wrapper.cpython-310.pyc b/lagent/llms/__pycache__/vllm_wrapper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..479db8b77c08c408fa8228cff2b1abec5b13ca9a Binary files /dev/null and b/lagent/llms/__pycache__/vllm_wrapper.cpython-310.pyc differ diff --git a/lagent/llms/base_api.py b/lagent/llms/base_api.py new file mode 100644 index 0000000000000000000000000000000000000000..e0c2205dcd2c99f96ed97d5caf26fba6644bf59a --- /dev/null +++ b/lagent/llms/base_api.py @@ -0,0 +1,175 @@ +import warnings +from typing import Dict, List, Optional, Tuple, Union + +from lagent.llms.base_llm import AsyncLLMMixin, BaseLLM + + +class APITemplateParser: + """Intermidate prompt template parser, specifically for API models. + + Args: + meta_template (Dict): The meta template for the model. + """ + + def __init__(self, meta_template: Optional[Dict] = None): + self.meta_template = meta_template + # Check meta template + if meta_template: + assert isinstance(meta_template, list) + self.roles: Dict[str, dict] = dict() # maps role name to config + for item in meta_template: + assert isinstance(item, dict) + assert item['role'] not in self.roles, \ + 'role in meta prompt must be unique!' + self.roles[item['role']] = item.copy() + + def __call__(self, dialog: List[Union[str, List]]): + """Parse the intermidate prompt template, and wrap it with meta + template if applicable. When the meta template is set and the input is + a list, the return value will be a list containing the full + conversation history. Each item looks like: + + .. code-block:: python + + {'role': 'user', 'content': '...'}). + + Args: + dialog (List[str or list]): An intermidate prompt + template (potentially before being wrapped by meta template). + + Returns: + List[str or list]: The finalized prompt or a conversation. + """ + assert isinstance(dialog, (str, list)) + if isinstance(dialog, str): + return dialog + if self.meta_template: + + prompt = list() + # Whether to keep generating the prompt + generate = True + for i, item in enumerate(dialog): + if not generate: + break + if isinstance(item, str): + if item.strip(): + # TODO: logger + warnings.warn('Non-empty string in prompt template ' + 'will be ignored in API models.') + else: + api_prompts = self._prompt2api(item) + prompt.append(api_prompts) + + # merge the consecutive prompts assigned to the same role + new_prompt = list([prompt[0]]) + last_role = prompt[0]['role'] + for item in prompt[1:]: + if item['role'] == last_role: + new_prompt[-1]['content'] += '\n' + item['content'] + else: + last_role = item['role'] + new_prompt.append(item) + prompt = new_prompt + + else: + # in case the model does not have any meta template + prompt = '' + last_sep = '' + for item in dialog: + if isinstance(item, str): + if item: + prompt += last_sep + item + elif item.get('content', ''): + prompt += last_sep + item.get('content', '') + last_sep = '\n' + return prompt + + def _prompt2api(self, prompts: Union[List, str]) -> Tuple[str, bool]: + """Convert the prompts to a API-style prompts, given an updated + role_dict. + + Args: + prompts (Union[List, str]): The prompts to be converted. + role_dict (Dict[str, Dict]): The updated role dict. + for_gen (bool): If True, the prompts will be converted for + generation tasks. The conversion stops before the first + role whose "generate" is set to True. + + Returns: + Tuple[str, bool]: The converted string, and whether the follow-up + conversion should be proceeded. + """ + if isinstance(prompts, str): + return prompts + elif isinstance(prompts, dict): + api_role = self._role2api_role(prompts) + return api_role + + res = [] + for prompt in prompts: + if isinstance(prompt, str): + raise TypeError('Mixing str without explicit role is not ' + 'allowed in API models!') + else: + api_role = self._role2api_role(prompt) + res.append(api_role) + return res + + def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]: + merged_prompt = self.roles[role_prompt['role']] + if merged_prompt.get('fallback_role'): + merged_prompt = self.roles[self.roles[ + merged_prompt['fallback_role']]] + res = role_prompt.copy() + res['role'] = merged_prompt['api_role'] + res['content'] = merged_prompt.get('begin', '') + res['content'] += role_prompt.get('content', '') + res['content'] += merged_prompt.get('end', '') + return res + + +class BaseAPILLM(BaseLLM): + """Base class for API model wrapper. + + Args: + model_type (str): The type of model. + retry (int): Number of retires if the API call fails. Defaults to 2. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + """ + + is_api: bool = True + + def __init__(self, + model_type: str, + retry: int = 2, + template_parser: 'APITemplateParser' = APITemplateParser, + meta_template: Optional[Dict] = None, + *, + max_new_tokens: int = 512, + top_p: float = 0.8, + top_k: int = 40, + temperature: float = 0.8, + repetition_penalty: float = 0.0, + stop_words: Union[List[str], str] = None): + self.model_type = model_type + self.meta_template = meta_template + self.retry = retry + if template_parser: + self.template_parser = template_parser(meta_template) + + if isinstance(stop_words, str): + stop_words = [stop_words] + self.gen_params = dict( + max_new_tokens=max_new_tokens, + top_p=top_p, + top_k=top_k, + temperature=temperature, + repetition_penalty=repetition_penalty, + stop_words=stop_words, + skip_special_tokens=False) + + +class AsyncBaseAPILLM(AsyncLLMMixin, BaseAPILLM): + pass diff --git a/lagent/llms/base_llm.py b/lagent/llms/base_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..62bb5d6296d5d304362614d438617aa4658b5cba --- /dev/null +++ b/lagent/llms/base_llm.py @@ -0,0 +1,305 @@ +from copy import copy +from typing import Dict, List, Optional, Tuple, Union + + +class LMTemplateParser: + """Intermidate prompt template parser, specifically for language models. + + Args: + meta_template (list of dict, optional): The meta template for the + model. + """ + + def __init__(self, meta_template: Optional[List[Dict]] = None): + self.meta_template = meta_template + if meta_template: + assert isinstance(meta_template, list) + self.roles: Dict[str, dict] = dict() # maps role name to config + for item in meta_template: + assert isinstance(item, dict) + assert item['role'] not in self.roles, \ + 'role in meta prompt must be unique!' + self.roles[item['role']] = item.copy() + + def __call__(self, dialog) -> str: + """Parse a prompt template, and wrap it with meta template if + applicable. + + Args: + dialog (List[str or PromptList]): A prompt + template (potentially before being wrapped by meta template). + + Returns: + str: The final string. + """ + assert isinstance(dialog, (str, list)) + if isinstance(dialog, str): + return dialog + if self.meta_template: + + prompt = '' + for index, item in enumerate(dialog): + if isinstance(item, str): + prompt += item + else: + new_str = self._prompt2str(item, index == len(dialog) - 1) + prompt += new_str + else: + # in case the model does not have any meta template + prompt = '' + last_sep = '' + for item in dialog: + if isinstance(item, str): + if item: + prompt += last_sep + item + elif item.get('content', ''): + prompt += last_sep + item.get('prompt', '') + last_sep = '\n' + return prompt + + def _format_begin(self, role_cfg, message): + name = message.get('name', None) + if name is not None: + begin = role_cfg['begin'].get('with_name', '') + if name in role_cfg['begin'].get('name', {}): + begin = begin.format(name=role_cfg['begin']['name'][name]) + else: + begin = begin.format(name=name) + else: + if isinstance(role_cfg.get('begin', ''), str): + begin = role_cfg.get('begin', '') + elif isinstance(role_cfg['begin'], dict): + begin = role_cfg['begin'].get('without_name', '') + return begin + + def _prompt2str(self, + prompt: Union[str, Dict], + last: bool = False) -> Tuple[str, bool]: + if isinstance(prompt, str): + return prompt + merged_prompt = self.roles.get(prompt['role']) + + if merged_prompt.get('fallback_role'): + merged_prompt = self.roles.get(merged_prompt['fallback_role']) + begin = self._format_begin(merged_prompt, prompt) + res = begin + if last and merged_prompt.get('generate', False): + res += prompt.get('content', '') + return res + res += prompt.get('content', '') + merged_prompt.get('end', '') + if last and merged_prompt['role'] != 'assistant': + res += self._format_begin(self.roles['assistant'], {}) + return res + return res + + +class BaseLLM: + """Base class for model wrapper. + + Args: + path (str): The path to the model. + max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults + to 512. + tokenizer_only (bool): If True, only the tokenizer will be initialized. + Defaults to False. + meta_template (list of dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + """ + + def __init__(self, + path: str, + tokenizer_only: bool = False, + template_parser: 'LMTemplateParser' = LMTemplateParser, + meta_template: Optional[List[Dict]] = None, + *, + max_new_tokens: int = 512, + top_p: float = 0.8, + top_k: float = 40, + temperature: float = 0.8, + repetition_penalty: float = 1.0, + stop_words: Union[List[str], str] = None): + self.path = path + self.tokenizer_only = tokenizer_only + # meta template + self.template_parser = template_parser(meta_template) + self.eos_token_id = None + if meta_template and 'eos_token_id' in meta_template: + self.eos_token_id = meta_template['eos_token_id'] + + if isinstance(stop_words, str): + stop_words = [stop_words] + self.gen_params = dict( + max_new_tokens=max_new_tokens, + top_p=top_p, + top_k=top_k, + temperature=temperature, + repetition_penalty=repetition_penalty, + stop_words=stop_words) + + def generate(self, inputs: Union[str, List[str]], **gen_params) -> str: + """Generate results given a str (or list of) inputs. + + Args: + inputs (Union[str, List[str]]): + gen_params (dict): The input params for generation. + + Returns: + Union[str, List[str]]: A (list of) generated strings. + + eg. + batched = True + if isinstance(inputs, str): + inputs = [inputs] + batched = False + response = [''] + if batched: + return response + return response[0] + """ + raise NotImplementedError + + def stream_generate(self, inputs: str, **gen_params) -> List[str]: + """Generate results as streaming given a str inputs. + + Args: + inputs (str): + gen_params (dict): The input params for generation. + + Returns: + str: A generated string. + """ + raise NotImplementedError + + def chat(self, + inputs: Union[List[dict], List[List[dict]]], + session_ids: Union[int, List[int]] = None, + **gen_params): + """Generate completion from a list of templates. + + Args: + inputs (Union[List[dict], List[List[dict]]]): + gen_params (dict): The input params for generation. + Returns: + """ + if isinstance(inputs[0], list): + _inputs = list() + for msg in inputs: + _inputs.append(self.template_parser(msg)) + else: + _inputs = self.template_parser(inputs) + return self.generate(_inputs, **gen_params) + + def stream_chat(self, inputs: List[dict], **gen_params): + """Generate results as streaming given a list of templates. + + Args: + inputs (Union[List[dict]): + gen_params (dict): The input params for generation. + Returns: + """ + raise NotImplementedError + + def tokenize(self, prompts: Union[str, List[str], List[dict], + List[List[dict]]]): + """Tokenize the input prompts. + + Args: + prompts(str | List[str]): user's prompt, or a batch prompts + + Returns: + Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token + ids, ids' length and requested output length + """ + raise NotImplementedError + + def update_gen_params(self, **kwargs): + gen_params = copy(self.gen_params) + gen_params.update(kwargs) + return gen_params + + +class AsyncLLMMixin: + + async def generate(self, + inputs: Union[str, List[str]], + session_ids: Union[int, List[int]] = None, + **gen_params) -> str: + """Generate results given a str (or list of) inputs. + + Args: + inputs (Union[str, List[str]]): + gen_params (dict): The input params for generation. + + Returns: + Union[str, List[str]]: A (list of) generated strings. + + eg. + batched = True + if isinstance(inputs, str): + inputs = [inputs] + batched = False + response = [''] + if batched: + return response + return response[0] + """ + raise NotImplementedError + + async def stream_generate(self, inputs: str, **gen_params) -> List[str]: + """Generate results as streaming given a str inputs. + + Args: + inputs (str): + gen_params (dict): The input params for generation. + + Returns: + str: A generated string. + """ + raise NotImplementedError + + async def chat(self, + inputs: Union[List[dict], List[List[dict]]], + session_ids: Union[int, List[int]] = None, + **gen_params): + """Generate completion from a list of templates. + + Args: + inputs (Union[List[dict], List[List[dict]]]): + gen_params (dict): The input params for generation. + Returns: + """ + if isinstance(inputs[0], list): + _inputs = list() + for msg in inputs: + _inputs.append(self.template_parser(msg)) + else: + _inputs = self.template_parser(inputs) + return await self.generate(_inputs, session_ids, **gen_params) + + async def stream_chat(self, inputs: List[dict], **gen_params): + """Generate results as streaming given a list of templates. + + Args: + inputs (Union[List[dict]): + gen_params (dict): The input params for generation. + Returns: + """ + raise NotImplementedError + + async def tokenize(self, prompts: Union[str, List[str], List[dict], + List[List[dict]]]): + """Tokenize the input prompts. + + Args: + prompts(str | List[str]): user's prompt, or a batch prompts + + Returns: + Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token + ids, ids' length and requested output length + """ + raise NotImplementedError + + +class AsyncBaseLLM(AsyncLLMMixin, BaseLLM): + pass diff --git a/lagent/llms/huggingface.py b/lagent/llms/huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..87c779a8f7a1b6b5a9ff78b201b2150c8e277e62 --- /dev/null +++ b/lagent/llms/huggingface.py @@ -0,0 +1,337 @@ +import copy +import logging +from typing import Dict, List, Optional, Union + +from lagent.schema import ModelStatusCode +from .base_api import APITemplateParser +from .base_llm import BaseLLM + +logger = logging.getLogger(__name__) + + +class HFTransformer(BaseLLM): + """Model wrapper around HuggingFace general models. + + Adapted from Internlm (https://github.com/InternLM/InternLM/blob/main/ + chat/web_demo.py) + + Args: + path (str): The name or path to HuggingFace's model. + tokenizer_path (str): The path to the tokenizer. Defaults to None. + tokenizer_kwargs (dict): Keyword arguments for the tokenizer. + Defaults to {}. + tokenizer_only (bool): If True, only the tokenizer will be initialized. + Defaults to False. + model_kwargs (dict): Keyword arguments for the model, used in loader. + Defaults to dict(device_map='auto'). + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + """ + + def __init__(self, + path: str, + tokenizer_path: Optional[str] = None, + tokenizer_kwargs: dict = dict(), + tokenizer_only: bool = False, + model_kwargs: dict = dict(device_map='auto'), + meta_template: Optional[Dict] = None, + stop_words_id: Union[List[int], int] = None, + **kwargs): + super().__init__( + path=path, + tokenizer_only=tokenizer_only, + meta_template=meta_template, + **kwargs) + if isinstance(stop_words_id, int): + stop_words_id = [stop_words_id] + self.gen_params.update(stop_words_id=stop_words_id) + if self.gen_params['stop_words'] is not None and \ + self.gen_params['stop_words_id'] is not None: + logger.warning('Both stop_words and stop_words_id are specified,' + 'only stop_words_id will be used.') + + self._load_tokenizer( + path=path, + tokenizer_path=tokenizer_path, + tokenizer_kwargs=tokenizer_kwargs) + if not tokenizer_only: + self._load_model(path=path, model_kwargs=model_kwargs) + + from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList # noqa: E501 + self.logits_processor = LogitsProcessorList() + self.stopping_criteria = StoppingCriteriaList() + self.prefix_allowed_tokens_fn = None + + stop_words_id = [] + if self.gen_params.get('stop_words_id'): + stop_words_id = self.gen_params.get('stop_words_id') + elif self.gen_params.get('stop_words'): + for sw in self.gen_params.get('stop_words'): + stop_words_id.append(self.tokenizer(sw)['input_ids'][-1]) + self.additional_eos_token_id = stop_words_id + + def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], + tokenizer_kwargs: dict): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path if tokenizer_path else path, + trust_remote_code=True, + **tokenizer_kwargs) + + if self.tokenizer.pad_token_id is None: + if self.tokenizer.eos_token is not None: + logger.warning( + f'Using eos_token_id {self.tokenizer.eos_token} ' + 'as pad_token_id.') + self.tokenizer.pad_token = self.tokenizer.eos_token + else: + from transformers.generation import GenerationConfig + self.gcfg = GenerationConfig.from_pretrained(path) + + if self.gcfg.pad_token_id is not None: + logger.warning( + f'Using pad_token_id {self.gcfg.pad_token_id} ' + 'as pad_token_id.') + self.tokenizer.pad_token_id = self.gcfg.pad_token_id + else: + raise ValueError( + 'pad_token_id is not set for this tokenizer. Try to ' + 'set pad_token_id via passing ' + '`pad_token_id={PAD_TOKEN_ID}` in model_cfg.') + + def _load_model(self, path: str, model_kwargs: dict): + import torch + from transformers import AutoModel + model_kwargs.setdefault('torch_dtype', torch.float16) + self.model = AutoModel.from_pretrained( + path, trust_remote_code=True, **model_kwargs) + self.model.eval() + + def tokenize(self, inputs: str): + assert isinstance(inputs, str) + inputs = self.tokenizer( + inputs, return_tensors='pt', return_length=True) + return inputs['input_ids'].tolist() + + def generate( + self, + inputs: Union[str, List[str]], + do_sample: bool = True, + **kwargs, + ): + """Return the chat completions in non-stream mode. + + Args: + inputs (Union[str, List[str]]): input texts to be completed. + do_sample (bool): do sampling if enabled + Returns: + (a list of/batched) text/chat completion + """ + for status, chunk, _ in self.stream_generate(inputs, do_sample, + **kwargs): + response = chunk + return response + + def stream_generate( + self, + inputs: List[str], + do_sample: bool = True, + **kwargs, + ): + """Return the chat completions in stream mode. + + Args: + inputs (Union[str, List[str]]): input texts to be completed. + do_sample (bool): do sampling if enabled + Returns: + tuple(Status, str, int): status, text/chat completion, + generated token number + """ + import torch + from torch import nn + with torch.no_grad(): + batched = True + if isinstance(inputs, str): + inputs = [inputs] + batched = False + inputs = self.tokenizer( + inputs, padding=True, return_tensors='pt', return_length=True) + input_length = inputs['length'] + for k, v in inputs.items(): + inputs[k] = v.cuda() + input_ids = inputs['input_ids'] + attention_mask = inputs['attention_mask'] + batch_size = input_ids.shape[0] + input_ids_seq_length = input_ids.shape[-1] + generation_config = self.model.generation_config + generation_config = copy.deepcopy(generation_config) + new_gen_params = self.update_gen_params(**kwargs) + generation_config.update(**new_gen_params) + generation_config.update(**kwargs) + model_kwargs = generation_config.to_dict() + model_kwargs['attention_mask'] = attention_mask + _, eos_token_id = ( # noqa: F841 # pylint: disable=W0612 + generation_config.bos_token_id, + generation_config.eos_token_id, + ) + if eos_token_id is None: + if self.gcfg.eos_token_id is not None: + eos_token_id = self.gcfg.eos_token_id + else: + eos_token_id = [] + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if self.additional_eos_token_id is not None: + eos_token_id.extend(self.additional_eos_token_id) + eos_token_id_tensor = torch.tensor(eos_token_id).to( + input_ids.device) if eos_token_id is not None else None + generation_config.max_length = ( + generation_config.max_new_tokens + input_ids_seq_length) + # Set generation parameters if not already defined + logits_processor = self.logits_processor + stopping_criteria = self.stopping_criteria + + logits_processor = self.model._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self.model._get_stopping_criteria( + generation_config=generation_config, + stopping_criteria=stopping_criteria) + logits_warper = self.model._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(batch_size).fill_(1) + scores = None + while True: + model_inputs = self.model.prepare_inputs_for_generation( + input_ids, **model_kwargs) + # forward pass to get next token + outputs = self.model( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, + next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if do_sample: + next_tokens = torch.multinomial( + probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, + # and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], + dim=-1) + model_kwargs = self.model._update_model_kwargs_for_generation( # noqa: E501 + outputs, + model_kwargs, + is_encoder_decoder=False) + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne( + eos_token_id_tensor.unsqueeze(1)).prod(dim=0)) + output_token_ids = input_ids.cpu().tolist() + for i in range(len(output_token_ids)): + output_token_ids[i] = output_token_ids[i][:][ + input_length[i]:] + # Find the first occurrence of + # an EOS token in the sequence + first_eos_idx = next( + (idx + for idx, token_id in enumerate(output_token_ids[i]) + if token_id in eos_token_id), None) + # If an EOS token is found, only the previous + # part of it is retained + if first_eos_idx is not None: + output_token_ids[i] = output_token_ids[ + i][:first_eos_idx] + + response = self.tokenizer.batch_decode(output_token_ids) + # print(response) + if not batched: + response = response[0] + yield ModelStatusCode.STREAM_ING, response, None + # stop when each sentence is finished, + # or if we exceed the maximum length + if (unfinished_sequences.max() == 0 + or stopping_criteria(input_ids, scores)): + break + yield ModelStatusCode.END, response, None + + def stream_chat( + self, + inputs: List[dict], + do_sample: bool = True, + **kwargs, + ): + """Return the chat completions in stream mode. + + Args: + inputs (List[dict]): input messages to be completed. + do_sample (bool): do sampling if enabled + Returns: + the text/chat completion + """ + prompt = self.template_parser(inputs) + yield from self.stream_generate(prompt, do_sample, **kwargs) + + +class HFTransformerCasualLM(HFTransformer): + + def _load_model(self, path: str, model_kwargs: dict): + import torch + from transformers import AutoModelForCausalLM + model_kwargs.setdefault('torch_dtype', torch.float16) + self.model = AutoModelForCausalLM.from_pretrained( + path, trust_remote_code=True, **model_kwargs) + self.model.eval() + + +class HFTransformerChat(HFTransformerCasualLM): + + def __init__(self, template_parser=APITemplateParser, **kwargs): + super().__init__(template_parser=template_parser, **kwargs) + + def chat(self, + inputs: Union[List[dict], List[List[dict]]], + do_sample: bool = True, + **kwargs): + """Return the chat completions in stream mode. + + Args: + inputs (Union[List[dict], List[List[dict]]]): input messages to be completed. + do_sample (bool): do sampling if enabled + Returns: + the text/chat completion + """ + # handle batch inference with vanilla for loop + if isinstance(inputs[0], list): + resps = [] + for input in inputs: + resps.append(self.chat(input, do_sample, **kwargs)) + return resps + prompt = self.template_parser(inputs) + query = prompt[-1]['content'] + history = prompt[:-1] + try: + response, history = self.model.chat( + self.tokenizer, query, history=history) + except Exception as e: + # handle over-length input error + logger.warning(str(e)) + response = '' + return response diff --git a/lagent/llms/lmdeploy_wrapper.py b/lagent/llms/lmdeploy_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..283d50e8de090878d9e15c1d78d60ee8818fdc8a --- /dev/null +++ b/lagent/llms/lmdeploy_wrapper.py @@ -0,0 +1,790 @@ +import asyncio +import copy +import logging +from dataclasses import asdict +from typing import List, Optional, Union + +import aiohttp + +from lagent.llms.base_llm import AsyncLLMMixin, BaseLLM +from lagent.schema import ModelStatusCode +from lagent.utils.util import filter_suffix + + +class TritonClient(BaseLLM): + """TritonClient is a wrapper of TritonClient for LLM. + + Args: + tritonserver_addr (str): the address in format "ip:port" of + triton inference server + model_name (str): the name of the model + session_len (int): the context size + max_tokens (int): the expected generated token numbers + """ + + def __init__(self, + tritonserver_addr: str, + model_name: str, + session_len: int = 32768, + log_level: str = 'WARNING', + **kwargs): + super().__init__(path=None, **kwargs) + try: + from lmdeploy.serve.turbomind.chatbot import Chatbot, StatusCode + except Exception as e: + logging.error(f'{e}') + raise RuntimeError('DO NOT use turbomind.chatbot since it has ' + 'been removed by lmdeploy since v0.5.2') + self.state_map = { + StatusCode.TRITON_STREAM_END: ModelStatusCode.END, + StatusCode.TRITON_SERVER_ERR: ModelStatusCode.SERVER_ERR, + StatusCode.TRITON_SESSION_CLOSED: ModelStatusCode.SESSION_CLOSED, + StatusCode.TRITON_STREAM_ING: ModelStatusCode.STREAM_ING, + StatusCode.TRITON_SESSION_OUT_OF_LIMIT: + ModelStatusCode.SESSION_OUT_OF_LIMIT, + StatusCode.TRITON_SESSION_INVALID_ARG: + ModelStatusCode.SESSION_INVALID_ARG, + StatusCode.TRITON_SESSION_READY: ModelStatusCode.SESSION_READY + } + self.chatbot = Chatbot( + tritonserver_addr=tritonserver_addr, + model_name=model_name, + session_len=session_len, + log_level=log_level, + **kwargs) + + def generate(self, + inputs: Union[str, List[str]], + session_id: int = 2967, + request_id: str = '', + sequence_start: bool = True, + sequence_end: bool = True, + skip_special_tokens: bool = False, + **kwargs): + """Start a new round conversation of a session. Return the chat + completions in non-stream mode. + + Args: + inputs (str, List[str]): user's prompt(s) in this round + session_id (int): the identical id of a session + request_id (str): the identical id of this round conversation + sequence_start (bool): start flag of a session + sequence_end (bool): end flag of a session + skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be False. + Returns: + (a list of/batched) text/chat completion + """ + from lmdeploy.serve.turbomind.chatbot import Session, get_logger + if isinstance(inputs, str): + inputs = [inputs] + prompt = inputs + + assert isinstance(session_id, int), \ + f'INT session id is required, but got {type(session_id)}' + + self.chatbot.cfg = self._update_gen_params(**kwargs) + max_new_tokens = self.chatbot.cfg.max_new_tokens + + logger = get_logger('service.ft', log_level=self.chatbot.log_level) + logger.info(f'session {session_id}, request_id {request_id}, ' + f'max_out_len {max_new_tokens}') + + if self.chatbot._session is None: + sequence_start = True + self.chatbot._session = Session(session_id=session_id) + elif self.chatbot._session.status == 0: + logger.error(f'session {session_id} has been ended. Please set ' + f'`sequence_start` be True if you want to restart it') + return '' + + self.chatbot._session.status = 1 + self.chatbot._session.request_id = request_id + self.chatbot._session.response = '' + + status, res, _ = None, '', 0 + for status, res, _ in self.chatbot._stream_infer( + self.chatbot._session, + prompt, + max_new_tokens, + sequence_start, + sequence_end, + skip_special_tokens=skip_special_tokens): + status = self.state_map.get(status) + if status < ModelStatusCode.END: + return '' + elif status == ModelStatusCode.END: + self.chatbot._session.histories = ( + self.chatbot._session.histories + + self.chatbot._session.prompt + + self.chatbot._session.response) + # remove stop_words + res = filter_suffix(res, self.gen_params.get('stop_words')) + return res + + def stream_chat(self, + inputs: List[dict], + session_id: int = 2967, + request_id: str = '', + sequence_start: bool = True, + sequence_end: bool = True, + skip_special_tokens: bool = False, + **kwargs): + """Start a new round conversation of a session. Return the chat + completions in stream mode. + + Args: + session_id (int): the identical id of a session + inputs (List[dict]): user's inputs in this round conversation + request_id (str): the identical id of this round conversation + sequence_start (bool): start flag of a session + sequence_end (bool): end flag of a session + skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be False. + Returns: + tuple(Status, str, int): status, text/chat completion, + generated token number + """ + from lmdeploy.serve.turbomind.chatbot import Session, get_logger + assert isinstance(session_id, int), \ + f'INT session id is required, but got {type(session_id)}' + + self.chatbot.cfg = self._update_gen_params(**kwargs) + max_new_tokens = self.chatbot.cfg.max_new_tokens + + logger = get_logger('service.ft', log_level=self.chatbot.log_level) + logger.info(f'session {session_id}, request_id {request_id}, ' + f'max_out_len {max_new_tokens}') + + if self.chatbot._session is None: + sequence_start = True + self.chatbot._session = Session(session_id=session_id) + elif self.chatbot._session.status == 0: + logger.error(f'session {session_id} has been ended. Please set ' + f'`sequence_start` be True if you want to restart it') + return ModelStatusCode.SESSION_CLOSED, '', 0 + + self.chatbot._session.status = 1 + self.chatbot._session.request_id = request_id + self.chatbot._session.response = '' + + prompt = self.template_parser(inputs) + status, res, _ = None, '', 0 + for status, res, _ in self.chatbot._stream_infer( + self.chatbot._session, + prompt, + max_new_tokens, + sequence_start, + sequence_end, + skip_special_tokens=skip_special_tokens): + status = self.state_map.get(status) + # The stop symbol also appears in the output of the last STREAM_ING state. + res = filter_suffix(res, self.gen_params.get('stop_words')) + if status < ModelStatusCode.END: + return status, res, _ + elif status == ModelStatusCode.END: # remove stop_words + self.chatbot._session.histories = ( + self.chatbot._session.histories + + self.chatbot._session.prompt + + self.chatbot._session.response) + yield status, res, _ + break + else: + yield status, res, _ + + def _update_gen_params(self, **kwargs): + import mmengine + new_gen_params = self.update_gen_params(**kwargs) + self.gen_params['stop_words'] = new_gen_params.pop('stop_words') + stop_words = self.chatbot._stop_words( + self.gen_params.get('stop_words')) + cfg = mmengine.Config( + dict( + session_len=self.chatbot.model.session_len, + stop_words=stop_words, + bad_words=self.chatbot.cfg.bad_words, + **new_gen_params)) + return cfg + + +class LMDeployPipeline(BaseLLM): + """ + + Args: + path (str): The path to the model. + It could be one of the following options: + - i) A local directory path of a turbomind model which is + converted by `lmdeploy convert` command or download + from ii) and iii). + - ii) The model_id of a lmdeploy-quantized model hosted + inside a model repo on huggingface.co, such as + "InternLM/internlm-chat-20b-4bit", + "lmdeploy/llama2-chat-70b-4bit", etc. + - iii) The model_id of a model hosted inside a model repo + on huggingface.co, such as "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" + and so on. + model_name (str): needed when model_path is a pytorch model on + huggingface.co, such as "internlm-chat-7b", + "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. + tp (int): tensor parallel + pipeline_cfg (dict): config of pipeline + """ + + def __init__(self, + path: str, + model_name: Optional[str] = None, + tp: int = 1, + pipeline_cfg=dict(), + **kwargs): + import lmdeploy + from lmdeploy import ChatTemplateConfig, TurbomindEngineConfig, pipeline, version_info + + self.str_version = lmdeploy.__version__ + self.version = version_info + self.do_sample = kwargs.pop('do_sample', None) + if self.do_sample is not None and self.version < (0, 6, 0): + raise RuntimeError( + '`do_sample` parameter is not supported by lmdeploy until ' + f'v0.6.0, but currently using lmdeloy {self.str_version}') + super().__init__(path=path, **kwargs) + backend_config = copy.deepcopy(pipeline_cfg) + backend_config.update(tp=tp) + backend_config = { + k: v + for k, v in backend_config.items() + if hasattr(TurbomindEngineConfig, k) + } + backend_config = TurbomindEngineConfig(**backend_config) + chat_template_config = ChatTemplateConfig( + model_name=model_name) if model_name else None + self.model = pipeline( + model_path=self.path, + backend_config=backend_config, + chat_template_config=chat_template_config, + log_level='WARNING') + + def generate(self, + inputs: Union[str, List[str]], + do_preprocess: bool = None, + skip_special_tokens: bool = False, + return_dict: bool = False, + **kwargs): + """Return the chat completions in non-stream mode. + + Args: + inputs (Union[str, List[str]]): input texts to be completed. + do_preprocess (bool): whether pre-process the messages. Default to + True, which means chat_template will be applied. + skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be False. + Returns: + (a list of/batched) text/chat completion + """ + from lmdeploy.messages import GenerationConfig + batched = True + if isinstance(inputs, str): + inputs = [inputs] + batched = False + prompt = inputs + do_sample = kwargs.pop('do_sample', None) + gen_params = self.update_gen_params(**kwargs) + + if do_sample is None: + do_sample = self.do_sample + if do_sample is not None and self.version < (0, 6, 0): + raise RuntimeError( + '`do_sample` parameter is not supported by lmdeploy until ' + f'v0.6.0, but currently using lmdeloy {self.str_version}') + if self.version >= (0, 6, 0): + if do_sample is None: + do_sample = gen_params['top_k'] > 1 or gen_params[ + 'temperature'] > 0 + gen_params.update(do_sample=do_sample) + + gen_config = GenerationConfig( + skip_special_tokens=skip_special_tokens, **gen_params) + response = self.model.batch_infer( + prompt, gen_config=gen_config, do_preprocess=do_preprocess) + texts = [resp.text for resp in response] + # remove stop_words + texts = filter_suffix(texts, self.gen_params.get('stop_words')) + for resp, text in zip(response, texts): + resp.text = text + if batched: + return [asdict(resp) + for resp in response] if return_dict else texts + return asdict(response[0]) if return_dict else texts[0] + + +class LMDeployServer(BaseLLM): + """ + + Args: + path (str): The path to the model. + It could be one of the following options: + - i) A local directory path of a turbomind model which is + converted by `lmdeploy convert` command or download from + ii) and iii). + - ii) The model_id of a lmdeploy-quantized model hosted + inside a model repo on huggingface.co, such as + "InternLM/internlm-chat-20b-4bit", + "lmdeploy/llama2-chat-70b-4bit", etc. + - iii) The model_id of a model hosted inside a model repo + on huggingface.co, such as "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" + and so on. + model_name (str): needed when model_path is a pytorch model on + huggingface.co, such as "internlm-chat-7b", + "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. + server_name (str): host ip for serving + server_port (int): server port + tp (int): tensor parallel + log_level (str): set log level whose value among + [CRITICAL, ERROR, WARNING, INFO, DEBUG] + """ + + def __init__(self, + path: str, + model_name: Optional[str] = None, + server_name: str = '0.0.0.0', + server_port: int = 23333, + tp: int = 1, + log_level: str = 'WARNING', + serve_cfg=dict(), + **kwargs): + super().__init__(path=path, **kwargs) + self.model_name = model_name + # TODO get_logger issue in multi processing + import lmdeploy + self.client = lmdeploy.serve( + model_path=self.path, + model_name=model_name, + server_name=server_name, + server_port=server_port, + tp=tp, + log_level=log_level, + **serve_cfg) + + def generate(self, + inputs: Union[str, List[str]], + session_id: int = 2967, + sequence_start: bool = True, + sequence_end: bool = True, + ignore_eos: bool = False, + skip_special_tokens: Optional[bool] = False, + timeout: int = 30, + **kwargs) -> List[str]: + """Start a new round conversation of a session. Return the chat + completions in non-stream mode. + + Args: + inputs (str, List[str]): user's prompt(s) in this round + session_id (int): the identical id of a session + sequence_start (bool): start flag of a session + sequence_end (bool): end flag of a session + ignore_eos (bool): indicator for ignoring eos + skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be False. + timeout (int): max time to wait for response + Returns: + (a list of/batched) text/chat completion + """ + + batched = True + if isinstance(inputs, str): + inputs = [inputs] + batched = False + + gen_params = self.update_gen_params(**kwargs) + max_new_tokens = gen_params.pop('max_new_tokens') + gen_params.update(max_tokens=max_new_tokens) + + resp = [''] * len(inputs) + for text in self.client.completions_v1( + self.model_name, + inputs, + session_id=session_id, + sequence_start=sequence_start, + sequence_end=sequence_end, + stream=False, + ignore_eos=ignore_eos, + skip_special_tokens=skip_special_tokens, + timeout=timeout, + **gen_params): + resp = [ + resp[i] + item['text'] + for i, item in enumerate(text['choices']) + ] + # remove stop_words + resp = filter_suffix(resp, self.gen_params.get('stop_words')) + if not batched: + return resp[0] + return resp + + def stream_chat(self, + inputs: List[dict], + session_id=0, + sequence_start: bool = True, + sequence_end: bool = True, + stream: bool = True, + ignore_eos: bool = False, + skip_special_tokens: Optional[bool] = False, + timeout: int = 30, + **kwargs): + """Start a new round conversation of a session. Return the chat + completions in stream mode. + + Args: + session_id (int): the identical id of a session + inputs (List[dict]): user's inputs in this round conversation + sequence_start (bool): start flag of a session + sequence_end (bool): end flag of a session + stream (bool): return in a streaming format if enabled + ignore_eos (bool): indicator for ignoring eos + skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be False. + timeout (int): max time to wait for response + Returns: + tuple(Status, str, int): status, text/chat completion, + generated token number + """ + gen_params = self.update_gen_params(**kwargs) + max_new_tokens = gen_params.pop('max_new_tokens') + gen_params.update(max_tokens=max_new_tokens) + prompt = self.template_parser(inputs) + + resp = '' + finished = False + stop_words = self.gen_params.get('stop_words') + for text in self.client.completions_v1( + self.model_name, + prompt, + session_id=session_id, + sequence_start=sequence_start, + sequence_end=sequence_end, + stream=stream, + ignore_eos=ignore_eos, + skip_special_tokens=skip_special_tokens, + timeout=timeout, + **gen_params): + resp += text['choices'][0]['text'] + if not resp: + continue + # remove stop_words + for sw in stop_words: + if sw in resp: + resp = filter_suffix(resp, stop_words) + finished = True + break + yield ModelStatusCode.STREAM_ING, resp, None + if finished: + break + yield ModelStatusCode.END, resp, None + + +class LMDeployClient(LMDeployServer): + """ + + Args: + url (str): communicating address 'http://:' of + api_server + model_name (str): needed when model_path is a pytorch model on + huggingface.co, such as "internlm-chat-7b", + "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. + """ + + def __init__(self, url: str, model_name: str, **kwargs): + BaseLLM.__init__(self, path=url, **kwargs) + from lmdeploy.serve.openai.api_client import APIClient + self.client = APIClient(url) + self.model_name = model_name + + +class AsyncLMDeployPipeline(AsyncLLMMixin, LMDeployPipeline): + """ + + Args: + path (str): The path to the model. + It could be one of the following options: + - i) A local directory path of a turbomind model which is + converted by `lmdeploy convert` command or download + from ii) and iii). + - ii) The model_id of a lmdeploy-quantized model hosted + inside a model repo on huggingface.co, such as + "InternLM/internlm-chat-20b-4bit", + "lmdeploy/llama2-chat-70b-4bit", etc. + - iii) The model_id of a model hosted inside a model repo + on huggingface.co, such as "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" + and so on. + model_name (str): needed when model_path is a pytorch model on + huggingface.co, such as "internlm-chat-7b", + "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. + tp (int): tensor parallel + pipeline_cfg (dict): config of pipeline + """ + + async def generate(self, + inputs: Union[str, List[str]], + session_ids: Union[int, List[int]] = None, + do_preprocess: bool = None, + skip_special_tokens: bool = False, + return_dict: bool = False, + **kwargs): + """Return the chat completions in non-stream mode. + + Args: + inputs (Union[str, List[str]]): input texts to be completed. + do_preprocess (bool): whether pre-process the messages. Default to + True, which means chat_template will be applied. + skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be False. + Returns: + (a list of/batched) text/chat completion + """ + from lmdeploy.messages import GenerationConfig, Response + + batched = True + if isinstance(inputs, str): + inputs = [inputs] + batched = False + if session_ids is None: + session_ids = list(range(len(inputs))) + elif isinstance(session_ids, (int, str)): + session_ids = [session_ids] + assert len(inputs) == len(session_ids) + + prompt = inputs + gen_params = self.update_gen_params(**kwargs) + gen_config = GenerationConfig( + skip_special_tokens=skip_special_tokens, **gen_params) + + async def _inner_generate(uid, text): + resp = Response('', 0, 0, uid) + async for out in self.model.generate( + text, + uid, + gen_config, + stream_response=True, + sequence_start=True, + sequence_end=True, + do_preprocess=do_preprocess, + **kwargs): + resp.text += out.response + resp.generate_token_len = out.generate_token_len + resp.input_token_len = out.input_token_len + resp.finish_reason = out.finish_reason + if out.token_ids: + resp.token_ids.extend(out.token_ids) + if out.logprobs: + if resp.logprobs is None: + resp.logprobs = [] + resp.logprobs.extend(out.logprobs) + return resp + + response = await asyncio.gather(*[ + _inner_generate(sid, inp) for sid, inp in zip(session_ids, prompt) + ]) + texts = [resp.text for resp in response] + # remove stop_words + texts = filter_suffix(texts, self.gen_params.get('stop_words')) + for resp, text in zip(response, texts): + resp.text = text + if batched: + return [asdict(resp) + for resp in response] if return_dict else texts + return asdict(response[0]) if return_dict else texts[0] + + +class AsyncLMDeployServer(AsyncLLMMixin, LMDeployServer): + """ + + Args: + path (str): The path to the model. + It could be one of the following options: + - i) A local directory path of a turbomind model which is + converted by `lmdeploy convert` command or download from + ii) and iii). + - ii) The model_id of a lmdeploy-quantized model hosted + inside a model repo on huggingface.co, such as + "InternLM/internlm-chat-20b-4bit", + "lmdeploy/llama2-chat-70b-4bit", etc. + - iii) The model_id of a model hosted inside a model repo + on huggingface.co, such as "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" + and so on. + model_name (str): needed when model_path is a pytorch model on + huggingface.co, such as "internlm-chat-7b", + "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. + server_name (str): host ip for serving + server_port (int): server port + tp (int): tensor parallel + log_level (str): set log level whose value among + [CRITICAL, ERROR, WARNING, INFO, DEBUG] + """ + + async def generate( + self, + inputs: Union[str, List[str]], + session_ids: Union[int, List[int]] = None, + sequence_start: bool = True, + sequence_end: bool = True, + ignore_eos: bool = False, + skip_special_tokens: Optional[bool] = False, + timeout: int = 30, + **kwargs, + ): + """Start a new round conversation of a session. Return the chat + completions in non-stream mode. + + Args: + inputs (str, List[str]): user's prompt(s) in this round + session_ids (int, List[int]): session id(s) + sequence_start (bool): start flag of a session + sequence_end (bool): end flag of a session + ignore_eos (bool): indicator for ignoring eos + skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be False. + timeout (int): max time to wait for response + Returns: + (a list of/batched) text/chat completion + """ + from lmdeploy.serve.openai.api_client import json_loads + + batched = True + if isinstance(inputs, str): + inputs = [inputs] + batched = False + + gen_params = self.update_gen_params(**kwargs) + max_new_tokens = gen_params.pop('max_new_tokens') + gen_params.update(max_tokens=max_new_tokens) + + responses = [''] * len(inputs) + pload = dict( + model=self.model_name, + prompt=inputs, + sequence_start=sequence_start, + sequence_end=sequence_end, + stream=False, + ignore_eos=ignore_eos, + skip_special_tokens=skip_special_tokens, + timeout=timeout, + **gen_params) + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(3 * 3600)) as session: + async with session.post( + self.client.completions_v1_url, + headers=self.client.headers, + json=pload) as resp: + async for chunk in resp.content: + if chunk: + decoded = chunk.decode('utf-8') + output = json_loads(decoded) + responses = [ + response + item['text'] for response, item in zip( + responses, output['choices']) + ] + # remove stop_words + responses = filter_suffix(responses, self.gen_params.get('stop_words')) + if not batched: + return responses[0] + return responses + + async def stream_chat( + self, + inputs: List[dict], + session_id: int = None, + sequence_start: bool = True, + sequence_end: bool = True, + stream: bool = True, + ignore_eos: bool = False, + skip_special_tokens: Optional[bool] = False, + timeout: int = 30, + **kwargs, + ): + """Start a new round conversation of a session. Return the chat + completions in stream mode. + + Args: + inputs (List[dict]): user's inputs in this round conversation + session_id (int): session id + sequence_start (bool): start flag of a session + sequence_end (bool): end flag of a session + stream (bool): return in a streaming format if enabled + ignore_eos (bool): indicator for ignoring eos + skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be False. + timeout (int): max time to wait for response + Returns: + tuple(Status, str, int): status, text/chat completion, + generated token number + """ + from lmdeploy.serve.openai.api_client import json_loads + + gen_params = self.update_gen_params(**kwargs) + max_new_tokens = gen_params.pop('max_new_tokens') + gen_params.update(max_tokens=max_new_tokens) + prompt = self.template_parser(inputs) + + response = '' + finished = False + stop_words = self.gen_params.get('stop_words') + + pload = dict( + model=self.model_name, + prompt=prompt, + sequence_start=sequence_start, + sequence_end=sequence_end, + stream=stream, + ignore_eos=ignore_eos, + skip_special_tokens=skip_special_tokens, + timeout=timeout, + **gen_params) + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(3 * 3600)) as session: + async with session.post( + self.client.completions_v1_url, + headers=self.client.headers, + json=pload) as resp: + async for chunk in resp.content: + if chunk: + decoded = chunk.decode('utf-8') + if not decoded.strip() or decoded.rstrip( + ) == 'data: [DONE]': + continue + if decoded[:6] == 'data: ': + decoded = decoded[6:] + output = json_loads(decoded) + response += output['choices'][0]['text'] + if not response: + continue + # remove stop_words + for sw in stop_words: + if sw in response: + response = filter_suffix(response, stop_words) + finished = True + break + yield ModelStatusCode.STREAM_ING, response, None + if finished: + break + yield ModelStatusCode.END, response, None + + +class AsyncLMDeployClient(AsyncLMDeployServer): + """ + + Args: + url (str): communicating address 'http://:' of + api_server + model_name (str): needed when model_path is a pytorch model on + huggingface.co, such as "internlm-chat-7b", + "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. + """ + + def __init__(self, url: str, model_name: str, **kwargs): + BaseLLM.__init__(self, path=url, **kwargs) + from lmdeploy.serve.openai.api_client import APIClient + self.client = APIClient(url) + self.model_name = model_name diff --git a/lagent/llms/meta_template.py b/lagent/llms/meta_template.py new file mode 100644 index 0000000000000000000000000000000000000000..9b4ed978889619367636431801ffba8fcc462f0e --- /dev/null +++ b/lagent/llms/meta_template.py @@ -0,0 +1,40 @@ +INTERNLM2_META = [ + dict( + role='system', + begin=dict( + with_name='<|im_start|>system name={name}\n', + without_name='<|im_start|>system\n', + name={ + 'interpreter': '<|interpreter|>', + 'plugin': '<|plugin|>', + }), + end='<|im_end|>\n', + ), + dict( + role='user', + begin=dict( + with_name='<|im_start|>user name={name}\n', + without_name='<|im_start|>user\n', + ), + end='<|im_end|>\n'), + dict( + role='assistant', + begin=dict( + with_name='<|im_start|>assistant name={name}\n', + without_name='<|im_start|>assistant\n', + name={ + 'interpreter': '<|interpreter|>', + 'plugin': '<|plugin|>', + }), + end='<|im_end|>\n'), + dict( + role='environment', + begin=dict( + with_name='<|im_start|>environment name={name}\n', + without_name='<|im_start|>environment\n', + name={ + 'interpreter': '<|interpreter|>', + 'plugin': '<|plugin|>', + }), + end='<|im_end|>\n'), +] diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..ffbd1b3de10bb6799c673784367acb476fe495cf --- /dev/null +++ b/lagent/llms/openai.py @@ -0,0 +1,924 @@ +import asyncio +import json +import os +import time +import traceback +import warnings +from concurrent.futures import ThreadPoolExecutor +from logging import getLogger +from threading import Lock +from typing import AsyncGenerator, Dict, List, Optional, Union + +import aiohttp +import requests + +from ..schema import ModelStatusCode +from ..utils import filter_suffix +from .base_api import AsyncBaseAPILLM, BaseAPILLM + +warnings.simplefilter('default') + +OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions' + + +class GPTAPI(BaseAPILLM): + """Model wrapper around OpenAI's models. + + Args: + model_type (str): The name of OpenAI's model. + retry (int): Number of retires if the API call fails. Defaults to 2. + key (str or List[str]): OpenAI key(s). In particular, when it + is set to "ENV", the key will be fetched from the environment + variable $OPENAI_API_KEY, as how openai defaults to be. If it's a + list, the keys will be used in round-robin manner. Defaults to + 'ENV'. + org (str or List[str], optional): OpenAI organization(s). If not + specified, OpenAI uses the default organization bound to each API + key. If specified, the orgs will be posted with each request in + round-robin manner. Defaults to None. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + api_base (str): The base url of OpenAI's API. Defaults to + 'https://api.openai.com/v1/chat/completions'. + gen_params: Default generation configuration which could be overridden + on the fly of generation. + """ + + is_api: bool = True + + def __init__(self, + model_type: str = 'gpt-3.5-turbo', + retry: int = 2, + json_mode: bool = False, + key: Union[str, List[str]] = 'ENV', + org: Optional[Union[str, List[str]]] = None, + meta_template: Optional[Dict] = [ + dict(role='system', api_role='system'), + dict(role='user', api_role='user'), + dict(role='assistant', api_role='assistant'), + dict(role='environment', api_role='system') + ], + api_base: str = OPENAI_API_BASE, + proxies: Optional[Dict] = None, + **gen_params): + if 'top_k' in gen_params: + warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', + DeprecationWarning) + gen_params.pop('top_k') + super().__init__( + model_type=model_type, + meta_template=meta_template, + retry=retry, + **gen_params) + self.gen_params.pop('top_k') + self.logger = getLogger(__name__) + + if isinstance(key, str): + self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] + else: + self.keys = key + + # record invalid keys and skip them when requesting API + # - keys have insufficient_quota + self.invalid_keys = set() + + self.key_ctr = 0 + if isinstance(org, str): + self.orgs = [org] + else: + self.orgs = org + self.org_ctr = 0 + self.url = api_base + self.model_type = model_type + self.proxies = proxies + self.json_mode = json_mode + + def chat( + self, + inputs: Union[List[dict], List[List[dict]]], + **gen_params, + ) -> Union[str, List[str]]: + """Generate responses given the contexts. + + Args: + inputs (Union[List[dict], List[List[dict]]]): a list of messages + or list of lists of messages + gen_params: additional generation configuration + + Returns: + Union[str, List[str]]: generated string(s) + """ + assert isinstance(inputs, list) + if 'max_tokens' in gen_params: + raise NotImplementedError('unsupported parameter: max_tokens') + gen_params = {**self.gen_params, **gen_params} + with ThreadPoolExecutor(max_workers=20) as executor: + tasks = [ + executor.submit(self._chat, + self.template_parser._prompt2api(messages), + **gen_params) + for messages in ( + [inputs] if isinstance(inputs[0], dict) else inputs) + ] + ret = [task.result() for task in tasks] + return ret[0] if isinstance(inputs[0], dict) else ret + + def stream_chat( + self, + inputs: List[dict], + **gen_params, + ): + """Generate responses given the contexts. + + Args: + inputs (List[dict]): a list of messages + gen_params: additional generation configuration + + Returns: + str: generated string + """ + assert isinstance(inputs, list) + if 'max_tokens' in gen_params: + raise NotImplementedError('unsupported parameter: max_tokens') + gen_params = self.update_gen_params(**gen_params) + gen_params['stream'] = True + + resp = '' + finished = False + stop_words = gen_params.get('stop_words') + if stop_words is None: + stop_words = [] + # mapping to role that openai supports + messages = self.template_parser._prompt2api(inputs) + for text in self._stream_chat(messages, **gen_params): + if self.model_type.lower().startswith('qwen'): + resp = text + else: + resp += text + if not resp: + continue + # remove stop_words + for sw in stop_words: + if sw in resp: + resp = filter_suffix(resp, stop_words) + finished = True + break + yield ModelStatusCode.STREAM_ING, resp, None + if finished: + break + yield ModelStatusCode.END, resp, None + + def _chat(self, messages: List[dict], **gen_params) -> str: + """Generate completion from a list of templates. + + Args: + messages (List[dict]): a list of prompt dictionaries + gen_params: additional generation configuration + + Returns: + str: The generated string. + """ + assert isinstance(messages, list) + + header, data = self.generate_request_data( + model_type=self.model_type, + messages=messages, + gen_params=gen_params, + json_mode=self.json_mode) + + max_num_retries, errmsg = 0, '' + while max_num_retries < self.retry: + with Lock(): + if len(self.invalid_keys) == len(self.keys): + raise RuntimeError('All keys have insufficient quota.') + + # find the next valid key + while True: + self.key_ctr += 1 + if self.key_ctr == len(self.keys): + self.key_ctr = 0 + + if self.keys[self.key_ctr] not in self.invalid_keys: + break + + key = self.keys[self.key_ctr] + header['Authorization'] = f'Bearer {key}' + + if self.orgs: + with Lock(): + self.org_ctr += 1 + if self.org_ctr == len(self.orgs): + self.org_ctr = 0 + header['OpenAI-Organization'] = self.orgs[self.org_ctr] + + response = dict() + try: + raw_response = requests.post( + self.url, + headers=header, + data=json.dumps(data), + proxies=self.proxies) + response = raw_response.json() + return response['choices'][0]['message']['content'].strip() + except requests.ConnectionError: + errmsg = 'Got connection error ' + str(traceback.format_exc()) + self.logger.error(errmsg) + continue + except requests.JSONDecodeError: + errmsg = 'JsonDecode error, got ' + str(raw_response.content) + self.logger.error(errmsg) + continue + except KeyError: + if 'error' in response: + if response['error']['code'] == 'rate_limit_exceeded': + time.sleep(1) + continue + elif response['error']['code'] == 'insufficient_quota': + self.invalid_keys.add(key) + self.logger.warn(f'insufficient_quota key: {key}') + continue + + errmsg = 'Find error message in response: ' + str( + response['error']) + self.logger.error(errmsg) + except Exception as error: + errmsg = str(error) + '\n' + str(traceback.format_exc()) + self.logger.error(errmsg) + max_num_retries += 1 + + raise RuntimeError('Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}') + + def _stream_chat(self, messages: List[dict], **gen_params) -> str: + """Generate completion from a list of templates. + + Args: + messages (List[dict]): a list of prompt dictionaries + gen_params: additional generation configuration + + Returns: + str: The generated string. + """ + + def streaming(raw_response): + for chunk in raw_response.iter_lines( + chunk_size=8192, decode_unicode=False, delimiter=b'\n'): + if chunk: + decoded = chunk.decode('utf-8') + if decoded.startswith('data: [DONE]'): + return + if decoded[:5] == 'data:': + decoded = decoded[5:] + if decoded[0] == ' ': + decoded = decoded[1:] + else: + print(decoded) + continue + try: + response = json.loads(decoded) + if 'code' in response and response['code'] == -20003: + # Context exceeds maximum length + yield '' + return + if self.model_type.lower().startswith('qwen'): + choice = response['output']['choices'][0] + yield choice['message']['content'] + if choice['finish_reason'] == 'stop': + return + else: + choice = response['choices'][0] + if choice['finish_reason'] == 'stop': + return + yield choice['delta'].get('content', '') + except Exception as exc: + msg = f'response {decoded} lead to exception of {str(exc)}' + self.logger.error(msg) + raise Exception(msg) from exc + + assert isinstance(messages, list) + + header, data = self.generate_request_data( + model_type=self.model_type, + messages=messages, + gen_params=gen_params, + json_mode=self.json_mode) + + max_num_retries, errmsg = 0, '' + while max_num_retries < self.retry: + if len(self.invalid_keys) == len(self.keys): + raise RuntimeError('All keys have insufficient quota.') + + # find the next valid key + while True: + self.key_ctr += 1 + if self.key_ctr == len(self.keys): + self.key_ctr = 0 + + if self.keys[self.key_ctr] not in self.invalid_keys: + break + + key = self.keys[self.key_ctr] + header['Authorization'] = f'Bearer {key}' + + if self.orgs: + self.org_ctr += 1 + if self.org_ctr == len(self.orgs): + self.org_ctr = 0 + header['OpenAI-Organization'] = self.orgs[self.org_ctr] + + response = dict() + try: + raw_response = requests.post( + self.url, + headers=header, + data=json.dumps(data), + proxies=self.proxies) + return streaming(raw_response) + except requests.ConnectionError: + errmsg = 'Got connection error ' + str(traceback.format_exc()) + self.logger.error(errmsg) + continue + except requests.JSONDecodeError: + errmsg = 'JsonDecode error, got ' + str(raw_response.content) + self.logger.error(errmsg) + continue + except KeyError: + if 'error' in response: + if response['error']['code'] == 'rate_limit_exceeded': + time.sleep(1) + continue + elif response['error']['code'] == 'insufficient_quota': + self.invalid_keys.add(key) + self.logger.warn(f'insufficient_quota key: {key}') + continue + + errmsg = 'Find error message in response: ' + str( + response['error']) + self.logger.error(errmsg) + except Exception as error: + errmsg = str(error) + '\n' + str(traceback.format_exc()) + self.logger.error(errmsg) + max_num_retries += 1 + + raise RuntimeError('Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}') + + def generate_request_data(self, + model_type, + messages, + gen_params, + json_mode=False): + """ + Generates the request data for different model types. + + Args: + model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen'). + messages (list): The list of messages to be sent to the model. + gen_params (dict): The generation parameters. + json_mode (bool): Flag to determine if the response format should be JSON. + + Returns: + tuple: A tuple containing the header and the request data. + """ + # Copy generation parameters to avoid modifying the original dictionary + gen_params = gen_params.copy() + + # Hold out 100 tokens due to potential errors in token calculation + max_tokens = min(gen_params.pop('max_new_tokens'), 4096) + if max_tokens <= 0: + return '', '' + + # Initialize the header + header = { + 'content-type': 'application/json', + } + + # Common parameters processing + gen_params['max_tokens'] = max_tokens + if 'stop_words' in gen_params: + gen_params['stop'] = gen_params.pop('stop_words') + if 'repetition_penalty' in gen_params: + gen_params['frequency_penalty'] = gen_params.pop( + 'repetition_penalty') + + # Model-specific processing + data = {} + if model_type.lower().startswith('gpt'): + if 'top_k' in gen_params: + warnings.warn( + '`top_k` parameter is deprecated in OpenAI APIs.', + DeprecationWarning) + gen_params.pop('top_k') + gen_params.pop('skip_special_tokens', None) + gen_params.pop('session_id', None) + data = { + 'model': model_type, + 'messages': messages, + 'n': 1, + **gen_params + } + if json_mode: + data['response_format'] = {'type': 'json_object'} + elif model_type.lower().startswith('internlm'): + data = { + 'model': model_type, + 'messages': messages, + 'n': 1, + **gen_params + } + if json_mode: + data['response_format'] = {'type': 'json_object'} + elif model_type.lower().startswith('qwen'): + header['X-DashScope-SSE'] = 'enable' + gen_params.pop('skip_special_tokens', None) + gen_params.pop('session_id', None) + if 'frequency_penalty' in gen_params: + gen_params['repetition_penalty'] = gen_params.pop( + 'frequency_penalty') + gen_params['result_format'] = 'message' + data = { + 'model': model_type, + 'input': { + 'messages': messages + }, + 'parameters': { + **gen_params + } + } + else: + raise NotImplementedError( + f'Model type {model_type} is not supported') + + return header, data + + def tokenize(self, prompt: str) -> list: + """Tokenize the input prompt. + + Args: + prompt (str): Input string. + + Returns: + list: token ids + """ + import tiktoken + self.tiktoken = tiktoken + enc = self.tiktoken.encoding_for_model(self.model_type) + return enc.encode(prompt) + + +class AsyncGPTAPI(AsyncBaseAPILLM): + """Model wrapper around OpenAI's models. + + Args: + model_type (str): The name of OpenAI's model. + retry (int): Number of retires if the API call fails. Defaults to 2. + key (str or List[str]): OpenAI key(s). In particular, when it + is set to "ENV", the key will be fetched from the environment + variable $OPENAI_API_KEY, as how openai defaults to be. If it's a + list, the keys will be used in round-robin manner. Defaults to + 'ENV'. + org (str or List[str], optional): OpenAI organization(s). If not + specified, OpenAI uses the default organization bound to each API + key. If specified, the orgs will be posted with each request in + round-robin manner. Defaults to None. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + api_base (str): The base url of OpenAI's API. Defaults to + 'https://api.openai.com/v1/chat/completions'. + gen_params: Default generation configuration which could be overridden + on the fly of generation. + """ + + is_api: bool = True + + def __init__(self, + model_type: str = 'gpt-3.5-turbo', + retry: int = 2, + json_mode: bool = False, + key: Union[str, List[str]] = 'ENV', + org: Optional[Union[str, List[str]]] = None, + meta_template: Optional[Dict] = [ + dict(role='system', api_role='system'), + dict(role='user', api_role='user'), + dict(role='assistant', api_role='assistant') + ], + api_base: str = OPENAI_API_BASE, + proxies: Optional[Dict] = None, + **gen_params): + if 'top_k' in gen_params: + warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', + DeprecationWarning) + gen_params.pop('top_k') + super().__init__( + model_type=model_type, + meta_template=meta_template, + retry=retry, + **gen_params) + self.gen_params.pop('top_k') + self.logger = getLogger(__name__) + + if isinstance(key, str): + self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] + else: + self.keys = key + + # record invalid keys and skip them when requesting API + # - keys have insufficient_quota + self.invalid_keys = set() + + self.key_ctr = 0 + if isinstance(org, str): + self.orgs = [org] + else: + self.orgs = org + self.org_ctr = 0 + self.url = api_base + self.model_type = model_type + self.proxies = proxies or {} + self.json_mode = json_mode + + async def chat( + self, + inputs: Union[List[dict], List[List[dict]]], + session_ids: Union[int, List[int]] = None, + **gen_params, + ) -> Union[str, List[str]]: + """Generate responses given the contexts. + + Args: + inputs (Union[List[dict], List[List[dict]]]): a list of messages + or list of lists of messages + gen_params: additional generation configuration + + Returns: + Union[str, List[str]]: generated string(s) + """ + assert isinstance(inputs, list) + if 'max_tokens' in gen_params: + raise NotImplementedError('unsupported parameter: max_tokens') + gen_params = {**self.gen_params, **gen_params} + tasks = [ + self._chat(messages, **gen_params) for messages in ( + [inputs] if isinstance(inputs[0], dict) else inputs) + ] + ret = await asyncio.gather(*tasks) + return ret[0] if isinstance(inputs[0], dict) else ret + + async def stream_chat( + self, + inputs: List[dict], + **gen_params, + ): + """Generate responses given the contexts. + + Args: + inputs (List[dict]): a list of messages + gen_params: additional generation configuration + + Returns: + str: generated string + """ + assert isinstance(inputs, list) + if 'max_tokens' in gen_params: + raise NotImplementedError('unsupported parameter: max_tokens') + gen_params = self.update_gen_params(**gen_params) + gen_params['stream'] = True + + resp = '' + finished = False + stop_words = gen_params.get('stop_words') + if stop_words is None: + stop_words = [] + # mapping to role that openai supports + messages = self.template_parser._prompt2api(inputs) + async for text in self._stream_chat(messages, **gen_params): + if self.model_type.lower().startswith('qwen'): + resp = text + else: + resp += text + if not resp: + continue + # remove stop_words + for sw in stop_words: + if sw in resp: + resp = filter_suffix(resp, stop_words) + finished = True + break + yield ModelStatusCode.STREAM_ING, resp, None + if finished: + break + yield ModelStatusCode.END, resp, None + + async def _chat(self, messages: List[dict], **gen_params) -> str: + """Generate completion from a list of templates. + + Args: + messages (List[dict]): a list of prompt dictionaries + gen_params: additional generation configuration + + Returns: + str: The generated string. + """ + assert isinstance(messages, list) + + header, data = self.generate_request_data( + model_type=self.model_type, + messages=messages, + gen_params=gen_params, + json_mode=self.json_mode) + + max_num_retries, errmsg = 0, '' + while max_num_retries < self.retry: + if len(self.invalid_keys) == len(self.keys): + raise RuntimeError('All keys have insufficient quota.') + + # find the next valid key + while True: + self.key_ctr += 1 + if self.key_ctr == len(self.keys): + self.key_ctr = 0 + + if self.keys[self.key_ctr] not in self.invalid_keys: + break + + key = self.keys[self.key_ctr] + header['Authorization'] = f'Bearer {key}' + + if self.orgs: + self.org_ctr += 1 + if self.org_ctr == len(self.orgs): + self.org_ctr = 0 + header['OpenAI-Organization'] = self.orgs[self.org_ctr] + + response = dict() + try: + async with aiohttp.ClientSession() as session: + async with session.post( + self.url, + headers=header, + json=data, + proxy=self.proxies.get( + 'https', self.proxies.get('http'))) as resp: + response = await resp.json() + return response['choices'][0]['message'][ + 'content'].strip() + except aiohttp.ClientConnectionError: + errmsg = 'Got connection error ' + str(traceback.format_exc()) + self.logger.error(errmsg) + continue + except aiohttp.ClientResponseError as e: + errmsg = 'Response error, got ' + str(e) + self.logger.error(errmsg) + continue + except json.JSONDecodeError: + errmsg = 'JsonDecode error, got ' + (await resp.text( + errors='replace')) + self.logger.error(errmsg) + continue + except KeyError: + if 'error' in response: + if response['error']['code'] == 'rate_limit_exceeded': + time.sleep(1) + continue + elif response['error']['code'] == 'insufficient_quota': + self.invalid_keys.add(key) + self.logger.warn(f'insufficient_quota key: {key}') + continue + + errmsg = 'Find error message in response: ' + str( + response['error']) + self.logger.error(errmsg) + except Exception as error: + errmsg = str(error) + '\n' + str(traceback.format_exc()) + self.logger.error(errmsg) + max_num_retries += 1 + + raise RuntimeError('Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}') + + async def _stream_chat(self, messages: List[dict], + **gen_params) -> AsyncGenerator[str, None]: + """Generate completion from a list of templates. + + Args: + messages (List[dict]): a list of prompt dictionaries + gen_params: additional generation configuration + + Returns: + str: The generated string. + """ + + async def streaming(raw_response): + async for chunk in raw_response.content: + if chunk: + decoded = chunk.decode('utf-8') + if decoded.startswith('data: [DONE]'): + return + if decoded[:5] == 'data:': + decoded = decoded[5:] + if decoded[0] == ' ': + decoded = decoded[1:] + else: + print(decoded) + continue + try: + response = json.loads(decoded) + if 'code' in response and response['code'] == -20003: + # Context exceeds maximum length + yield '' + return + if self.model_type.lower().startswith('qwen'): + choice = response['output']['choices'][0] + yield choice['message']['content'] + if choice['finish_reason'] == 'stop': + return + else: + choice = response['choices'][0] + if choice['finish_reason'] == 'stop': + return + yield choice['delta'].get('content', '') + except Exception as exc: + msg = f'response {decoded} lead to exception of {str(exc)}' + self.logger.error(msg) + raise Exception(msg) from exc + + assert isinstance(messages, list) + + header, data = self.generate_request_data( + model_type=self.model_type, + messages=messages, + gen_params=gen_params, + json_mode=self.json_mode) + + max_num_retries, errmsg = 0, '' + while max_num_retries < self.retry: + if len(self.invalid_keys) == len(self.keys): + raise RuntimeError('All keys have insufficient quota.') + + # find the next valid key + while True: + self.key_ctr += 1 + if self.key_ctr == len(self.keys): + self.key_ctr = 0 + + if self.keys[self.key_ctr] not in self.invalid_keys: + break + + key = self.keys[self.key_ctr] + header['Authorization'] = f'Bearer {key}' + + if self.orgs: + self.org_ctr += 1 + if self.org_ctr == len(self.orgs): + self.org_ctr = 0 + header['OpenAI-Organization'] = self.orgs[self.org_ctr] + + response = dict() + try: + async with aiohttp.ClientSession() as session: + async with session.post( + self.url, + headers=header, + json=data, + proxy=self.proxies.get( + 'https', + self.proxies.get('http'))) as raw_response: + async for msg in streaming(raw_response): + yield msg + return + except aiohttp.ClientConnectionError: + errmsg = 'Got connection error ' + str(traceback.format_exc()) + self.logger.error(errmsg) + continue + except aiohttp.ClientResponseError as e: + errmsg = 'Response error, got ' + str(e) + self.logger.error(errmsg) + continue + except KeyError: + if 'error' in response: + if response['error']['code'] == 'rate_limit_exceeded': + time.sleep(1) + continue + elif response['error']['code'] == 'insufficient_quota': + self.invalid_keys.add(key) + self.logger.warn(f'insufficient_quota key: {key}') + continue + + errmsg = 'Find error message in response: ' + str( + response['error']) + self.logger.error(errmsg) + except Exception as error: + errmsg = str(error) + '\n' + str(traceback.format_exc()) + self.logger.error(errmsg) + max_num_retries += 1 + + raise RuntimeError('Calling OpenAI failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + f'details. errmsg: {errmsg}') + + def generate_request_data(self, + model_type, + messages, + gen_params, + json_mode=False): + """ + Generates the request data for different model types. + + Args: + model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen'). + messages (list): The list of messages to be sent to the model. + gen_params (dict): The generation parameters. + json_mode (bool): Flag to determine if the response format should be JSON. + + Returns: + tuple: A tuple containing the header and the request data. + """ + # Copy generation parameters to avoid modifying the original dictionary + gen_params = gen_params.copy() + + # Hold out 100 tokens due to potential errors in token calculation + max_tokens = min(gen_params.pop('max_new_tokens'), 4096) + if max_tokens <= 0: + return '', '' + + # Initialize the header + header = { + 'content-type': 'application/json', + } + + # Common parameters processing + gen_params['max_tokens'] = max_tokens + if 'stop_words' in gen_params: + gen_params['stop'] = gen_params.pop('stop_words') + if 'repetition_penalty' in gen_params: + gen_params['frequency_penalty'] = gen_params.pop( + 'repetition_penalty') + + # Model-specific processing + data = {} + if model_type.lower().startswith('gpt'): + if 'top_k' in gen_params: + warnings.warn( + '`top_k` parameter is deprecated in OpenAI APIs.', + DeprecationWarning) + gen_params.pop('top_k') + gen_params.pop('skip_special_tokens', None) + gen_params.pop('session_id', None) + data = { + 'model': model_type, + 'messages': messages, + 'n': 1, + **gen_params + } + if json_mode: + data['response_format'] = {'type': 'json_object'} + elif model_type.lower().startswith('internlm'): + data = { + 'model': model_type, + 'messages': messages, + 'n': 1, + **gen_params + } + if json_mode: + data['response_format'] = {'type': 'json_object'} + elif model_type.lower().startswith('qwen'): + header['X-DashScope-SSE'] = 'enable' + gen_params.pop('skip_special_tokens', None) + gen_params.pop('session_id', None) + if 'frequency_penalty' in gen_params: + gen_params['repetition_penalty'] = gen_params.pop( + 'frequency_penalty') + gen_params['result_format'] = 'message' + data = { + 'model': model_type, + 'input': { + 'messages': messages + }, + 'parameters': { + **gen_params + } + } + else: + raise NotImplementedError( + f'Model type {model_type} is not supported') + + return header, data + + def tokenize(self, prompt: str) -> list: + """Tokenize the input prompt. + + Args: + prompt (str): Input string. + + Returns: + list: token ids + """ + import tiktoken + self.tiktoken = tiktoken + enc = self.tiktoken.encoding_for_model(self.model_type) + return enc.encode(prompt) diff --git a/lagent/llms/sensenova.py b/lagent/llms/sensenova.py new file mode 100644 index 0000000000000000000000000000000000000000..54a9f9eeb0f6b90d1c2ddb6cfaa415a7b5927ba4 --- /dev/null +++ b/lagent/llms/sensenova.py @@ -0,0 +1,406 @@ +import json +import os +import time +import warnings +from concurrent.futures import ThreadPoolExecutor +from logging import getLogger +from threading import Lock +from typing import Dict, Generator, List, Optional, Tuple, Union + +import requests + +from lagent.schema import ModelStatusCode +from lagent.utils.util import filter_suffix +from .base_api import BaseAPILLM + +warnings.simplefilter('default') + +SENSENOVA_API_BASE = 'https://api.sensenova.cn/v1/llm/chat-completions' + +sensechat_models = {'SenseChat-5': 131072, 'SenseChat-5-Cantonese': 32768} + + +class SensenovaAPI(BaseAPILLM): + """Model wrapper around SenseTime's models. + + Args: + model_type (str): The name of SenseTime's model. + retry (int): Number of retires if the API call fails. Defaults to 2. + key (str or List[str]): SenseTime key(s). In particular, when it + is set to "ENV", the key will be fetched from the environment + variable $SENSENOVA_API_KEY. If it's a list, the keys will be + used in round-robin manner. Defaults to 'ENV'. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + sensenova_api_base (str): The base url of SenseTime's API. Defaults to + 'https://api.sensenova.cn/v1/llm/chat-completions'. + gen_params: Default generation configuration which could be overridden + on the fly of generation. + """ + + is_api: bool = True + + def __init__( + self, + model_type: str = 'SenseChat-5-Cantonese', + retry: int = 2, + json_mode: bool = False, + key: Union[str, List[str]] = 'ENV', + meta_template: Optional[Dict] = [ + dict(role='system', api_role='system'), + dict(role='user', api_role='user'), + dict(role='assistant', api_role='assistant'), + dict(role='environment', api_role='system'), + ], + sensenova_api_base: str = SENSENOVA_API_BASE, + proxies: Optional[Dict] = None, + **gen_params, + ): + + super().__init__( + model_type=model_type, + meta_template=meta_template, + retry=retry, + **gen_params, + ) + self.logger = getLogger(__name__) + + if isinstance(key, str): + # First, apply for SenseNova's ak and sk from SenseTime staff + # Then, generated SENSENOVA_API_KEY using lagent.utils.gen_key.auto_gen_jwt_token(ak, sk) + self.keys = [ + os.getenv('SENSENOVA_API_KEY') if key == 'ENV' else key + ] + else: + self.keys = key + + # record invalid keys and skip them when requesting API + # - keys have insufficient_quota + self.invalid_keys = set() + + self.key_ctr = 0 + self.url = sensenova_api_base + self.model_type = model_type + self.proxies = proxies + self.json_mode = json_mode + + def chat( + self, + inputs: Union[List[dict], List[List[dict]]], + **gen_params, + ) -> Union[str, List[str]]: + """Generate responses given the contexts. + + Args: + inputs (Union[List[dict], List[List[dict]]]): a list of messages + or list of lists of messages + gen_params: additional generation configuration + + Returns: + Union[str, List[str]]: generated string(s) + """ + assert isinstance(inputs, list) + if 'max_tokens' in gen_params: + raise NotImplementedError('unsupported parameter: max_tokens') + gen_params = {**self.gen_params, **gen_params} + with ThreadPoolExecutor(max_workers=20) as executor: + tasks = [ + executor.submit(self._chat, + self.template_parser._prompt2api(messages), + **gen_params) + for messages in ( + [inputs] if isinstance(inputs[0], dict) else inputs) + ] + ret = [task.result() for task in tasks] + return ret[0] if isinstance(inputs[0], dict) else ret + + def stream_chat( + self, + inputs: List[dict], + **gen_params, + ) -> Generator[Tuple[ModelStatusCode, str, Optional[str]], None, None]: + """Generate responses given the contexts. + + Args: + inputs (List[dict]): a list of messages + gen_params: additional generation configuration + + Yields: + Tuple[ModelStatusCode, str, Optional[str]]: Status code, generated string, and optional metadata + """ + assert isinstance(inputs, list) + if 'max_tokens' in gen_params: + raise NotImplementedError('unsupported parameter: max_tokens') + gen_params = self.update_gen_params(**gen_params) + gen_params['stream'] = True + + resp = '' + finished = False + stop_words = gen_params.get('stop_words') or [] + messages = self.template_parser._prompt2api(inputs) + for text in self._stream_chat(messages, **gen_params): + # TODO 测试 resp = text 还是 resp += text + resp += text + if not resp: + continue + # remove stop_words + for sw in stop_words: + if sw in resp: + resp = filter_suffix(resp, stop_words) + finished = True + break + yield ModelStatusCode.STREAM_ING, resp, None + if finished: + break + yield ModelStatusCode.END, resp, None + + def _chat(self, messages: List[dict], **gen_params) -> str: + """Generate completion from a list of templates. + + Args: + messages (List[dict]): a list of prompt dictionaries + gen_params: additional generation configuration + + Returns: + str: The generated string. + """ + assert isinstance(messages, list) + + header, data = self.generate_request_data( + model_type=self.model_type, + messages=messages, + gen_params=gen_params, + json_mode=self.json_mode, + ) + + max_num_retries = 0 + while max_num_retries < self.retry: + self._wait() + + with Lock(): + if len(self.invalid_keys) == len(self.keys): + raise RuntimeError('All keys have insufficient quota.') + + # find the next valid key + while True: + self.key_ctr += 1 + if self.key_ctr == len(self.keys): + self.key_ctr = 0 + + if self.keys[self.key_ctr] not in self.invalid_keys: + break + + key = self.keys[self.key_ctr] + header['Authorization'] = f'Bearer {key}' + + response = dict() + try: + raw_response = requests.post( + self.url, + headers=header, + data=json.dumps(data), + proxies=self.proxies, + ) + response = raw_response.json() + return response['choices'][0]['message']['content'].strip() + except requests.ConnectionError: + print('Got connection error, retrying...') + continue + except requests.JSONDecodeError: + print('JsonDecode error, got', str(raw_response.content)) + continue + except KeyError: + if 'error' in response: + if response['error']['code'] == 'rate_limit_exceeded': + time.sleep(1) + continue + elif response['error']['code'] == 'insufficient_quota': + self.invalid_keys.add(key) + self.logger.warn(f'insufficient_quota key: {key}') + continue + + print('Find error message in response: ', + str(response['error'])) + except Exception as error: + print(str(error)) + max_num_retries += 1 + + raise RuntimeError('Calling SenseTime failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + 'details.') + + def _stream_chat(self, messages: List[dict], **gen_params) -> str: + """Generate completion from a list of templates. + + Args: + messages (List[dict]): a list of prompt dictionaries + gen_params: additional generation configuration + + Returns: + str: The generated string. + """ + + def streaming(raw_response): + for chunk in raw_response.iter_lines(): + if chunk: + try: + decoded_chunk = chunk.decode('utf-8') + # print(f"Decoded chunk: {decoded_chunk}") + + if decoded_chunk == 'data:[DONE]': + # print("Stream ended") + break + + if decoded_chunk.startswith('data:'): + json_str = decoded_chunk[5:] + chunk_data = json.loads(json_str) + + if 'data' in chunk_data and 'choices' in chunk_data[ + 'data']: + choice = chunk_data['data']['choices'][0] + if 'delta' in choice: + content = choice['delta'] + yield content + else: + print(f'Unexpected format: {decoded_chunk}') + + except json.JSONDecodeError as e: + print(f'JSON parsing error: {e}') + except Exception as e: + print( + f'An error occurred while processing the chunk: {e}' + ) + + assert isinstance(messages, list) + + header, data = self.generate_request_data( + model_type=self.model_type, + messages=messages, + gen_params=gen_params, + json_mode=self.json_mode, + ) + + max_num_retries = 0 + while max_num_retries < self.retry: + if len(self.invalid_keys) == len(self.keys): + raise RuntimeError('All keys have insufficient quota.') + + # find the next valid key + while True: + self.key_ctr += 1 + if self.key_ctr == len(self.keys): + self.key_ctr = 0 + + if self.keys[self.key_ctr] not in self.invalid_keys: + break + + key = self.keys[self.key_ctr] + header['Authorization'] = f'Bearer {key}' + + response = dict() + try: + raw_response = requests.post( + self.url, + headers=header, + data=json.dumps(data), + proxies=self.proxies, + ) + return streaming(raw_response) + except requests.ConnectionError: + print('Got connection error, retrying...') + continue + except requests.JSONDecodeError: + print('JsonDecode error, got', str(raw_response.content)) + continue + except KeyError: + if 'error' in response: + if response['error']['code'] == 'rate_limit_exceeded': + time.sleep(1) + continue + elif response['error']['code'] == 'insufficient_quota': + self.invalid_keys.add(key) + self.logger.warn(f'insufficient_quota key: {key}') + continue + + print('Find error message in response: ', + str(response['error'])) + except Exception as error: + print(str(error)) + max_num_retries += 1 + + raise RuntimeError('Calling SenseTime failed after retrying for ' + f'{max_num_retries} times. Check the logs for ' + 'details.') + + def generate_request_data(self, + model_type, + messages, + gen_params, + json_mode=False): + """ + Generates the request data for different model types. + + Args: + model_type (str): The type of the model (e.g., 'sense'). + messages (list): The list of messages to be sent to the model. + gen_params (dict): The generation parameters. + json_mode (bool): Flag to determine if the response format should be JSON. + + Returns: + tuple: A tuple containing the header and the request data. + """ + # Copy generation parameters to avoid modifying the original dictionary + gen_params = gen_params.copy() + + # Hold out 100 tokens due to potential errors in token calculation + max_tokens = min(gen_params.pop('max_new_tokens'), 4096) + if max_tokens <= 0: + return '', '' + + # Initialize the header + header = { + 'content-type': 'application/json', + } + + # Common parameters processing + gen_params['max_tokens'] = max_tokens + if 'stop_words' in gen_params: + gen_params['stop'] = gen_params.pop('stop_words') + if 'repetition_penalty' in gen_params: + gen_params['frequency_penalty'] = gen_params.pop( + 'repetition_penalty') + + # Model-specific processing + data = {} + if model_type.lower().startswith('sense'): + gen_params.pop('skip_special_tokens', None) + gen_params.pop('session_id', None) + data = { + 'model': model_type, + 'messages': messages, + 'n': 1, + **gen_params + } + if json_mode: + data['response_format'] = {'type': 'json_object'} + else: + raise NotImplementedError( + f'Model type {model_type} is not supported') + + return header, data + + def tokenize(self, prompt: str) -> list: + """Tokenize the input prompt. + + Args: + prompt (str): Input string. + + Returns: + list: token ids + """ + import tiktoken + + self.tiktoken = tiktoken + enc = self.tiktoken.encoding_for_model('gpt-4o') + return enc.encode(prompt) diff --git a/lagent/llms/vllm_wrapper.py b/lagent/llms/vllm_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..c4a01cbe5b26ce2676bbc1beaf201483ee6a233b --- /dev/null +++ b/lagent/llms/vllm_wrapper.py @@ -0,0 +1,176 @@ +import asyncio +from typing import List, Union + +from lagent.llms.base_llm import AsyncBaseLLM, BaseLLM +from lagent.utils.util import filter_suffix + + +def asdict_completion(output): + return { + key: getattr(output, key) + for key in [ + 'text', 'token_ids', 'cumulative_logprob', 'logprobs', + 'finish_reason', 'stop_reason' + ] + } + + +class VllmModel(BaseLLM): + """ + A wrapper of vLLM model. + + Args: + path (str): The path to the model. + It could be one of the following options: + - i) A local directory path of a huggingface model. + - ii) The model_id of a model hosted inside a model repo + on huggingface.co, such as "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" + and so on. + tp (int): tensor parallel + vllm_cfg (dict): Other kwargs for vllm model initialization. + """ + + def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs): + + super().__init__(path=path, **kwargs) + from vllm import LLM + self.model = LLM( + model=self.path, + trust_remote_code=True, + tensor_parallel_size=tp, + **vllm_cfg) + + def generate(self, + inputs: Union[str, List[str]], + do_preprocess: bool = None, + skip_special_tokens: bool = False, + return_dict: bool = False, + **kwargs): + """Return the chat completions in non-stream mode. + + Args: + inputs (Union[str, List[str]]): input texts to be completed. + do_preprocess (bool): whether pre-process the messages. Default to + True, which means chat_template will be applied. + skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be False. + Returns: + (a list of/batched) text/chat completion + """ + from vllm import SamplingParams + + batched = True + if isinstance(inputs, str): + inputs = [inputs] + batched = False + prompt = inputs + gen_params = self.update_gen_params(**kwargs) + max_new_tokens = gen_params.pop('max_new_tokens') + stop_words = gen_params.pop('stop_words') + + sampling_config = SamplingParams( + skip_special_tokens=skip_special_tokens, + max_tokens=max_new_tokens, + stop=stop_words, + **gen_params) + response = self.model.generate(prompt, sampling_params=sampling_config) + texts = [resp.outputs[0].text for resp in response] + # remove stop_words + texts = filter_suffix(texts, self.gen_params.get('stop_words')) + for resp, text in zip(response, texts): + resp.outputs[0].text = text + if batched: + return [asdict_completion(resp.outputs[0]) + for resp in response] if return_dict else texts + return asdict_completion( + response[0].outputs[0]) if return_dict else texts[0] + + +class AsyncVllmModel(AsyncBaseLLM): + """ + A asynchronous wrapper of vLLM model. + + Args: + path (str): The path to the model. + It could be one of the following options: + - i) A local directory path of a huggingface model. + - ii) The model_id of a model hosted inside a model repo + on huggingface.co, such as "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" + and so on. + tp (int): tensor parallel + vllm_cfg (dict): Other kwargs for vllm model initialization. + """ + + def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs): + super().__init__(path=path, **kwargs) + from vllm import AsyncEngineArgs, AsyncLLMEngine + + engine_args = AsyncEngineArgs( + model=self.path, + trust_remote_code=True, + tensor_parallel_size=tp, + **vllm_cfg) + self.model = AsyncLLMEngine.from_engine_args(engine_args) + + async def generate(self, + inputs: Union[str, List[str]], + session_ids: Union[int, List[int]] = None, + do_preprocess: bool = None, + skip_special_tokens: bool = False, + return_dict: bool = False, + **kwargs): + """Return the chat completions in non-stream mode. + + Args: + inputs (Union[str, List[str]]): input texts to be completed. + do_preprocess (bool): whether pre-process the messages. Default to + True, which means chat_template will be applied. + skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be False. + Returns: + (a list of/batched) text/chat completion + """ + from vllm import SamplingParams + + batched = True + if isinstance(inputs, str): + inputs = [inputs] + batched = False + if session_ids is None: + session_ids = list(range(len(inputs))) + elif isinstance(session_ids, (int, str)): + session_ids = [session_ids] + assert len(inputs) == len(session_ids) + + prompt = inputs + gen_params = self.update_gen_params(**kwargs) + max_new_tokens = gen_params.pop('max_new_tokens') + stop_words = gen_params.pop('stop_words') + + sampling_config = SamplingParams( + skip_special_tokens=skip_special_tokens, + max_tokens=max_new_tokens, + stop=stop_words, + **gen_params) + + async def _inner_generate(uid, text): + resp, generator = '', self.model.generate( + text, sampling_params=sampling_config, request_id=uid) + async for out in generator: + resp = out.outputs[0] + return resp + + response = await asyncio.gather(*[ + _inner_generate(sid, inp) for sid, inp in zip(session_ids, prompt) + ]) + texts = [resp.text for resp in response] + # remove stop_words + texts = filter_suffix(texts, self.gen_params.get('stop_words')) + for resp, text in zip(response, texts): + resp.text = text + if batched: + return [asdict_completion(resp) + for resp in response] if return_dict else texts + return asdict_completion(response[0]) if return_dict else texts[0] diff --git a/lagent/memory/__init__.py b/lagent/memory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13534ada0dd646809bb41c10e2fb82c6c18cbdcf --- /dev/null +++ b/lagent/memory/__init__.py @@ -0,0 +1,4 @@ +from .base_memory import Memory +from .manager import MemoryManager + +__all__ = ['Memory', 'MemoryManager'] diff --git a/lagent/memory/__pycache__/__init__.cpython-310.pyc b/lagent/memory/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86fdf88a6f6393ca8f2f231c4e472086ac9b8e4b Binary files /dev/null and b/lagent/memory/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/memory/__pycache__/base_memory.cpython-310.pyc b/lagent/memory/__pycache__/base_memory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e3608816a2181bf2457d02b4f38b97116b10e8e Binary files /dev/null and b/lagent/memory/__pycache__/base_memory.cpython-310.pyc differ diff --git a/lagent/memory/__pycache__/manager.cpython-310.pyc b/lagent/memory/__pycache__/manager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..907354951e0e4a8a8f12fd153d35214451080614 Binary files /dev/null and b/lagent/memory/__pycache__/manager.cpython-310.pyc differ diff --git a/lagent/memory/base_memory.py b/lagent/memory/base_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..c60d9780715eed3a4d43c4b48ca175b44abef611 --- /dev/null +++ b/lagent/memory/base_memory.py @@ -0,0 +1,60 @@ +from typing import Callable, Dict, List, Optional, Union + +from lagent.schema import AgentMessage + + +class Memory: + + def __init__(self, recent_n=None) -> None: + self.memory: List[AgentMessage] = [] + self.recent_n = recent_n + + def get_memory( + self, + recent_n: Optional[int] = None, + filter_func: Optional[Callable[[int, dict], bool]] = None, + ) -> list: + recent_n = recent_n or self.recent_n + if recent_n is not None: + memory = self.memory[-recent_n:] + else: + memory = self.memory + if filter_func is not None: + memory = [m for i, m in enumerate(memory) if filter_func(i, m)] + return memory + + def add(self, memories: Union[List[Dict], Dict, None]) -> None: + for memory in memories if isinstance(memories, + (list, tuple)) else [memories]: + if isinstance(memory, str): + memory = AgentMessage(sender='user', content=memory) + if isinstance(memory, AgentMessage): + self.memory.append(memory) + + def delete(self, index: Union[List, int]) -> None: + if isinstance(index, int): + del self.memory[index] + else: + for i in index: + del self.memory[i] + + def load( + self, + memories: Union[str, Dict, List], + overwrite: bool = True, + ) -> None: + if overwrite: + self.memory = [] + if isinstance(memories, dict): + self.memory.append(AgentMessage(**memories)) + elif isinstance(memories, list): + for m in memories: + self.memory.append(AgentMessage(**m)) + else: + raise TypeError(f'{type(memories)} is not supported') + + def save(self) -> List[dict]: + memory = [] + for m in self.memory: + memory.append(m.model_dump()) + return memory diff --git a/lagent/memory/manager.py b/lagent/memory/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..1d42ac9600dc86936aa8c41fb3d73d5c7110b9bb --- /dev/null +++ b/lagent/memory/manager.py @@ -0,0 +1,29 @@ +from typing import Dict + +from ..utils import create_object +from .base_memory import Memory + + +class MemoryManager: + + def __init__(self, cfg) -> None: + self.cfg = cfg + self.memory_map: Dict[str, Memory] = {} + + def create_instance(self, session_id): + self.memory_map[session_id] = create_object(self.cfg) + + def get_memory(self, session_id=0, **kwargs) -> list: + return self.memory_map[session_id].get_memory(**kwargs) + + def add(self, memory, session_id=0, **kwargs) -> None: + if session_id not in self.memory_map: + self.create_instance(session_id) + self.memory_map[session_id].add(memory, **kwargs) + + def get(self, session_id=0) -> Memory: + return self.memory_map.get(session_id, None) + + def reset(self, session_id=0) -> None: + if session_id in self.memory_map: + del self.memory_map[session_id] diff --git a/lagent/prompts/__init__.py b/lagent/prompts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7f76f4705a8f03a0a0983f8c33fdf8ee1409866 --- /dev/null +++ b/lagent/prompts/__init__.py @@ -0,0 +1,4 @@ +from .parsers import * # noqa +from .prompt_template import PromptTemplate + +__all__ = ['PromptTemplate'] diff --git a/lagent/prompts/__pycache__/__init__.cpython-310.pyc b/lagent/prompts/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63941ec568925abd242824c9ba50bcd7d5bc5d5c Binary files /dev/null and b/lagent/prompts/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/prompts/__pycache__/prompt_template.cpython-310.pyc b/lagent/prompts/__pycache__/prompt_template.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40722a657c65ed221c205183d13f7e75dff83368 Binary files /dev/null and b/lagent/prompts/__pycache__/prompt_template.cpython-310.pyc differ diff --git a/lagent/prompts/parsers/__init__.py b/lagent/prompts/parsers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b2ce25332d926663a1afc98f995d6ef0aba4ff --- /dev/null +++ b/lagent/prompts/parsers/__init__.py @@ -0,0 +1,9 @@ +from .custom_parser import CustomFormatParser +from .json_parser import JSONParser +from .str_parser import StrParser +from .tool_parser import InterpreterParser, MixedToolParser, PluginParser, ToolParser, ToolStatusCode + +__all__ = [ + 'CustomFormatParser', 'JSONParser', 'StrParser', 'ToolParser', + 'InterpreterParser', 'PluginParser', 'MixedToolParser', 'ToolStatusCode' +] diff --git a/lagent/prompts/parsers/__pycache__/__init__.cpython-310.pyc b/lagent/prompts/parsers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c04a59ad7501a1ac8b0011a3c0ad6d125594db1a Binary files /dev/null and b/lagent/prompts/parsers/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/prompts/parsers/__pycache__/custom_parser.cpython-310.pyc b/lagent/prompts/parsers/__pycache__/custom_parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d27ca6117f84e95c58b6658dfcc18f91585caa7c Binary files /dev/null and b/lagent/prompts/parsers/__pycache__/custom_parser.cpython-310.pyc differ diff --git a/lagent/prompts/parsers/__pycache__/json_parser.cpython-310.pyc b/lagent/prompts/parsers/__pycache__/json_parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11fc0800a97cbcb5301b4183acf221d14f1889cc Binary files /dev/null and b/lagent/prompts/parsers/__pycache__/json_parser.cpython-310.pyc differ diff --git a/lagent/prompts/parsers/__pycache__/str_parser.cpython-310.pyc b/lagent/prompts/parsers/__pycache__/str_parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c09e9bb26fa411b39ae53f4b854b29aaff948d26 Binary files /dev/null and b/lagent/prompts/parsers/__pycache__/str_parser.cpython-310.pyc differ diff --git a/lagent/prompts/parsers/__pycache__/tool_parser.cpython-310.pyc b/lagent/prompts/parsers/__pycache__/tool_parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4fce1deb336399344d97164eaab34e5b1fb9ffe Binary files /dev/null and b/lagent/prompts/parsers/__pycache__/tool_parser.cpython-310.pyc differ diff --git a/lagent/prompts/parsers/custom_parser.py b/lagent/prompts/parsers/custom_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..68171522e6d452abc2694b2dfc727ff508e78e05 --- /dev/null +++ b/lagent/prompts/parsers/custom_parser.py @@ -0,0 +1,92 @@ +import re +from typing import Any, Dict, Union + +from pydantic import BaseModel, ValidationError + +from lagent.prompts.parsers.str_parser import StrParser + + +class CustomFormatParser(StrParser): + + def _extract_fields_with_metadata( + self, model: BaseModel) -> Dict[str, Dict[str, Any]]: + fields_metadata = {} + for field_name, field in model.model_fields.items(): + fields_metadata[field_name] = { + 'annotation': field.annotation, + 'default': field.default + if field.default is not None else '', + 'comment': field.description if field.description else '' + } + return fields_metadata + + def format_to_string(self, format_model: BaseModel) -> str: + fields = self._extract_fields_with_metadata(format_model) + formatted_str = '' + for field_name, metadata in fields.items(): + comment = metadata.get('comment', '') + field_annotation = metadata['annotation'].__name__ if metadata[ + 'annotation'] is not None else 'Any' + if comment: + formatted_str += f'\n' + formatted_str += f'<{field_name} type="{field_annotation}">{metadata["default"] if metadata["default"] != "" else ""}\n' + return formatted_str + + def parse_response(self, data: str) -> Union[dict, BaseModel]: + pattern = re.compile(r'()?\s*<(\w+)[^>]*>(.*?)', + re.DOTALL) + matches = pattern.findall(data) + + data_dict = {} + for _, comment_text, key, value in matches: + if comment_text: + self.fields[key]['comment'] = comment_text.strip() + data_dict[key] = value + + model = self.default_format + if self.unknown_format and not self._is_valid_format( + data_dict, self.default_format): + model = self.unknown_format + + return model.model_validate(data_dict) + + def _is_valid_format(self, data: Dict, format_model: BaseModel) -> bool: + try: + format_model.model_validate(data) + return True + except ValidationError: + return False + + +if __name__ == '__main__': + # Example usage + class DefaultFormat(BaseModel): + name: str + age: int + + class UnknownFormat(BaseModel): + title: str + year: int + + template = """如果了解该问题请按照一下格式回复 + ```html + {format} + ``` + 否则请回复 + ```html + {unknown_format} + ``` + """ + parser = CustomFormatParser( + template, default_format=DefaultFormat, unknown_format=UnknownFormat) + + # Example data + response = ''' + + John Doe + + 30 + ''' + + result = parser.parse_response(response) + print(result) diff --git a/lagent/prompts/parsers/json_parser.py b/lagent/prompts/parsers/json_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..49f57215b184030e88d454294c6a2756589699de --- /dev/null +++ b/lagent/prompts/parsers/json_parser.py @@ -0,0 +1,183 @@ +import json +from typing import Any, Dict, List, Union, get_args, get_origin + +from pydantic import BaseModel, Field +from pydantic_core import PydanticUndefined + +from lagent.prompts.parsers.str_parser import StrParser + + +def get_field_type_name(field_type): + # 获取字段类型的起源类型(对于复合类型,如 List、Dict 等) + origin = get_origin(field_type) + if origin: + # 获取复合类型的所有参数 + args = get_args(field_type) + # 重新构建类型名称,例如 List[str] 或 Optional[int] + args_str = ', '.join([get_field_type_name(arg) for arg in args]) + return f'{origin.__name__}[{args_str}]' + # 如果不是复合类型,直接返回类型的名称 + elif hasattr(field_type, '__name__'): + return field_type.__name__ + else: + return str(field_type) # 处理一些特殊情况,如来自未知库的类型 + + +# class JSONParser(BaseParser): +class JSONParser(StrParser): + + def _extract_fields_with_metadata( + self, model: BaseModel) -> Dict[str, Dict[str, Any]]: + fields_metadata = {} + for field_name, field in model.model_fields.items(): + fields_metadata[field_name] = { + 'annotation': field.annotation, + 'default': field.default + if field.default is not PydanticUndefined else '', + 'comment': field.description if field.description else '' + } + + # 类型检查,以支持 BaseModel 的子类 + origin = get_origin(field.annotation) + args = get_args(field.annotation) + if origin is None: + # 不是复合类型,直接检查是否为 BaseModel 的子类 + if isinstance(field.annotation, type) and issubclass( + field.annotation, BaseModel): + fields_metadata[field_name][ + 'fields'] = self._extract_fields_with_metadata( + field.annotation) + else: + # 是复合类型,检查其中是否有 BaseModel 的子类 + for arg in args: + if isinstance(arg, type) and issubclass(arg, BaseModel): + fields_metadata[field_name][ + 'fields'] = self._extract_fields_with_metadata(arg) + break + return fields_metadata + + def _format_field(self, + field_name: str, + metadata: Dict[str, Any], + indent: int = 1) -> str: + comment = metadata.get('comment', '') + field_type = get_field_type_name( + metadata['annotation'] + ) if metadata['annotation'] is not None else 'Any' + default_value = metadata['default'] + indent_str = ' ' * indent + formatted_lines = [] + + if comment: + formatted_lines.append(f'{indent_str}// {comment}') + + if 'fields' in metadata: + formatted_lines.append(f'{indent_str}"{field_name}": {{') + for sub_field_name, sub_metadata in metadata['fields'].items(): + formatted_lines.append( + self._format_field(sub_field_name, sub_metadata, + indent + 1)) + formatted_lines.append(f'{indent_str}}},') + else: + if default_value == '': + formatted_lines.append( + f'{indent_str}"{field_name}": "{field_type}", // required' + ) + else: + formatted_lines.append( + f'{indent_str}"{field_name}": "{field_type}", // default: {default_value}' + ) + + return '\n'.join(formatted_lines) + + def format_to_string(self, format_model) -> str: + fields = self._extract_fields_with_metadata(format_model) + formatted_lines = [] + for field_name, metadata in fields.items(): + formatted_lines.append(self._format_field(field_name, metadata)) + + # Remove the trailing comma from the last line + if formatted_lines and formatted_lines[-1].endswith(','): + formatted_lines[-1] = formatted_lines[-1].rstrip(',') + + return '{\n' + '\n'.join(formatted_lines) + '\n}' + + def parse_response(self, data: str) -> Union[dict, BaseModel]: + # Remove comments + data_no_comments = '\n'.join( + line for line in data.split('\n') + if not line.strip().startswith('//')) + try: + data_dict = json.loads(data_no_comments) + parsed_data = {} + + for field_name, value in self.format_field.items(): + if self._is_valid_format(data_dict, value): + model = value + break + + self.fields = self._extract_fields_with_metadata(model) + + for field_name, value in data_dict.items(): + if field_name in self.fields: + metadata = self.fields[field_name] + if value in [ + 'str', 'int', 'float', 'bool', 'list', 'dict' + ]: + if metadata['default'] == '': + raise ValueError( + f"Field '{field_name}' is required but not provided" + ) + parsed_data[field_name] = metadata['default'] + else: + parsed_data[field_name] = value + + return model.model_validate(parsed_data).dict() + except json.JSONDecodeError: + raise ValueError('Input string is not a valid JSON.') + + def _is_valid_format(self, data: dict, format_model: BaseModel) -> bool: + try: + format_model.model_validate(data) + return True + except Exception: + return False + + +if __name__ == '__main__': + + # Example usage + class DefaultFormat(BaseModel): + name: List[str] = Field(description='Name of the person') + age: int = Field(description='Age of the person') + + class UnknownFormat(BaseModel): + title: str + year: int + + TEMPLATE = """如果了解该问题请按照一下格式回复 + ```json + {format} + ``` + 否则请回复 + ```json + {unknown_format} + ``` + """ + + parser = JSONParser( + template=TEMPLATE, + default_format=DefaultFormat, + unknown_format=UnknownFormat, + ) + + # Example data + data = ''' + { + "name": ["John Doe"], + "age": 30 + } + ''' + print(parser.format()) + result = parser.parse_response(data) + print(result) diff --git a/lagent/prompts/parsers/str_parser.py b/lagent/prompts/parsers/str_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..6af7aa6ecafacf8ff13ad0f3f2cbcdd41d65b735 --- /dev/null +++ b/lagent/prompts/parsers/str_parser.py @@ -0,0 +1,28 @@ +from typing import Any + + +class StrParser: + + def __init__( + self, + template: str = '', + **format_field, + ): + self.template = template + self.format_field = format_field + + def format_instruction(self) -> Any: + format_data = { + key: self.format_to_string(value) + for key, value in self.format_field.items() + } + return self.template.format(**format_data) + + def format_to_string(self, format_model: Any) -> str: + return format_model + + def format_response(self, parsed: dict) -> str: + raise NotImplementedError + + def parse_response(self, data: str) -> str: + return data diff --git a/lagent/prompts/parsers/tool_parser.py b/lagent/prompts/parsers/tool_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..534331275b71d1b443aba81bec710eae0deba88d --- /dev/null +++ b/lagent/prompts/parsers/tool_parser.py @@ -0,0 +1,143 @@ +import json +from enum import IntEnum + +# import re +from typing import Any, Callable, List, Optional + +from lagent.prompts.parsers import StrParser +from lagent.utils import create_object, load_class_from_string + + +def default_plugin_validate(plugin: str): + plugin = plugin.strip() + if not (plugin.startswith('{') and plugin.endswith("}")): + raise json.decoder.JSONDecodeError + return json.loads(plugin) + + +class ToolStatusCode(IntEnum): + NO_TOOL = 0 + VALID_TOOL = 1 + PARSING_ERROR = -1 + + +class ToolParser(StrParser): + + def __init__(self, + tool_type: str, + template: str = '', + begin: str = '\n', + end: str = '\n', + validate: Callable[[str], Any] = None, + **kwargs): + super().__init__(template, begin=begin, end=end, **kwargs) + self.template = template + self.tool_type = tool_type + # self.pattern = re.compile( + # '(.*?){}(.*)({})?'.format(re.escape(begin), re.escape(end)), + # re.DOTALL) + self.validate = load_class_from_string(validate) if isinstance( + validate, str) else validate + + def parse_response(self, data: str) -> dict: + if self.format_field['begin'] not in data: + return dict( + tool_type=None, + thought=data, + action=None, + status=ToolStatusCode.NO_TOOL) + thought, action, *_ = data.split(self.format_field["begin"]) + action = action.split(self.format_field['end'])[0] + status = ToolStatusCode.VALID_TOOL + if self.validate: + try: + action = self.validate(action) + except Exception: + status = ToolStatusCode.PARSING_ERROR + return dict( + tool_type=self.tool_type, + thought=thought, + action=action, + status=status) + + def format_response(self, parsed: dict) -> str: + if parsed['action'] is None: + return parsed['thought'] + assert parsed['tool_type'] == self.tool_type + if isinstance(parsed['action'], dict): + action = json.dumps(parsed['action'], ensure_ascii=False) + else: + action = str(parsed['action']) + return parsed['thought'] + self.format_field[ + 'begin'] + action + self.format_field['end'] + + +class InterpreterParser(ToolParser): + + def __init__(self, + tool_type: str = 'interpreter', + template: str = '', + begin: str = '<|action_start|><|interpreter|>\n', + end: str = '<|action_end|>\n', + validate: Callable[[str], Any] = None, + **kwargs): + super().__init__(tool_type, template, begin, end, validate, **kwargs) + + +class PluginParser(ToolParser): + + def __init__(self, + tool_type: str = 'plugin', + template: str = '', + begin: str = '<|action_start|><|plugin|>\n', + end: str = '<|action_end|>\n', + validate: Callable[[str], Any] = default_plugin_validate, + **kwargs): + super().__init__(tool_type, template, begin, end, validate, **kwargs) + + +class MixedToolParser(StrParser): + + def __init__(self, + tool_type: Optional[str] = None, + template='', + parsers: List[ToolParser] = None, + **format_field): + self.parsers = {} + self.tool_type = tool_type + for parser in parsers or []: + parser = create_object(parser) + self.parsers[parser.tool_type] = parser + super().__init__(template, **format_field) + + def format_instruction(self) -> List[dict]: + inst = [] + content = super().format_instruction() + if content.strip(): + msg = dict(role='system', content=content) + if self.tool_type: + msg['name'] = self.tool_type + inst.append(msg) + for name, parser in self.parsers.items(): + content = parser.format_instruction() + if content.strip(): + inst.append(dict(role='system', content=content, name=name)) + return inst + + def parse_response(self, data: str) -> dict: + res = dict( + tool_type=None, + thought=data, + action=None, + status=ToolStatusCode.NO_TOOL) + for name, parser in self.parsers.items(): + res = parser.parse_response(data) + if res['tool_type'] == name: + break + return res + + def format_response(self, parsed: dict) -> str: + if parsed['action'] is None: + return parsed['thought'] + assert parsed['tool_type'] in self.parsers + return self.parsers[parsed['tool_type']].format_response(parsed) diff --git a/lagent/prompts/prompt_template.py b/lagent/prompts/prompt_template.py new file mode 100644 index 0000000000000000000000000000000000000000..15facf92cbcdf4f0963770587f68cbdc99aac7db --- /dev/null +++ b/lagent/prompts/prompt_template.py @@ -0,0 +1,188 @@ +import re +from dataclasses import asdict, is_dataclass +from typing import Any, Dict, Optional, Union + +import jinja2 +from pydantic import BaseModel + + +class PromptTemplate: + """prompt templates. + + Args: + template (str): The template string. + variables (Optional[Union[Dict[str, str], BaseModel, Any]]): Variables for the template. + format_type (str): The format type of the template ('json' or 'jinja'). + + """ + + def __init__(self, template: str, format_type: str = 'json') -> None: + self.template = template + self.format_type = format_type + + def _convert_to_dict( + self, variables: Optional[Union[Dict[str, str], BaseModel, Any]] + ) -> Dict[str, str]: + """ + Convert variables to a dictionary. + + Args: + variables (Optional[Union[Dict[str, str], BaseModel, Any]]): + Variables to convert. + + Returns: + Dict[str, str]: The converted dictionary. + + Raises: + ValueError: If the variables type is unsupported. + """ + if variables is None: + return {} + if isinstance(variables, BaseModel): + return variables.dict() + if is_dataclass(variables): + return asdict(variables) + if isinstance(variables, dict): + return variables + raise ValueError( + 'Unsupported variables type. Must be a dict, BaseModel, or ' + 'dataclass.') + + def parse_template(self, template: str) -> Dict[str, str]: + """ + Extract variables from the template. + + Args: + template (str): The template string. + + Returns: + Dict[str, str]: A dictionary of variables with None values. + """ + if self.format_type == 'jinja': + variables = re.findall(r'\{\{(.*?)\}\}', template) + + elif self.format_type == 'json': + variables = re.findall(r'\{(.*?)\}', template) + variables = [var for var in variables if '{' not in var] + else: + variables = [] + return {var.strip(): None for var in variables} + + def format_json(self, template: str, variables: Dict[str, str]) -> str: + """ + Format the JSON template. + + Args: + template (str): The JSON template string. + variables (Dict[str, str]): The variables to fill in the template. + + Returns: + str: The formatted JSON string. + + Raises: + ValueError: If the template is not a valid JSON. + """ + try: + return template.format(**variables) + except KeyError as e: + raise ValueError('Invalid JSON template') from e + + def format_jinja(self, template: str, variables: Dict[str, str]) -> str: + """ + Format the Jinja template. + + Args: + template (str): The Jinja template string. + variables (Dict[str, str]): The variables to fill in the template. + + Returns: + str: The formatted Jinja string. + + Raises: + ValueError: If the template is not a valid Jinja template. + """ + try: + jinja_template = jinja2.Template(template) + return jinja_template.render(variables) + except jinja2.TemplateError as e: + raise ValueError('Invalid Jinja template') from e + + def _update_variables_with_info(self) -> Dict[str, str]: + """ + Update variables dictionary with action_info and agents_info. + + Returns: + Dict[str, str]: The updated variables dictionary. + """ + variables = self.variables.copy() + if 'action_info' not in variables and self.actions_info: + variables['action_info'] = self.actions_info + if 'agents_info' not in variables and self.agents_info: + variables['agents_info'] = self.agents_info + return variables + + def _check_variables_match(self, parsed_variables: Dict[str, str], + variables: Dict[str, str]) -> None: + """ + Check if all keys in variables are present in parsed_variables. + + Args: + parsed_variables (Dict[str, str]): The parsed variables from + the template. + variables (Dict[str, str]): The variables to check. + + Raises: + ValueError: If any key in variables is not present in + parsed_variables. + """ + if not all(key in parsed_variables for key in variables.keys()): + raise ValueError( + 'Variables keys do not match the template variables') + + def format( + self, + **kwargs: Optional[Union[Dict[str, str], BaseModel, Any]], + ) -> Any: + self.variables = kwargs + return str(self) + + def __str__(self) -> Any: + """ + Call the template formatting based on format_type. + + Returns: + Any: The formatted template. + + Raises: + ValueError: If the format_type is unsupported. + """ + parsed_variables = self.parse_template(self.template) + updated_variables = self._update_variables_with_info() + self._check_variables_match(parsed_variables, updated_variables) + + if self.format_type == 'json': + return self.format_json(self.template, updated_variables) + elif self.format_type == 'jinja': + return self.format_jinja(self.template, updated_variables) + else: + raise ValueError('Unsupported format type') + + @property + def actions_info(self) -> Optional[Dict[str, Any]]: + """Get the action information.""" + return getattr(self, '_action_info', None) + + @actions_info.setter + def actions_info(self, value: Dict[str, Any]) -> None: + """Set the action information.""" + self._action_info = value + + @property + def agents_info(self) -> Optional[Dict[str, Any]]: + """Get the agent information.""" + return getattr(self, '_agents_info', None) + + @agents_info.setter + def agents_info(self, value: Dict[str, Any]) -> None: + """Set the agent information.""" + self._agents_info = value diff --git a/lagent/schema.py b/lagent/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..668846fbe563ddc7cb26b2d77e6a88258cb29962 --- /dev/null +++ b/lagent/schema.py @@ -0,0 +1,97 @@ +from dataclasses import asdict, dataclass +from enum import IntEnum +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel + + +def enum_dict_factory(inputs): + inputs = [(i[0], i[-1].value) if isinstance(i[-1], IntEnum) else i + for i in inputs] + return dict(inputs) + + +def dataclass2dict(data): + return asdict(data, dict_factory=enum_dict_factory) + + +@dataclass +class FunctionCall: + name: str + parameters: Union[Dict, str] + + +class ActionStatusCode(IntEnum): + ING = 1 + SUCCESS = 0 + HTTP_ERROR = -1000 # http error + ARGS_ERROR = -1001 # parameter error + API_ERROR = -1002 # unknown error + + +class ActionValidCode(IntEnum): + FINISH = 1 + OPEN = 0 + CLOSED = -1 + INVALID = -2 + ABSENT = -3 # NO ACTION + + +@dataclass +class ActionReturn: + args: Optional[dict] = None + url: Optional[str] = None + type: Optional[str] = None + result: Optional[List[dict]] = None + errmsg: Optional[str] = None + state: Union[ActionStatusCode, int] = ActionStatusCode.SUCCESS + thought: Optional[str] = None + valid: Optional[ActionValidCode] = ActionValidCode.OPEN + + def format_result(self) -> str: + """Concatenate items in result.""" + result = [] + for item in self.result or []: + if item['type'] == 'text': + result.append(item['content']) + else: + result.append(f"[{item['type']}]({item['content']})") + result = '\n'.join(result) + return result + + +# need to integrate int, so asdict can convert AgentStatusCode to int +class ModelStatusCode(IntEnum): + END = 0 # end of streaming + STREAM_ING = 1 # response is in streaming + SERVER_ERR = -1 # triton server's error + SESSION_CLOSED = -2 # session has been closed + SESSION_OUT_OF_LIMIT = -3 # request length out of limit + SESSION_INVALID_ARG = -4 # invalid argument + SESSION_READY = 2 # session is ready for inference + + +class AgentStatusCode(IntEnum): + END = 0 # end of streaming + STREAM_ING = 1 # response is in streaming + SERVER_ERR = -1 # triton server's error + SESSION_CLOSED = -2 # session has been closed + SESSION_OUT_OF_LIMIT = -3 # request length out of limit + SESSION_INVALID_ARG = -4 # invalid argument + SESSION_READY = 2 # session is ready for inference + PLUGIN_START = 3 # start tool + PLUGIN_END = 4 # finish tool + PLUGIN_RETURN = 5 # finish tool + CODING = 6 # start python + CODE_END = 7 # end python + CODE_RETURN = 8 # python return + + +class AgentMessage(BaseModel): + content: Any + sender: str = 'user' + formatted: Optional[Any] = None + extra_info: Optional[Any] = None + type: Optional[str] = None + receiver: Optional[str] = None + stream_state: Union[ModelStatusCode, AgentStatusCode] = AgentStatusCode.END diff --git a/lagent/utils/__init__.py b/lagent/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0ac54935d812f97f1a99ea168f106cf708b6e64 --- /dev/null +++ b/lagent/utils/__init__.py @@ -0,0 +1,14 @@ +from .package import is_module_exist +from .util import ( + GeneratorWithReturn, + async_as_completed, + create_object, + filter_suffix, + get_logger, + load_class_from_string, +) + +__all__ = [ + 'is_module_exist', 'filter_suffix', 'create_object', 'get_logger', + 'load_class_from_string', 'async_as_completed', 'GeneratorWithReturn' +] diff --git a/lagent/utils/__pycache__/__init__.cpython-310.pyc b/lagent/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c87a7aab8bffe3133051f2c567f83423816d9f7a Binary files /dev/null and b/lagent/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/lagent/utils/__pycache__/package.cpython-310.pyc b/lagent/utils/__pycache__/package.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f844d8d42e89f89b32c6635875cfeb6a31461bf6 Binary files /dev/null and b/lagent/utils/__pycache__/package.cpython-310.pyc differ diff --git a/lagent/utils/__pycache__/util.cpython-310.pyc b/lagent/utils/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8150b6baee2473c21afb195d906ebe61bf189b45 Binary files /dev/null and b/lagent/utils/__pycache__/util.cpython-310.pyc differ diff --git a/lagent/utils/gen_key.py b/lagent/utils/gen_key.py new file mode 100644 index 0000000000000000000000000000000000000000..95642820fb019a26dbcb5cff5a6834220ad390e8 --- /dev/null +++ b/lagent/utils/gen_key.py @@ -0,0 +1,32 @@ +# api token 生成代码 +import time +import jwt +import os + +minutes = 525600 + + +def encode_jwt_token(ak, sk): + headers = { + "alg": "HS256", + "typ": "JWT" + } + payload = { + "iss": ak, + "exp": int(time.time()) + minutes, # 填写您期望的有效时间,此处示例代表当前时间+一年 + "nbf": int(time.time()) - 5 # 填写您期望的生效时间,此处示例代表当前时间-5秒 + } + token = jwt.encode(payload, sk, headers=headers) + return token + + +def auto_gen_jwt_token(ak, sk): + token = encode_jwt_token(ak, sk) + return token + + +if __name__ == '__main__': + ak = os.getenv('NOVA_AK') + sk = os.getenv('NOVA_SK') + token = encode_jwt_token(ak, sk) + print(token) diff --git a/lagent/utils/package.py b/lagent/utils/package.py new file mode 100644 index 0000000000000000000000000000000000000000..874fc8370fefab31e40084f4b389bafa4ac1ca49 --- /dev/null +++ b/lagent/utils/package.py @@ -0,0 +1,9 @@ +from importlib.util import find_spec + + +def is_module_exist(module_name): + spec = find_spec(module_name) + if spec is None: + return False + else: + return True diff --git a/lagent/utils/util.py b/lagent/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..a40482b53c1c94a1a6f65d64a80f5282b29bd4c2 --- /dev/null +++ b/lagent/utils/util.py @@ -0,0 +1,139 @@ +import asyncio +import importlib +import inspect +import logging +import os +import os.path as osp +import sys +import time +from functools import partial +from logging.handlers import RotatingFileHandler +from typing import Any, Dict, Generator, Iterable, List, Optional, Union + + +def load_class_from_string(class_path: str, path=None): + path_in_sys = False + if path: + if path not in sys.path: + path_in_sys = True + sys.path.insert(0, path) + + try: + module_name, class_name = class_path.rsplit('.', 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + return cls + finally: + if path and path_in_sys: + sys.path.remove(path) + + +def create_object(config: Union[Dict, Any] = None): + """Create an instance based on the configuration where 'type' is a + preserved key to indicate the class (path). When accepting non-dictionary + input, the function degenerates to an identity. + """ + if config is None or not isinstance(config, dict): + return config + assert isinstance(config, dict) and 'type' in config + + config = config.copy() + obj_type = config.pop('type') + if isinstance(obj_type, str): + obj_type = load_class_from_string(obj_type) + if inspect.isclass(obj_type): + obj = obj_type(**config) + else: + assert callable(obj_type) + obj = partial(obj_type, **config) + return obj + + +async def async_as_completed(futures: Iterable[asyncio.Future]): + """A asynchronous wrapper for `asyncio.as_completed`""" + loop = asyncio.get_event_loop() + wrappers = [] + for fut in futures: + assert isinstance(fut, asyncio.Future) + wrapper = loop.create_future() + fut.add_done_callback(wrapper.set_result) + wrappers.append(wrapper) + for next_completed in asyncio.as_completed(wrappers): + yield await next_completed + + +def filter_suffix(response: Union[str, List[str]], + suffixes: Optional[List[str]] = None) -> str: + """Filter response with suffixes. + + Args: + response (Union[str, List[str]]): generated responses by LLMs. + suffixes (str): a list of suffixes to be deleted. + + Return: + str: a clean response. + """ + if suffixes is None: + return response + batched = True + if isinstance(response, str): + response = [response] + batched = False + processed = [] + for resp in response: + for item in suffixes: + # if response.endswith(item): + # response = response[:len(response) - len(item)] + if item in resp: + resp = resp.split(item)[0] + processed.append(resp) + if not batched: + return processed[0] + return processed + + +def get_logger( + name: str = 'lagent', + level: str = 'debug', + fmt: + str = '%(asctime)s %(levelname)8s %(filename)20s %(lineno)4s - %(message)s', + add_file_handler: bool = False, + log_dir: str = 'log', + log_file: str = time.strftime('%Y-%m-%d.log', time.localtime()), + max_bytes: int = 5 * 1024 * 1024, + backup_count: int = 3, +): + logger = logging.getLogger(name) + logger.propagate = False + logger.setLevel(getattr(logging, level.upper(), logging.DEBUG)) + + formatter = logging.Formatter(fmt) + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + if add_file_handler: + if not osp.exists(log_dir): + os.makedirs(log_dir) + log_file_path = osp.join(log_dir, log_file) + file_handler = RotatingFileHandler( + log_file_path, + maxBytes=max_bytes, + backupCount=backup_count, + encoding='utf-8') + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + +class GeneratorWithReturn: + """Generator wrapper to capture the return value.""" + + def __init__(self, generator: Generator): + self.generator = generator + self.ret = None + + def __iter__(self): + self.ret = yield from self.generator + return self.ret diff --git a/lagent/version.py b/lagent/version.py new file mode 100644 index 0000000000000000000000000000000000000000..d9c59dd319e7e1be2581dc9330a2ed8120073173 --- /dev/null +++ b/lagent/version.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +__version__ = '0.5.0rc1' + + +def parse_version_info(version_str): + """Parse the version information. + + Args: + version_str (str): version string like '0.1.0'. + + Returns: + tuple: version information contains major, minor, micro version. + """ + version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + version_info.append(int(patch_version[0])) + version_info.append(f'rc{patch_version[1]}') + return tuple(version_info) + + +version_info = parse_version_info(__version__) + +__all__ = ['__version__', 'version_info', 'parse_version_info']