caif / generator.py
Балаганский Никита Николаевич
fix app
124c8d3
raw
history blame
7.54 kB
from typing import Optional, Union
import torch
import transformers
class Generator:
def __init__(self, lm_model_name, device, entropy=None):
self.device = device
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
lm_model_name
)
self.lm = transformers.AutoModelForCausalLM.from_pretrained(
lm_model_name
).to(device)
self.lm.eval()
self.lm.config.pad_token_id = self.lm.config.eos_token_id
self.tokenizer.add_special_tokens(
{"pad_token": self.tokenizer.decode(self.lm.config.eos_token_id)}
)
self.caif_sampler = None
self.ordinary_sampler = None
self.entropy_based_stats = {
"skips": 0,
"avg_entropy": 0,
"count": 0,
}
self.entropy = entropy
def set_caif_sampler(self, sampler):
self.caif_sampler = sampler
def set_ordinary_sampler(self, sampler):
self.ordinary_sampler = sampler
def sample_sequences(
self,
num_samples: int,
input_prompt: Optional[str],
max_length: int,
caif_period: int,
caif_tokens_num: Union[int, None] = None,
entropy: float = None,
**sampler_kwargs
):
self.entropy = entropy
input_ids, past, ended_sequences = self.get_input_ids(
input_prompt,
num_samples,
)
for i in range(max_length):
is_caif_step = (
i % caif_period == 0 and self.caif_sampler is not None
)
input_ids, past, ended_sequences = self.generation_step(
input_ids,
past,
ended_sequences,
is_caif_step,
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
if ended_sequences.all():
break
return (
[
self.tokenizer.decode(sequence, skip_special_tokens=True)
for sequence in input_ids
],
input_ids,
)
def generation_step(
self,
input_ids,
past,
ended_sequences,
is_caif_step: bool,
caif_tokens_num=None,
**sampler_kwargs
):
prepared_inputs = self.lm.prepare_inputs_for_generation(
input_ids, past, use_cache=True
)
outputs = self.lm(
**prepared_inputs,
output_attentions=False,
output_hidden_states=False,
return_dict=True
)
past = outputs.past_key_values
if self.entropy is not None:
normalized = torch.nn.functional.log_softmax(
outputs.logits, dim=-1
)
p = torch.exp(normalized)
output_probs = p
output_information = -normalized
output_entropy = (output_probs * output_information).sum(-1)[:, -1]
batch_size = output_entropy.shape[0]
caif_mask = torch.ge(output_entropy, self.entropy)
ordinary_mask = ~caif_mask
self.entropy_based_stats["skips"] += caif_mask.sum() / batch_size
self.entropy_based_stats["count"] += 1
self.entropy_based_stats["avg_entropy"] += (
output_entropy.sum() / batch_size
)
flatten_entropy = output_entropy.view(-1).cpu().tolist()
if "entropy" not in self.entropy_based_stats.keys():
self.entropy_based_stats["entropy"] = flatten_entropy
else:
self.entropy_based_stats["entropy"] += flatten_entropy
if caif_mask.sum() == 0:
next_tokens_sampler = self.ordinary_sampler
next_tokens = next_tokens_sampler(
input_ids,
outputs.logits,
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
next_tokens = (
next_tokens * (1 - ended_sequences.long())
+ self.lm.config.eos_token_id * ended_sequences.long()
).long()
elif caif_mask.sum() == batch_size:
next_tokens_sampler = self.caif_sampler
next_tokens = next_tokens_sampler(
input_ids,
outputs.logits,
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
next_tokens = (
next_tokens * (1 - ended_sequences.long())
+ self.lm.config.eos_token_id * ended_sequences.long()
).long()
else:
next_tokens_caif = self.caif_sampler(
input_ids[caif_mask],
outputs.logits[caif_mask],
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
next_tokens_ordinary = self.ordinary_sampler(
input_ids[ordinary_mask],
outputs.logits[ordinary_mask],
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
next_tokens_caif = (
next_tokens_caif * (1 - ended_sequences[caif_mask].long())
+ self.lm.config.eos_token_id
* ended_sequences[caif_mask].long()
).long()
next_tokens_ordinary = (
next_tokens_ordinary
* (1 - ended_sequences[ordinary_mask].long())
+ self.lm.config.eos_token_id
* ended_sequences[ordinary_mask].long()
).long()
next_tokens = torch.ones(batch_size).long().to(self.device)
next_tokens[caif_mask] = next_tokens_caif
next_tokens[ordinary_mask] = next_tokens_ordinary
else:
if is_caif_step:
next_tokens_sampler = self.caif_sampler
else:
next_tokens_sampler = self.ordinary_sampler
next_tokens = next_tokens_sampler(
input_ids,
outputs.logits,
caif_tokens_num=caif_tokens_num,
**sampler_kwargs
)
next_tokens = (
next_tokens * (1 - ended_sequences.long())
+ self.lm.config.eos_token_id * ended_sequences.long()
).long()
input_ids = torch.cat(
[input_ids, next_tokens[:, None].to(self.device)], dim=-1
)
ended_sequences += next_tokens == self.lm.config.eos_token_id
return input_ids, past, ended_sequences
def get_input_ids(self, input_prompt, num_samples):
input_ids = torch.tensor([[self.lm.config.bos_token_id]])
if input_prompt is not None:
input_prompt = self.tokenizer(
input_prompt, return_tensors="pt"
).input_ids
input_ids = torch.cat([input_ids, input_prompt], 1)
input_ids = input_ids.repeat(num_samples, 1).to(self.device)
past = None
ended_sequences = torch.zeros(
input_ids.shape[0], device=self.device
).bool()
return input_ids, past, ended_sequences
@staticmethod
def sample(unscaled_probs, values):
samples = torch.multinomial(unscaled_probs, 1)
return torch.take_along_dim(values, samples, dim=1)