File size: 5,756 Bytes
9852b1b 570bb74 030a0f8 e852933 030a0f8 570bb74 030a0f8 e852933 030a0f8 fb0c713 030a0f8 99fe246 030a0f8 d320fdd 030a0f8 d320fdd f57bdfa 030a0f8 f57bdfa 030a0f8 f57bdfa 030a0f8 57ffe79 030a0f8 f57bdfa 030a0f8 f57bdfa 030a0f8 f57bdfa 030a0f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from typing import List
from functools import lru_cache
import torch
from torch.nn import functional as F
import transformers
from utils import get_cls
def sample_from_values(unscaled_probs, values):
samples = torch.multinomial(unscaled_probs, 1)
return torch.take_along_dim(values, samples, dim=1)
class TopKWithTemperatureSampler:
def __call__(self, input_ids, output_logits, top_k, temperature, **kwargs):
next_token_logits = output_logits[:, -1]
next_token_log_probs = F.log_softmax(
next_token_logits, dim=-1
)
topk_log_probs = next_token_log_probs.topk(top_k, -1)
next_tokens = sample_from_values(
torch.exp(topk_log_probs[0] / temperature), topk_log_probs[1]
).squeeze(1)
return next_tokens
class CAIFSampler:
@lru_cache(20)
def __init__(self, classifier_name, lm_tokenizer, device, invert_cls_probs: bool = False):
self.device = device
self.classifier_tokenizer = transformers.AutoTokenizer.from_pretrained(
classifier_name
)
self.classifier_model = (
get_cls(classifier_name).to(device)
)
self.classifier_model.eval()
self.lm_tokenizer = lm_tokenizer
self.invert_cls_probs = invert_cls_probs
def __call__(
self,
input_ids,
output_logis,
top_k,
temperature,
top_k_classifier,
classifier_weight,
caif_tokens_num=None,
act_type: str = "sigmoid",
target_cls_id: int = 0,
**kwargs
):
print(act_type)
next_token_logits = output_logis[:, -1]
next_token_log_probs = F.log_softmax(
next_token_logits, dim=-1
)
(next_token_unnormalized_probs, topk_indices,) = self.get_unnormalized_probs(
input_ids,
next_token_log_probs,
temperature,
top_k_classifier,
classifier_weight,
caif_tokens_num=caif_tokens_num,
target_cls_id=target_cls_id
)
topk_probs = next_token_unnormalized_probs.topk(top_k, -1)
next_tokens = sample_from_values(
topk_probs[0],
torch.take_along_dim(topk_indices, topk_probs[1], dim=1),
).squeeze(1)
return next_tokens
def get_unnormalized_probs(
self,
input_ids,
next_token_log_probs,
temperature,
top_k_classifier,
classifier_weight,
target_cls_id: int = 0,
act_type: str = "sigmoid",
caif_tokens_num=None
):
if classifier_weight == 0.0:
raise ValueError(
"classifier weight equal to 0 is not supported for CAIF Sampling"
)
top_next_token_log_probs = next_token_log_probs.topk(top_k_classifier, -1)
classifier_input = torch.cat(
[
input_ids.unsqueeze(1).repeat(1, top_k_classifier, 1).flatten(0, 1),
top_next_token_log_probs[1].view(-1).unsqueeze(-1),
],
-1,
)
classifier_input = [
self.lm_tokenizer.decode(sequence, skip_special_tokens=True)
for sequence in classifier_input
]
if self.invert_cls_probs:
classifier_log_probs = torch.log(
1 - self.get_classifier_probs(
classifier_input, caif_tokens_num=caif_tokens_num, target_cls_id=target_cls_id
).view(-1, top_k_classifier)
)
else:
classifier_log_probs = self.get_classifier_log_probs(
classifier_input,
caif_tokens_num=caif_tokens_num,
target_cls_id=target_cls_id,
act_type=act_type,
).view(-1, top_k_classifier)
next_token_probs = torch.exp(
(top_next_token_log_probs[0] +
classifier_weight * (classifier_log_probs - classifier_log_probs.mean(-1)) -
top_next_token_log_probs[0].mean(-1))
/ temperature
)
return next_token_probs, top_next_token_log_probs[1]
def get_classifier_log_probs(self, input, caif_tokens_num=None, target_cls_id: int = 0, act_type: str = "sigmoid"):
input_ids = self.classifier_tokenizer(
input, padding=True, return_tensors="pt"
).to(self.device)
if caif_tokens_num is not None:
input_ids["input_ids"] = input_ids["input_ids"][:, -caif_tokens_num:]
if "attention_mask" in input_ids.keys():
input_ids["attention_mask"] = input_ids["attention_mask"][:, -caif_tokens_num:]
if "token_type_ids" in input_ids.keys():
input_ids["token_type_ids"] = input_ids["token_type_ids"][:, -caif_tokens_num:]
if act_type == "sigmoid":
logits = self.classifier_model(**input_ids).logits[:, target_cls_id].squeeze(-1)
return F.logsigmoid(logits)
if act_type == "softmax":
logits = F.log_softmax(self.classifier_model(**input_ids).logits)[:, target_cls_id].squeeze(-1)
return logits
def get_classifier_probs(self, input, caif_tokens_num=None, target_cls_id: int = 0):
input_ids = self.classifier_tokenizer(
input, padding=True, return_tensors="pt"
).to(self.device)
if caif_tokens_num is not None:
input_ids["input_ids"] = input_ids["input_ids"][-caif_tokens_num:]
if "attention_mask" in input_ids.keys():
input_ids["attention_mask"] = input_ids["attention_mask"][-caif_tokens_num:]
logits = self.classifier_model(**input_ids).logits[:, target_cls_id].squeeze(-1)
return torch.sigmoid(logits)
|