import asyncio import importlib import inspect import logging import os import os.path as osp import sys import time from functools import partial from logging.handlers import RotatingFileHandler from typing import Any, Dict, Generator, Iterable, List, Optional, Union def load_class_from_string(class_path: str, path=None): path_in_sys = False if path: if path not in sys.path: path_in_sys = True sys.path.insert(0, path) try: module_name, class_name = class_path.rsplit('.', 1) module = importlib.import_module(module_name) cls = getattr(module, class_name) return cls finally: if path and path_in_sys: sys.path.remove(path) def create_object(config: Union[Dict, Any] = None): """Create an instance based on the configuration where 'type' is a preserved key to indicate the class (path). When accepting non-dictionary input, the function degenerates to an identity. """ if config is None or not isinstance(config, dict): return config assert isinstance(config, dict) and 'type' in config config = config.copy() obj_type = config.pop('type') if isinstance(obj_type, str): obj_type = load_class_from_string(obj_type) if inspect.isclass(obj_type): obj = obj_type(**config) else: assert callable(obj_type) obj = partial(obj_type, **config) return obj async def async_as_completed(futures: Iterable[asyncio.Future]): """A asynchronous wrapper for `asyncio.as_completed`""" loop = asyncio.get_event_loop() wrappers = [] for fut in futures: assert isinstance(fut, asyncio.Future) wrapper = loop.create_future() fut.add_done_callback(wrapper.set_result) wrappers.append(wrapper) for next_completed in asyncio.as_completed(wrappers): yield await next_completed def filter_suffix(response: Union[str, List[str]], suffixes: Optional[List[str]] = None) -> str: """Filter response with suffixes. Args: response (Union[str, List[str]]): generated responses by LLMs. suffixes (str): a list of suffixes to be deleted. Return: str: a clean response. """ if suffixes is None: return response batched = True if isinstance(response, str): response = [response] batched = False processed = [] for resp in response: for item in suffixes: # if response.endswith(item): # response = response[:len(response) - len(item)] if item in resp: resp = resp.split(item)[0] processed.append(resp) if not batched: return processed[0] return processed def get_logger( name: str = 'lagent', level: str = 'debug', fmt: str = '%(asctime)s %(levelname)8s %(filename)20s %(lineno)4s - %(message)s', add_file_handler: bool = False, log_dir: str = 'log', log_file: str = time.strftime('%Y-%m-%d.log', time.localtime()), max_bytes: int = 5 * 1024 * 1024, backup_count: int = 3, ): logger = logging.getLogger(name) logger.propagate = False logger.setLevel(getattr(logging, level.upper(), logging.DEBUG)) formatter = logging.Formatter(fmt) console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logger.addHandler(console_handler) if add_file_handler: if not osp.exists(log_dir): os.makedirs(log_dir) log_file_path = osp.join(log_dir, log_file) file_handler = RotatingFileHandler( log_file_path, maxBytes=max_bytes, backupCount=backup_count, encoding='utf-8') file_handler.setFormatter(formatter) logger.addHandler(file_handler) return logger class GeneratorWithReturn: """Generator wrapper to capture the return value.""" def __init__(self, generator: Generator): self.generator = generator self.ret = None def __iter__(self): self.ret = yield from self.generator return self.ret