Spaces:
Running
Running
import warnings | |
from typing import Dict, List, Optional, Tuple, Union | |
from lagent.llms.base_llm import AsyncLLMMixin, BaseLLM | |
class APITemplateParser: | |
"""Intermidate prompt template parser, specifically for API models. | |
Args: | |
meta_template (Dict): The meta template for the model. | |
""" | |
def __init__(self, meta_template: Optional[Dict] = None): | |
self.meta_template = meta_template | |
# Check meta template | |
if meta_template: | |
assert isinstance(meta_template, list) | |
self.roles: Dict[str, dict] = dict() # maps role name to config | |
for item in meta_template: | |
assert isinstance(item, dict) | |
assert item['role'] not in self.roles, \ | |
'role in meta prompt must be unique!' | |
self.roles[item['role']] = item.copy() | |
def __call__(self, dialog: List[Union[str, List]]): | |
"""Parse the intermidate prompt template, and wrap it with meta | |
template if applicable. When the meta template is set and the input is | |
a list, the return value will be a list containing the full | |
conversation history. Each item looks like: | |
.. code-block:: python | |
{'role': 'user', 'content': '...'}). | |
Args: | |
dialog (List[str or list]): An intermidate prompt | |
template (potentially before being wrapped by meta template). | |
Returns: | |
List[str or list]: The finalized prompt or a conversation. | |
""" | |
assert isinstance(dialog, (str, list)) | |
if isinstance(dialog, str): | |
return dialog | |
if self.meta_template: | |
prompt = list() | |
# Whether to keep generating the prompt | |
generate = True | |
for i, item in enumerate(dialog): | |
if not generate: | |
break | |
if isinstance(item, str): | |
if item.strip(): | |
# TODO: logger | |
warnings.warn('Non-empty string in prompt template ' | |
'will be ignored in API models.') | |
else: | |
api_prompts = self._prompt2api(item) | |
prompt.append(api_prompts) | |
# merge the consecutive prompts assigned to the same role | |
new_prompt = list([prompt[0]]) | |
last_role = prompt[0]['role'] | |
for item in prompt[1:]: | |
if item['role'] == last_role: | |
new_prompt[-1]['content'] += '\n' + item['content'] | |
else: | |
last_role = item['role'] | |
new_prompt.append(item) | |
prompt = new_prompt | |
else: | |
# in case the model does not have any meta template | |
prompt = '' | |
last_sep = '' | |
for item in dialog: | |
if isinstance(item, str): | |
if item: | |
prompt += last_sep + item | |
elif item.get('content', ''): | |
prompt += last_sep + item.get('content', '') | |
last_sep = '\n' | |
return prompt | |
def _prompt2api(self, prompts: Union[List, str]) -> Tuple[str, bool]: | |
"""Convert the prompts to a API-style prompts, given an updated | |
role_dict. | |
Args: | |
prompts (Union[List, str]): The prompts to be converted. | |
role_dict (Dict[str, Dict]): The updated role dict. | |
for_gen (bool): If True, the prompts will be converted for | |
generation tasks. The conversion stops before the first | |
role whose "generate" is set to True. | |
Returns: | |
Tuple[str, bool]: The converted string, and whether the follow-up | |
conversion should be proceeded. | |
""" | |
if isinstance(prompts, str): | |
return prompts | |
elif isinstance(prompts, dict): | |
api_role = self._role2api_role(prompts) | |
return api_role | |
res = [] | |
for prompt in prompts: | |
if isinstance(prompt, str): | |
raise TypeError('Mixing str without explicit role is not ' | |
'allowed in API models!') | |
else: | |
api_role = self._role2api_role(prompt) | |
res.append(api_role) | |
return res | |
def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]: | |
merged_prompt = self.roles[role_prompt['role']] | |
if merged_prompt.get('fallback_role'): | |
merged_prompt = self.roles[self.roles[ | |
merged_prompt['fallback_role']]] | |
res = role_prompt.copy() | |
res['role'] = merged_prompt['api_role'] | |
res['content'] = merged_prompt.get('begin', '') | |
res['content'] += role_prompt.get('content', '') | |
res['content'] += merged_prompt.get('end', '') | |
return res | |
class BaseAPILLM(BaseLLM): | |
"""Base class for API model wrapper. | |
Args: | |
model_type (str): The type of model. | |
retry (int): Number of retires if the API call fails. Defaults to 2. | |
meta_template (Dict, optional): The model's meta prompt | |
template if needed, in case the requirement of injecting or | |
wrapping of any meta instructions. | |
""" | |
is_api: bool = True | |
def __init__(self, | |
model_type: str, | |
retry: int = 2, | |
template_parser: 'APITemplateParser' = APITemplateParser, | |
meta_template: Optional[Dict] = None, | |
*, | |
max_new_tokens: int = 512, | |
top_p: float = 0.8, | |
top_k: int = 40, | |
temperature: float = 0.8, | |
repetition_penalty: float = 0.0, | |
stop_words: Union[List[str], str] = None): | |
self.model_type = model_type | |
self.meta_template = meta_template | |
self.retry = retry | |
if template_parser: | |
self.template_parser = template_parser(meta_template) | |
if isinstance(stop_words, str): | |
stop_words = [stop_words] | |
self.gen_params = dict( | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
stop_words=stop_words, | |
skip_special_tokens=False) | |
class AsyncBaseAPILLM(AsyncLLMMixin, BaseAPILLM): | |
pass | |