File size: 9,214 Bytes
644b4b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from transformers import AutoTokenizer, BertForMaskedLM
from transformers.models.bert.modeling_bert import BertForMaskedLM
from transformers.modeling_outputs import TokenClassifierOutput
from transformers import PreTrainedModel
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss, TransformerDecoder, TransformerDecoderLayer
from typing import Optional
import wandb
import numpy as np
class DenoSentModel(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.pooler = config.pooler
self.sent_embedding_projector = nn.Linear(config.hidden_size, config.hidden_size)
self.decoder = TransformerDecoder(TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.decoder_num_heads, batch_first=True, dropout=0.1), num_layers=config.decoder_num_layers)
self.decoder_noise_dropout = nn.Dropout(config.decoder_noise_dropout)
self.sim = nn.CosineSimilarity(dim=-1)
self.init_weights()
self.tokenizer = AutoTokenizer.from_pretrained(config.encoder_name_or_path)
self.encoder = BertForMaskedLM.from_pretrained(config.encoder_name_or_path)
self.prediction_head = self.encoder.cls
self.encoder = self.encoder.bert
self.post_init()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def encode(self, sentences, batch_size=32, **kwargs):
""" Returns a list of embeddings for the given sentences.
Args:
sentences (`List[str]`): List of sentences to encode
batch_size (`int`): Batch size for the encoding
Returns:
`List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences
"""
self.eval()
all_embeddings = []
length_sorted_idx = np.argsort([len(sen) for sen in sentences])
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
if self.config.pooler == 'mask':
prompt_length = len(self.tokenizer(self.config.prompt_format, add_special_tokens=False)['input_ids'])
sentences_sorted = self.tokenizer.batch_decode(self.tokenizer(sentences_sorted, padding=True, truncation=True, max_length=self.config.max_length, return_tensors='pt').input_ids, skip_special_tokens=True)
sentences_sorted = [self.config.prompt_format.replace('[X]', s).replace('[MASK]', self.tokenizer.mask_token) for s in sentences_sorted]
for start_index in range(0, len(sentences), batch_size):
sentences_batch = sentences_sorted[start_index:start_index+batch_size]
inputs = self.tokenizer(sentences_batch, padding='max_length', truncation=True, return_tensors="pt", max_length=self.config.max_length+prompt_length)
inputs = {k: v.to(self.device) for k,v in inputs.items()}
with torch.no_grad():
encoder_outputs = self.encoder(**inputs, output_hidden_states=True, output_attentions=True, return_dict=True)
last_hidden_state = encoder_outputs.last_hidden_state
if self.config.pooler == 'cls':
embeddings = last_hidden_state[:, 0, :]
elif self.config.pooler == 'mean':
embeddings = (last_hidden_state * inputs['attention_mask'].unsqueeze(-1)).sum(1) / inputs['attention_mask'].sum(-1).unsqueeze(-1)
elif self.pooler == 'mask':
embeddings = last_hidden_state[inputs['input_ids'] == self.tokenizer.mask_token_id]
else:
raise NotImplementedError()
all_embeddings.extend(embeddings.cpu().numpy())
all_embeddings = torch.tensor(np.array([all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]))
return all_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
positive_input_ids: Optional[torch.LongTensor] = None,
positive_attention_mask: Optional[torch.LongTensor] = None,
negative_input_ids: Optional[torch.LongTensor] = None,
negative_attention_mask: Optional[torch.LongTensor] = None,
global_step: Optional[int] = None,
max_steps: Optional[int] = None,
):
batch_size = input_ids.size(0)
if negative_input_ids is not None:
encoder_input_ids = torch.cat([input_ids, positive_input_ids, negative_input_ids], dim=0).to(self.device)
encoder_attention_mask = torch.cat([attention_mask, positive_attention_mask, negative_attention_mask], dim=0).to(self.device)
elif positive_input_ids is not None:
encoder_input_ids = torch.cat([input_ids, positive_input_ids], dim=0).to(self.device)
encoder_attention_mask = torch.cat([attention_mask, positive_attention_mask], dim=0).to(self.device)
elif self.config.do_contrastive:
encoder_input_ids = torch.cat([input_ids, input_ids], dim=0).to(self.device)
encoder_attention_mask = torch.cat([attention_mask, attention_mask], dim=0).to(self.device)
elif self.config.do_generative and not self.config.do_contrastive:
encoder_input_ids = input_ids.to(self.device)
encoder_attention_mask = attention_mask.to(self.device)
else:
raise NotImplementedError()
encoder_outputs = self.encoder(input_ids=encoder_input_ids, attention_mask=encoder_attention_mask, return_dict=True, output_hidden_states=True, output_attentions=True)
if self.pooler == 'cls':
sent_embedding = encoder_outputs.last_hidden_state[:, 0, :]
elif self.pooler == 'mean':
sent_embedding = ((encoder_outputs.last_hidden_state * encoder_attention_mask.unsqueeze(-1)).sum(1) / encoder_attention_mask.sum(-1).unsqueeze(-1))
elif self.pooler == 'mask':
sent_embedding = encoder_outputs.last_hidden_state[encoder_input_ids == self.tokenizer.mask_token_id]
else:
raise NotImplementedError()
sent_embedding = sent_embedding.unsqueeze(1)
sent_embedding = self.sent_embedding_projector(sent_embedding)
if self.config.do_generative:
if positive_input_ids is not None:
tgt = encoder_outputs.hidden_states[0][batch_size:2*batch_size].detach()
tgt_key_padding_mask = (positive_input_ids == self.tokenizer.pad_token_id)
labels = positive_input_ids
else:
tgt = encoder_outputs.hidden_states[0][:batch_size].detach()
tgt_key_padding_mask = (input_ids == self.tokenizer.pad_token_id)
labels = input_ids
tgt = self.decoder_noise_dropout(tgt)
decoder_outputs = self.decoder(tgt=tgt, memory=sent_embedding[:batch_size], tgt_mask=None, tgt_key_padding_mask=tgt_key_padding_mask)
logits = self.prediction_head(decoder_outputs)
loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
generative_loss = loss_fct(logits.view(-1, self.encoder.config.vocab_size), labels.view(-1))
wandb.log({'train/generative_loss': generative_loss})
if self.config.do_contrastive:
positive_sim = self.sim(sent_embedding[:batch_size], sent_embedding[batch_size:2*batch_size].transpose(0, 1))
cos_sim = positive_sim
if negative_attention_mask is not None:
negative_sim = self.sim(sent_embedding[:batch_size], sent_embedding[2*batch_size:].transpose(0, 1))
cos_sim = torch.cat([positive_sim, negative_sim], dim=1)
cos_sim = cos_sim / self.config.contrastive_temp
contrastive_labels = torch.arange(batch_size, dtype=torch.long, device=self.device)
contrastive_loss = nn.CrossEntropyLoss()(cos_sim, contrastive_labels)
wandb.log({'train/contrastive_loss': contrastive_loss.item()})
logits = None
loss = 0
if self.config.do_contrastive:
loss += self.config.contrastive_weight * contrastive_loss
if self.config.do_generative:
loss += self.config.generative_weight * generative_loss
wandb.log({'train/loss': loss})
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
|