from copy import copy from typing import Dict, List, Optional, Tuple, Union class LMTemplateParser: """Intermidate prompt template parser, specifically for language models. Args: meta_template (list of dict, optional): The meta template for the model. """ def __init__(self, meta_template: Optional[List[Dict]] = None): self.meta_template = meta_template if meta_template: assert isinstance(meta_template, list) self.roles: Dict[str, dict] = dict() # maps role name to config for item in meta_template: assert isinstance(item, dict) assert item['role'] not in self.roles, \ 'role in meta prompt must be unique!' self.roles[item['role']] = item.copy() def __call__(self, dialog) -> str: """Parse a prompt template, and wrap it with meta template if applicable. Args: dialog (List[str or PromptList]): A prompt template (potentially before being wrapped by meta template). Returns: str: The final string. """ assert isinstance(dialog, (str, list)) if isinstance(dialog, str): return dialog if self.meta_template: prompt = '' for index, item in enumerate(dialog): if isinstance(item, str): prompt += item else: new_str = self._prompt2str(item, index == len(dialog) - 1) prompt += new_str else: # in case the model does not have any meta template prompt = '' last_sep = '' for item in dialog: if isinstance(item, str): if item: prompt += last_sep + item elif item.get('content', ''): prompt += last_sep + item.get('prompt', '') last_sep = '\n' return prompt def _format_begin(self, role_cfg, message): name = message.get('name', None) if name is not None: begin = role_cfg['begin'].get('with_name', '') if name in role_cfg['begin'].get('name', {}): begin = begin.format(name=role_cfg['begin']['name'][name]) else: begin = begin.format(name=name) else: if isinstance(role_cfg.get('begin', ''), str): begin = role_cfg.get('begin', '') elif isinstance(role_cfg['begin'], dict): begin = role_cfg['begin'].get('without_name', '') return begin def _prompt2str(self, prompt: Union[str, Dict], last: bool = False) -> Tuple[str, bool]: if isinstance(prompt, str): return prompt merged_prompt = self.roles.get(prompt['role']) if merged_prompt.get('fallback_role'): merged_prompt = self.roles.get(merged_prompt['fallback_role']) begin = self._format_begin(merged_prompt, prompt) res = begin if last and merged_prompt.get('generate', False): res += prompt.get('content', '') return res res += prompt.get('content', '') + merged_prompt.get('end', '') if last and merged_prompt['role'] != 'assistant': res += self._format_begin(self.roles['assistant'], {}) return res return res class BaseLLM: """Base class for model wrapper. Args: path (str): The path to the model. max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults to 512. tokenizer_only (bool): If True, only the tokenizer will be initialized. Defaults to False. meta_template (list of dict, optional): The model's meta prompt template if needed, in case the requirement of injecting or wrapping of any meta instructions. """ def __init__(self, path: str, tokenizer_only: bool = False, template_parser: 'LMTemplateParser' = LMTemplateParser, meta_template: Optional[List[Dict]] = None, *, max_new_tokens: int = 512, top_p: float = 0.8, top_k: float = 40, temperature: float = 0.8, repetition_penalty: float = 1.0, stop_words: Union[List[str], str] = None): self.path = path self.tokenizer_only = tokenizer_only # meta template self.template_parser = template_parser(meta_template) self.eos_token_id = None if meta_template and 'eos_token_id' in meta_template: self.eos_token_id = meta_template['eos_token_id'] if isinstance(stop_words, str): stop_words = [stop_words] self.gen_params = dict( max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, temperature=temperature, repetition_penalty=repetition_penalty, stop_words=stop_words) def generate(self, inputs: Union[str, List[str]], **gen_params) -> str: """Generate results given a str (or list of) inputs. Args: inputs (Union[str, List[str]]): gen_params (dict): The input params for generation. Returns: Union[str, List[str]]: A (list of) generated strings. eg. batched = True if isinstance(inputs, str): inputs = [inputs] batched = False response = [''] if batched: return response return response[0] """ raise NotImplementedError def stream_generate(self, inputs: str, **gen_params) -> List[str]: """Generate results as streaming given a str inputs. Args: inputs (str): gen_params (dict): The input params for generation. Returns: str: A generated string. """ raise NotImplementedError def chat(self, inputs: Union[List[dict], List[List[dict]]], session_ids: Union[int, List[int]] = None, **gen_params): """Generate completion from a list of templates. Args: inputs (Union[List[dict], List[List[dict]]]): gen_params (dict): The input params for generation. Returns: """ if isinstance(inputs[0], list): _inputs = list() for msg in inputs: _inputs.append(self.template_parser(msg)) else: _inputs = self.template_parser(inputs) return self.generate(_inputs, **gen_params) def stream_chat(self, inputs: List[dict], **gen_params): """Generate results as streaming given a list of templates. Args: inputs (Union[List[dict]): gen_params (dict): The input params for generation. Returns: """ raise NotImplementedError def tokenize(self, prompts: Union[str, List[str], List[dict], List[List[dict]]]): """Tokenize the input prompts. Args: prompts(str | List[str]): user's prompt, or a batch prompts Returns: Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token ids, ids' length and requested output length """ raise NotImplementedError def update_gen_params(self, **kwargs): gen_params = copy(self.gen_params) gen_params.update(kwargs) return gen_params class AsyncLLMMixin: async def generate(self, inputs: Union[str, List[str]], session_ids: Union[int, List[int]] = None, **gen_params) -> str: """Generate results given a str (or list of) inputs. Args: inputs (Union[str, List[str]]): gen_params (dict): The input params for generation. Returns: Union[str, List[str]]: A (list of) generated strings. eg. batched = True if isinstance(inputs, str): inputs = [inputs] batched = False response = [''] if batched: return response return response[0] """ raise NotImplementedError async def stream_generate(self, inputs: str, **gen_params) -> List[str]: """Generate results as streaming given a str inputs. Args: inputs (str): gen_params (dict): The input params for generation. Returns: str: A generated string. """ raise NotImplementedError async def chat(self, inputs: Union[List[dict], List[List[dict]]], session_ids: Union[int, List[int]] = None, **gen_params): """Generate completion from a list of templates. Args: inputs (Union[List[dict], List[List[dict]]]): gen_params (dict): The input params for generation. Returns: """ if isinstance(inputs[0], list): _inputs = list() for msg in inputs: _inputs.append(self.template_parser(msg)) else: _inputs = self.template_parser(inputs) return await self.generate(_inputs, session_ids, **gen_params) async def stream_chat(self, inputs: List[dict], **gen_params): """Generate results as streaming given a list of templates. Args: inputs (Union[List[dict]): gen_params (dict): The input params for generation. Returns: """ raise NotImplementedError async def tokenize(self, prompts: Union[str, List[str], List[dict], List[List[dict]]]): """Tokenize the input prompts. Args: prompts(str | List[str]): user's prompt, or a batch prompts Returns: Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token ids, ids' length and requested output length """ raise NotImplementedError class AsyncBaseLLM(AsyncLLMMixin, BaseLLM): pass