import gc import copy from tenacity import RetryError from tenacity import retry, stop_after_attempt, wait_fixed import torch from transformers import ( AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, LogitsProcessorList, MinNewTokensLengthLogitsProcessor, TemperatureLogitsWarper, TopPLogitsWarper, ) def get_output_batch( model, tokenizer, prompts, generation_config ): if len(prompts) == 1: encoding = tokenizer(prompts, return_tensors="pt") input_ids = encoding["input_ids"].cuda() generated_id = model.generate( input_ids=input_ids, generation_config=generation_config, max_new_tokens=256 ) decoded = tokenizer.batch_decode(generated_id) del input_ids, generated_id torch.cuda.empty_cache() return decoded else: encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda') generated_ids = model.generate( **encodings, generation_config=generation_config, max_new_tokens=256 ) decoded = tokenizer.batch_decode(generated_ids) del encodings, generated_ids torch.cuda.empty_cache() return decoded # StreamModel is borrowed from basaran project # please find more info about it -> https://github.com/hyperonym/basaran class StreamModel: """StreamModel wraps around a language model to provide stream decoding.""" def __init__(self, model, tokenizer): super().__init__() self.model = model self.tokenizer = tokenizer self.device = "cuda" if torch.cuda.is_available() else "cpu" def __call__( self, prompt, min_tokens=0, max_tokens=16, temperature=1.0, top_p=1.0, n=1, logprobs=0, ): """Create a completion stream for the provided prompt.""" input_ids = self.tokenize(prompt) logprobs = max(logprobs, 0) # bigger than 1 chunk_size = 2 chunk_count = 0 # Generate completion tokens. final_tokens = torch.empty(0).to(self.device) try: for tokens in self.generate( input_ids[None, :].repeat(n, 1), logprobs=logprobs, min_new_tokens=min_tokens, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, ): if chunk_count < chunk_size: chunk_count = chunk_count + 1 final_tokens = torch.cat((final_tokens, tokens)) if chunk_count == chunk_size-1: chunk_count = 0 yield self.tokenizer.decode(final_tokens, skip_special_tokens=True) if chunk_count > 0: yield self.tokenizer.decode(final_tokens, skip_special_tokens=True) except RetryError as e: print(e) del input_ids gc.collect() del final_tokens if self.device == "cuda": torch.cuda.empty_cache() @retry(stop=stop_after_attempt(5), wait=wait_fixed(1)) def _infer(self, model_fn, **kwargs): """Call a model function in inference mode with auto retrying.""" # This is a temporary workaround for bitsandbytes #162: # https://github.com/TimDettmers/bitsandbytes/issues/162 with torch.inference_mode(): return model_fn(**kwargs) def _logits_processor(self, config, input_length): """Set up logits processor based on the generation config.""" processor = LogitsProcessorList() # Add processor for enforcing a min-length of new tokens. if ( config.min_new_tokens is not None and config.min_new_tokens > 0 and config.eos_token_id is not None ): processor.append( MinNewTokensLengthLogitsProcessor( prompt_length_to_skip=input_length, min_new_tokens=config.min_new_tokens, eos_token_id=config.eos_token_id, ) ) # Add processor for scaling output probability distribution. if ( config.temperature is not None and config.temperature > 0 and config.temperature != 1.0 ): processor.append(TemperatureLogitsWarper(config.temperature)) # Add processor for nucleus sampling. if config.top_p is not None and config.top_p > 0 and config.top_p < 1: processor.append(TopPLogitsWarper(config.top_p)) return processor def tokenize(self, text): """Tokenize a string into a tensor of token IDs.""" batch = self.tokenizer.encode(text, return_tensors="pt") return batch[0].to(self.device) def generate(self, input_ids, logprobs=0, **kwargs): """Generate a stream of predicted tokens using the language model.""" # Store the original batch size and input length. batch_size = input_ids.shape[0] input_length = input_ids.shape[-1] # Separate model arguments from generation config. config = self.model.generation_config config = copy.deepcopy(config) kwargs = config.update(**kwargs) kwargs["output_attentions"] = False kwargs["output_hidden_states"] = False kwargs["use_cache"] = True # config.use_cache # Collect special token IDs. pad_token_id = config.pad_token_id bos_token_id = config.bos_token_id eos_token_id = config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] if pad_token_id is None and eos_token_id is not None: pad_token_id = eos_token_id[0] # Generate from eos if no input is specified. if input_length == 0: input_ids = input_ids.new_ones((batch_size, 1)).long() if eos_token_id is not None: input_ids = input_ids * eos_token_id[0] input_length = 1 # Prepare inputs for encoder-decoder models. if self.model.config.is_encoder_decoder: # Get outputs from the encoder. encoder = self.model.get_encoder() encoder_kwargs = kwargs.copy() encoder_kwargs.pop("use_cache", None) encoder_kwargs["input_ids"] = input_ids encoder_kwargs["return_dict"] = True encoder_outputs = self._infer(encoder, **encoder_kwargs) kwargs["encoder_outputs"] = encoder_outputs # Reinitialize inputs for the decoder. decoder_start_token_id = config.decoder_start_token_id if decoder_start_token_id is None: decoder_start_token_id = bos_token_id input_ids = input_ids.new_ones((batch_size, 1)) input_ids = input_ids * decoder_start_token_id input_length = 1 # Set up logits processor. processor = self._logits_processor(config, input_length) # Keep track of which sequences are already finished. unfinished = input_ids.new_ones(batch_size) # Start auto-regressive generation. while True: inputs = self.model.prepare_inputs_for_generation( input_ids, **kwargs ) # noqa: E501 outputs = self._infer( self.model, **inputs, return_dict=True, output_attentions=False, output_hidden_states=False, ) # Pre-process the probability distribution of the next tokens. logits = outputs.logits[:, -1, :] with torch.inference_mode(): logits = processor(input_ids, logits) probs = torch.nn.functional.softmax(logits, dim=-1) # Select deterministic or stochastic decoding strategy. if (config.top_p is not None and config.top_p <= 0) or ( config.temperature is not None and config.temperature <= 0 ): tokens = torch.argmax(probs, dim=-1)[:, None] else: tokens = torch.multinomial(probs, num_samples=1) tokens = tokens.squeeze(1) # Finished sequences should have their next token be a padding. if pad_token_id is not None: tokens = tokens * unfinished + pad_token_id * (1 - unfinished) # Append selected tokens to the inputs. input_ids = torch.cat([input_ids, tokens[:, None]], dim=-1) # Mark sequences with eos tokens as finished. if eos_token_id is not None: not_eos = sum(tokens != i for i in eos_token_id) unfinished = unfinished.mul(not_eos.long()) # Set status to -1 if exceeded the max length. status = unfinished.clone() if input_ids.shape[-1] - input_length >= config.max_new_tokens: status = 0 - status # Yield predictions and status. yield tokens # Stop when finished or exceeded the max length. if status.max() <= 0: break