File size: 9,214 Bytes
644b4b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from transformers import AutoTokenizer, BertForMaskedLM
from transformers.models.bert.modeling_bert import BertForMaskedLM
from transformers.modeling_outputs import TokenClassifierOutput
from transformers import PreTrainedModel
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss, TransformerDecoder, TransformerDecoderLayer

from typing import Optional

import wandb
import numpy as np

class DenoSentModel(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.pooler = config.pooler
        self.sent_embedding_projector = nn.Linear(config.hidden_size, config.hidden_size)
        self.decoder = TransformerDecoder(TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.decoder_num_heads, batch_first=True, dropout=0.1), num_layers=config.decoder_num_layers)
        self.decoder_noise_dropout = nn.Dropout(config.decoder_noise_dropout)
        self.sim = nn.CosineSimilarity(dim=-1)
        self.init_weights()
        self.tokenizer = AutoTokenizer.from_pretrained(config.encoder_name_or_path)
        self.encoder = BertForMaskedLM.from_pretrained(config.encoder_name_or_path)
        self.prediction_head = self.encoder.cls
        self.encoder = self.encoder.bert
        self.post_init()

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    
    def encode(self, sentences, batch_size=32, **kwargs):
        """ Returns a list of embeddings for the given sentences.
        Args:
            sentences (`List[str]`): List of sentences to encode
            batch_size (`int`): Batch size for the encoding

        Returns:
            `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences
        """
        self.eval()
        all_embeddings = []
        length_sorted_idx = np.argsort([len(sen) for sen in sentences])
        sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
        if self.config.pooler == 'mask':
            prompt_length = len(self.tokenizer(self.config.prompt_format, add_special_tokens=False)['input_ids'])
            sentences_sorted = self.tokenizer.batch_decode(self.tokenizer(sentences_sorted, padding=True, truncation=True, max_length=self.config.max_length, return_tensors='pt').input_ids, skip_special_tokens=True)
            sentences_sorted = [self.config.prompt_format.replace('[X]', s).replace('[MASK]', self.tokenizer.mask_token) for s in sentences_sorted]
        for start_index in range(0, len(sentences), batch_size):
            sentences_batch = sentences_sorted[start_index:start_index+batch_size]
            inputs = self.tokenizer(sentences_batch, padding='max_length', truncation=True, return_tensors="pt", max_length=self.config.max_length+prompt_length)
            inputs = {k: v.to(self.device) for k,v in inputs.items()}
            with torch.no_grad():
                encoder_outputs = self.encoder(**inputs, output_hidden_states=True, output_attentions=True, return_dict=True)
                last_hidden_state = encoder_outputs.last_hidden_state
                if self.config.pooler == 'cls':
                    embeddings = last_hidden_state[:, 0, :]
                elif self.config.pooler == 'mean':
                    embeddings = (last_hidden_state * inputs['attention_mask'].unsqueeze(-1)).sum(1) / inputs['attention_mask'].sum(-1).unsqueeze(-1)
                elif self.pooler == 'mask':
                    embeddings = last_hidden_state[inputs['input_ids'] == self.tokenizer.mask_token_id]
                else:
                    raise NotImplementedError()
            all_embeddings.extend(embeddings.cpu().numpy())
        all_embeddings = torch.tensor(np.array([all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]))
        return all_embeddings

    def forward(
            self,
            input_ids: Optional[torch.LongTensor] = None,
            attention_mask: Optional[torch.LongTensor] = None,
            positive_input_ids: Optional[torch.LongTensor] = None,
            positive_attention_mask: Optional[torch.LongTensor] = None,
            negative_input_ids: Optional[torch.LongTensor] = None,
            negative_attention_mask: Optional[torch.LongTensor] = None,
            global_step: Optional[int] = None,
            max_steps: Optional[int] = None,
    ):
        batch_size = input_ids.size(0)
        if negative_input_ids is not None:
            encoder_input_ids = torch.cat([input_ids, positive_input_ids, negative_input_ids], dim=0).to(self.device)
            encoder_attention_mask = torch.cat([attention_mask, positive_attention_mask, negative_attention_mask], dim=0).to(self.device)
        elif positive_input_ids is not None:
            encoder_input_ids = torch.cat([input_ids, positive_input_ids], dim=0).to(self.device)
            encoder_attention_mask = torch.cat([attention_mask, positive_attention_mask], dim=0).to(self.device)
        elif self.config.do_contrastive:
            encoder_input_ids = torch.cat([input_ids, input_ids], dim=0).to(self.device)
            encoder_attention_mask = torch.cat([attention_mask, attention_mask], dim=0).to(self.device)
        elif self.config.do_generative and not self.config.do_contrastive:
            encoder_input_ids = input_ids.to(self.device)
            encoder_attention_mask = attention_mask.to(self.device)
        else:
            raise NotImplementedError()
        encoder_outputs = self.encoder(input_ids=encoder_input_ids, attention_mask=encoder_attention_mask, return_dict=True, output_hidden_states=True, output_attentions=True)
        if self.pooler == 'cls':
            sent_embedding = encoder_outputs.last_hidden_state[:, 0, :]
        elif self.pooler == 'mean':
            sent_embedding = ((encoder_outputs.last_hidden_state * encoder_attention_mask.unsqueeze(-1)).sum(1) / encoder_attention_mask.sum(-1).unsqueeze(-1))
        elif self.pooler == 'mask':
            sent_embedding = encoder_outputs.last_hidden_state[encoder_input_ids == self.tokenizer.mask_token_id]
        else:
            raise NotImplementedError()
        sent_embedding = sent_embedding.unsqueeze(1)
        sent_embedding = self.sent_embedding_projector(sent_embedding)

        if self.config.do_generative:
            if positive_input_ids is not None:
                tgt = encoder_outputs.hidden_states[0][batch_size:2*batch_size].detach()
                tgt_key_padding_mask = (positive_input_ids == self.tokenizer.pad_token_id)
                labels = positive_input_ids
            else:
                tgt = encoder_outputs.hidden_states[0][:batch_size].detach()
                tgt_key_padding_mask = (input_ids == self.tokenizer.pad_token_id)
                labels = input_ids
            tgt = self.decoder_noise_dropout(tgt)
            decoder_outputs = self.decoder(tgt=tgt, memory=sent_embedding[:batch_size], tgt_mask=None, tgt_key_padding_mask=tgt_key_padding_mask)
            logits = self.prediction_head(decoder_outputs)
            loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
            generative_loss = loss_fct(logits.view(-1, self.encoder.config.vocab_size), labels.view(-1))
            wandb.log({'train/generative_loss': generative_loss})

        if self.config.do_contrastive:
            positive_sim = self.sim(sent_embedding[:batch_size], sent_embedding[batch_size:2*batch_size].transpose(0, 1))
            cos_sim = positive_sim
            if negative_attention_mask is not None:
                negative_sim = self.sim(sent_embedding[:batch_size], sent_embedding[2*batch_size:].transpose(0, 1))
                cos_sim = torch.cat([positive_sim, negative_sim], dim=1)
            cos_sim = cos_sim / self.config.contrastive_temp
            contrastive_labels = torch.arange(batch_size, dtype=torch.long, device=self.device)
            contrastive_loss = nn.CrossEntropyLoss()(cos_sim, contrastive_labels)
            wandb.log({'train/contrastive_loss': contrastive_loss.item()})
            logits = None
        loss = 0
        if self.config.do_contrastive:
            loss += self.config.contrastive_weight * contrastive_loss
        if self.config.do_generative:
            loss += self.config.generative_weight * generative_loss
        wandb.log({'train/loss': loss})
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )