import warnings from typing import Dict, List, Optional, Tuple, Union from lagent.llms.base_llm import AsyncLLMMixin, BaseLLM class APITemplateParser: """Intermidate prompt template parser, specifically for API models. Args: meta_template (Dict): The meta template for the model. """ def __init__(self, meta_template: Optional[Dict] = None): self.meta_template = meta_template # Check meta template if meta_template: assert isinstance(meta_template, list) self.roles: Dict[str, dict] = dict() # maps role name to config for item in meta_template: assert isinstance(item, dict) assert item['role'] not in self.roles, \ 'role in meta prompt must be unique!' self.roles[item['role']] = item.copy() def __call__(self, dialog: List[Union[str, List]]): """Parse the intermidate prompt template, and wrap it with meta template if applicable. When the meta template is set and the input is a list, the return value will be a list containing the full conversation history. Each item looks like: .. code-block:: python {'role': 'user', 'content': '...'}). Args: dialog (List[str or list]): An intermidate prompt template (potentially before being wrapped by meta template). Returns: List[str or list]: The finalized prompt or a conversation. """ assert isinstance(dialog, (str, list)) if isinstance(dialog, str): return dialog if self.meta_template: prompt = list() # Whether to keep generating the prompt generate = True for i, item in enumerate(dialog): if not generate: break if isinstance(item, str): if item.strip(): # TODO: logger warnings.warn('Non-empty string in prompt template ' 'will be ignored in API models.') else: api_prompts = self._prompt2api(item) prompt.append(api_prompts) # merge the consecutive prompts assigned to the same role new_prompt = list([prompt[0]]) last_role = prompt[0]['role'] for item in prompt[1:]: if item['role'] == last_role: new_prompt[-1]['content'] += '\n' + item['content'] else: last_role = item['role'] new_prompt.append(item) prompt = new_prompt else: # in case the model does not have any meta template prompt = '' last_sep = '' for item in dialog: if isinstance(item, str): if item: prompt += last_sep + item elif item.get('content', ''): prompt += last_sep + item.get('content', '') last_sep = '\n' return prompt def _prompt2api(self, prompts: Union[List, str]) -> Tuple[str, bool]: """Convert the prompts to a API-style prompts, given an updated role_dict. Args: prompts (Union[List, str]): The prompts to be converted. role_dict (Dict[str, Dict]): The updated role dict. for_gen (bool): If True, the prompts will be converted for generation tasks. The conversion stops before the first role whose "generate" is set to True. Returns: Tuple[str, bool]: The converted string, and whether the follow-up conversion should be proceeded. """ if isinstance(prompts, str): return prompts elif isinstance(prompts, dict): api_role = self._role2api_role(prompts) return api_role res = [] for prompt in prompts: if isinstance(prompt, str): raise TypeError('Mixing str without explicit role is not ' 'allowed in API models!') else: api_role = self._role2api_role(prompt) res.append(api_role) return res def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]: merged_prompt = self.roles[role_prompt['role']] if merged_prompt.get('fallback_role'): merged_prompt = self.roles[self.roles[ merged_prompt['fallback_role']]] res = role_prompt.copy() res['role'] = merged_prompt['api_role'] res['content'] = merged_prompt.get('begin', '') res['content'] += role_prompt.get('content', '') res['content'] += merged_prompt.get('end', '') return res class BaseAPILLM(BaseLLM): """Base class for API model wrapper. Args: model_type (str): The type of model. retry (int): Number of retires if the API call fails. Defaults to 2. meta_template (Dict, optional): The model's meta prompt template if needed, in case the requirement of injecting or wrapping of any meta instructions. """ is_api: bool = True def __init__(self, model_type: str, retry: int = 2, template_parser: 'APITemplateParser' = APITemplateParser, meta_template: Optional[Dict] = None, *, max_new_tokens: int = 512, top_p: float = 0.8, top_k: int = 40, temperature: float = 0.8, repetition_penalty: float = 0.0, stop_words: Union[List[str], str] = None): self.model_type = model_type self.meta_template = meta_template self.retry = retry if template_parser: self.template_parser = template_parser(meta_template) if isinstance(stop_words, str): stop_words = [stop_words] self.gen_params = dict( max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, temperature=temperature, repetition_penalty=repetition_penalty, stop_words=stop_words, skip_special_tokens=False) class AsyncBaseAPILLM(AsyncLLMMixin, BaseAPILLM): pass