import copy import logging from typing import Dict, List, Optional, Union from lagent.schema import ModelStatusCode from .base_api import APITemplateParser from .base_llm import BaseLLM logger = logging.getLogger(__name__) class HFTransformer(BaseLLM): """Model wrapper around HuggingFace general models. Adapted from Internlm (https://github.com/InternLM/InternLM/blob/main/ chat/web_demo.py) Args: path (str): The name or path to HuggingFace's model. tokenizer_path (str): The path to the tokenizer. Defaults to None. tokenizer_kwargs (dict): Keyword arguments for the tokenizer. Defaults to {}. tokenizer_only (bool): If True, only the tokenizer will be initialized. Defaults to False. model_kwargs (dict): Keyword arguments for the model, used in loader. Defaults to dict(device_map='auto'). meta_template (Dict, optional): The model's meta prompt template if needed, in case the requirement of injecting or wrapping of any meta instructions. """ def __init__(self, path: str, tokenizer_path: Optional[str] = None, tokenizer_kwargs: dict = dict(), tokenizer_only: bool = False, model_kwargs: dict = dict(device_map='auto'), meta_template: Optional[Dict] = None, stop_words_id: Union[List[int], int] = None, **kwargs): super().__init__( path=path, tokenizer_only=tokenizer_only, meta_template=meta_template, **kwargs) if isinstance(stop_words_id, int): stop_words_id = [stop_words_id] self.gen_params.update(stop_words_id=stop_words_id) if self.gen_params['stop_words'] is not None and \ self.gen_params['stop_words_id'] is not None: logger.warning('Both stop_words and stop_words_id are specified,' 'only stop_words_id will be used.') self._load_tokenizer( path=path, tokenizer_path=tokenizer_path, tokenizer_kwargs=tokenizer_kwargs) if not tokenizer_only: self._load_model(path=path, model_kwargs=model_kwargs) from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList # noqa: E501 self.logits_processor = LogitsProcessorList() self.stopping_criteria = StoppingCriteriaList() self.prefix_allowed_tokens_fn = None stop_words_id = [] if self.gen_params.get('stop_words_id'): stop_words_id = self.gen_params.get('stop_words_id') elif self.gen_params.get('stop_words'): for sw in self.gen_params.get('stop_words'): stop_words_id.append(self.tokenizer(sw)['input_ids'][-1]) self.additional_eos_token_id = stop_words_id def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict): from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_path if tokenizer_path else path, trust_remote_code=True, **tokenizer_kwargs) if self.tokenizer.pad_token_id is None: if self.tokenizer.eos_token is not None: logger.warning( f'Using eos_token_id {self.tokenizer.eos_token} ' 'as pad_token_id.') self.tokenizer.pad_token = self.tokenizer.eos_token else: from transformers.generation import GenerationConfig self.gcfg = GenerationConfig.from_pretrained(path) if self.gcfg.pad_token_id is not None: logger.warning( f'Using pad_token_id {self.gcfg.pad_token_id} ' 'as pad_token_id.') self.tokenizer.pad_token_id = self.gcfg.pad_token_id else: raise ValueError( 'pad_token_id is not set for this tokenizer. Try to ' 'set pad_token_id via passing ' '`pad_token_id={PAD_TOKEN_ID}` in model_cfg.') def _load_model(self, path: str, model_kwargs: dict): import torch from transformers import AutoModel model_kwargs.setdefault('torch_dtype', torch.float16) self.model = AutoModel.from_pretrained( path, trust_remote_code=True, **model_kwargs) self.model.eval() def tokenize(self, inputs: str): assert isinstance(inputs, str) inputs = self.tokenizer( inputs, return_tensors='pt', return_length=True) return inputs['input_ids'].tolist() def generate( self, inputs: Union[str, List[str]], do_sample: bool = True, **kwargs, ): """Return the chat completions in non-stream mode. Args: inputs (Union[str, List[str]]): input texts to be completed. do_sample (bool): do sampling if enabled Returns: (a list of/batched) text/chat completion """ for status, chunk, _ in self.stream_generate(inputs, do_sample, **kwargs): response = chunk return response def stream_generate( self, inputs: List[str], do_sample: bool = True, **kwargs, ): """Return the chat completions in stream mode. Args: inputs (Union[str, List[str]]): input texts to be completed. do_sample (bool): do sampling if enabled Returns: tuple(Status, str, int): status, text/chat completion, generated token number """ import torch from torch import nn with torch.no_grad(): batched = True if isinstance(inputs, str): inputs = [inputs] batched = False inputs = self.tokenizer( inputs, padding=True, return_tensors='pt', return_length=True) input_length = inputs['length'] for k, v in inputs.items(): inputs[k] = v.cuda() input_ids = inputs['input_ids'] attention_mask = inputs['attention_mask'] batch_size = input_ids.shape[0] input_ids_seq_length = input_ids.shape[-1] generation_config = self.model.generation_config generation_config = copy.deepcopy(generation_config) new_gen_params = self.update_gen_params(**kwargs) generation_config.update(**new_gen_params) generation_config.update(**kwargs) model_kwargs = generation_config.to_dict() model_kwargs['attention_mask'] = attention_mask _, eos_token_id = ( # noqa: F841 # pylint: disable=W0612 generation_config.bos_token_id, generation_config.eos_token_id, ) if eos_token_id is None: if self.gcfg.eos_token_id is not None: eos_token_id = self.gcfg.eos_token_id else: eos_token_id = [] if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] if self.additional_eos_token_id is not None: eos_token_id.extend(self.additional_eos_token_id) eos_token_id_tensor = torch.tensor(eos_token_id).to( input_ids.device) if eos_token_id is not None else None generation_config.max_length = ( generation_config.max_new_tokens + input_ids_seq_length) # Set generation parameters if not already defined logits_processor = self.logits_processor stopping_criteria = self.stopping_criteria logits_processor = self.model._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, encoder_input_ids=input_ids, prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn, logits_processor=logits_processor, ) stopping_criteria = self.model._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria) logits_warper = self.model._get_logits_warper(generation_config) unfinished_sequences = input_ids.new(batch_size).fill_(1) scores = None while True: model_inputs = self.model.prepare_inputs_for_generation( input_ids, **model_kwargs) # forward pass to get next token outputs = self.model( **model_inputs, return_dict=True, output_attentions=False, output_hidden_states=False, ) next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_warper(input_ids, next_token_scores) # sample probs = nn.functional.softmax(next_token_scores, dim=-1) if do_sample: next_tokens = torch.multinomial( probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(probs, dim=-1) # update generated ids, model inputs, # and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) model_kwargs = self.model._update_model_kwargs_for_generation( # noqa: E501 outputs, model_kwargs, is_encoder_decoder=False) unfinished_sequences = unfinished_sequences.mul( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne( eos_token_id_tensor.unsqueeze(1)).prod(dim=0)) output_token_ids = input_ids.cpu().tolist() for i in range(len(output_token_ids)): output_token_ids[i] = output_token_ids[i][:][ input_length[i]:] # Find the first occurrence of # an EOS token in the sequence first_eos_idx = next( (idx for idx, token_id in enumerate(output_token_ids[i]) if token_id in eos_token_id), None) # If an EOS token is found, only the previous # part of it is retained if first_eos_idx is not None: output_token_ids[i] = output_token_ids[ i][:first_eos_idx] response = self.tokenizer.batch_decode(output_token_ids) # print(response) if not batched: response = response[0] yield ModelStatusCode.STREAM_ING, response, None # stop when each sentence is finished, # or if we exceed the maximum length if (unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores)): break yield ModelStatusCode.END, response, None def stream_chat( self, inputs: List[dict], do_sample: bool = True, **kwargs, ): """Return the chat completions in stream mode. Args: inputs (List[dict]): input messages to be completed. do_sample (bool): do sampling if enabled Returns: the text/chat completion """ prompt = self.template_parser(inputs) yield from self.stream_generate(prompt, do_sample, **kwargs) class HFTransformerCasualLM(HFTransformer): def _load_model(self, path: str, model_kwargs: dict): import torch from transformers import AutoModelForCausalLM model_kwargs.setdefault('torch_dtype', torch.float16) self.model = AutoModelForCausalLM.from_pretrained( path, trust_remote_code=True, **model_kwargs) self.model.eval() class HFTransformerChat(HFTransformerCasualLM): def __init__(self, template_parser=APITemplateParser, **kwargs): super().__init__(template_parser=template_parser, **kwargs) def chat(self, inputs: Union[List[dict], List[List[dict]]], do_sample: bool = True, **kwargs): """Return the chat completions in stream mode. Args: inputs (Union[List[dict], List[List[dict]]]): input messages to be completed. do_sample (bool): do sampling if enabled Returns: the text/chat completion """ # handle batch inference with vanilla for loop if isinstance(inputs[0], list): resps = [] for input in inputs: resps.append(self.chat(input, do_sample, **kwargs)) return resps prompt = self.template_parser(inputs) query = prompt[-1]['content'] history = prompt[:-1] try: response, history = self.model.chat( self.tokenizer, query, history=history) except Exception as e: # handle over-length input error logger.warning(str(e)) response = '' return response