from typing import Optional, Union import torch import transformers class Generator: def __init__(self, lm_model_name, device, entropy=None): self.device = device self.tokenizer = transformers.AutoTokenizer.from_pretrained( lm_model_name ) self.lm = transformers.AutoModelForCausalLM.from_pretrained( lm_model_name ).to(device) self.lm.eval() self.lm.config.pad_token_id = self.lm.config.eos_token_id self.tokenizer.add_special_tokens( {"pad_token": self.tokenizer.decode(self.lm.config.eos_token_id)} ) self.caif_sampler = None self.ordinary_sampler = None self.entropy_based_stats = { "skips": 0, "avg_entropy": 0, "count": 0, } self.entropy = entropy def set_caif_sampler(self, sampler): self.caif_sampler = sampler def set_ordinary_sampler(self, sampler): self.ordinary_sampler = sampler def sample_sequences( self, num_samples: int, input_prompt: Optional[str], max_length: int, caif_period: int, caif_tokens_num: Union[int, None] = None, entropy: float = None, **sampler_kwargs ): self.entropy = entropy input_ids, past, ended_sequences = self.get_input_ids( input_prompt, num_samples, ) for i in range(max_length): is_caif_step = ( i % caif_period == 0 and self.caif_sampler is not None ) input_ids, past, ended_sequences = self.generation_step( input_ids, past, ended_sequences, is_caif_step, caif_tokens_num=caif_tokens_num, **sampler_kwargs ) if ended_sequences.all(): break return ( [ self.tokenizer.decode(sequence, skip_special_tokens=True) for sequence in input_ids ], input_ids, ) def generation_step( self, input_ids, past, ended_sequences, is_caif_step: bool, caif_tokens_num=None, **sampler_kwargs ): prepared_inputs = self.lm.prepare_inputs_for_generation( input_ids, past, use_cache=True ) outputs = self.lm( **prepared_inputs, output_attentions=False, output_hidden_states=False, return_dict=True ) past = outputs.past_key_values if self.entropy is not None: normalized = torch.nn.functional.log_softmax( outputs.logits, dim=-1 ) p = torch.exp(normalized) output_probs = p output_information = -normalized output_entropy = (output_probs * output_information).sum(-1)[:, -1] batch_size = output_entropy.shape[0] caif_mask = torch.ge(output_entropy, self.entropy) ordinary_mask = ~caif_mask self.entropy_based_stats["skips"] += caif_mask.sum() / batch_size self.entropy_based_stats["count"] += 1 self.entropy_based_stats["avg_entropy"] += ( output_entropy.sum() / batch_size ) flatten_entropy = output_entropy.view(-1).cpu().tolist() if "entropy" not in self.entropy_based_stats.keys(): self.entropy_based_stats["entropy"] = flatten_entropy else: self.entropy_based_stats["entropy"] += flatten_entropy if caif_mask.sum() == 0: next_tokens_sampler = self.ordinary_sampler next_tokens = next_tokens_sampler( input_ids, outputs.logits, caif_tokens_num=caif_tokens_num, **sampler_kwargs ) next_tokens = ( next_tokens * (1 - ended_sequences.long()) + self.lm.config.eos_token_id * ended_sequences.long() ).long() elif caif_mask.sum() == batch_size: next_tokens_sampler = self.caif_sampler next_tokens = next_tokens_sampler( input_ids, outputs.logits, caif_tokens_num=caif_tokens_num, **sampler_kwargs ) next_tokens = ( next_tokens * (1 - ended_sequences.long()) + self.lm.config.eos_token_id * ended_sequences.long() ).long() else: next_tokens_caif = self.caif_sampler( input_ids[caif_mask], outputs.logits[caif_mask], caif_tokens_num=caif_tokens_num, **sampler_kwargs ) next_tokens_ordinary = self.ordinary_sampler( input_ids[ordinary_mask], outputs.logits[ordinary_mask], caif_tokens_num=caif_tokens_num, **sampler_kwargs ) next_tokens_caif = ( next_tokens_caif * (1 - ended_sequences[caif_mask].long()) + self.lm.config.eos_token_id * ended_sequences[caif_mask].long() ).long() next_tokens_ordinary = ( next_tokens_ordinary * (1 - ended_sequences[ordinary_mask].long()) + self.lm.config.eos_token_id * ended_sequences[ordinary_mask].long() ).long() next_tokens = torch.ones(batch_size).long().to(self.device) next_tokens[caif_mask] = next_tokens_caif next_tokens[ordinary_mask] = next_tokens_ordinary else: if is_caif_step: next_tokens_sampler = self.caif_sampler else: next_tokens_sampler = self.ordinary_sampler next_tokens = next_tokens_sampler( input_ids, outputs.logits, caif_tokens_num=caif_tokens_num, **sampler_kwargs ) next_tokens = ( next_tokens * (1 - ended_sequences.long()) + self.lm.config.eos_token_id * ended_sequences.long() ).long() input_ids = torch.cat( [input_ids, next_tokens[:, None].to(self.device)], dim=-1 ) ended_sequences += next_tokens == self.lm.config.eos_token_id return input_ids, past, ended_sequences def get_input_ids(self, input_prompt, num_samples): input_ids = torch.tensor([[self.lm.config.bos_token_id]]) if input_prompt is not None: input_prompt = self.tokenizer( input_prompt, return_tensors="pt" ).input_ids input_ids = torch.cat([input_ids, input_prompt], 1) input_ids = input_ids.repeat(num_samples, 1).to(self.device) past = None ended_sequences = torch.zeros( input_ids.shape[0], device=self.device ).bool() return input_ids, past, ended_sequences @staticmethod def sample(unscaled_probs, values): samples = torch.multinomial(unscaled_probs, 1) return torch.take_along_dim(values, samples, dim=1)