Singhoo commited on
Commit
644b4b6
·
verified ·
1 Parent(s): 2c1c1d0

Upload model

Browse files
Files changed (3) hide show
  1. config.json +4 -0
  2. config.py +33 -0
  3. model.py +158 -0
config.json CHANGED
@@ -3,6 +3,10 @@
3
  "architectures": [
4
  "DenoSentModel"
5
  ],
 
 
 
 
6
  "contrastive_temp": 0.05,
7
  "contrastive_weight": 5.0,
8
  "decoder_noise_dropout": 0.825,
 
3
  "architectures": [
4
  "DenoSentModel"
5
  ],
6
+ "auto_map": {
7
+ "AutoConfig": "config.DenoSentConfig",
8
+ "AutoModel": "model.DenoSentModel"
9
+ },
10
  "contrastive_temp": 0.05,
11
  "contrastive_weight": 5.0,
12
  "decoder_noise_dropout": 0.825,
config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import Optional
3
+
4
+ class DenoSentConfig(PretrainedConfig):
5
+ def __init__(self,
6
+ encoder_name_or_path:Optional[str]=None,
7
+ hidden_size:Optional[int]=768,
8
+ max_length:Optional[int]=32,
9
+ decoder_num_heads:Optional[int]=1,
10
+ decoder_num_layers:Optional[int]=16,
11
+ decoder_noise_dropout:Optional[float]=0.825,
12
+ pooler:Optional[str]='mask',
13
+ do_contrastive:Optional[bool]=False,
14
+ do_generative:Optional[bool]=False,
15
+ prompt_format:Optional[str]='[X] means [MASK]',
16
+ contrastive_weight:Optional[float]=1.0,
17
+ generative_weight:Optional[float]=1.0,
18
+ contrastive_temp: Optional[float]=0.05,
19
+ **kwargs):
20
+ super().__init__(**kwargs)
21
+ self.encoder_name_or_path = encoder_name_or_path
22
+ self.hidden_size = hidden_size
23
+ self.max_length = max_length
24
+ self.decoder_num_heads = decoder_num_heads
25
+ self.decoder_num_layers = decoder_num_layers
26
+ self.decoder_noise_dropout = decoder_noise_dropout
27
+ self.pooler = pooler
28
+ self.do_contrastive = do_contrastive
29
+ self.do_generative = do_generative
30
+ self.prompt_format = prompt_format
31
+ self.contrastive_weight = contrastive_weight
32
+ self.generative_weight = generative_weight
33
+ self.contrastive_temp = contrastive_temp
model.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, BertForMaskedLM
2
+ from transformers.models.bert.modeling_bert import BertForMaskedLM
3
+ from transformers.modeling_outputs import TokenClassifierOutput
4
+ from transformers import PreTrainedModel
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from torch.nn import CrossEntropyLoss, TransformerDecoder, TransformerDecoderLayer
9
+
10
+ from typing import Optional
11
+
12
+ import wandb
13
+ import numpy as np
14
+
15
+ class DenoSentModel(PreTrainedModel):
16
+ def __init__(self, config):
17
+ super().__init__(config)
18
+ self.pooler = config.pooler
19
+ self.sent_embedding_projector = nn.Linear(config.hidden_size, config.hidden_size)
20
+ self.decoder = TransformerDecoder(TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.decoder_num_heads, batch_first=True, dropout=0.1), num_layers=config.decoder_num_layers)
21
+ self.decoder_noise_dropout = nn.Dropout(config.decoder_noise_dropout)
22
+ self.sim = nn.CosineSimilarity(dim=-1)
23
+ self.init_weights()
24
+ self.tokenizer = AutoTokenizer.from_pretrained(config.encoder_name_or_path)
25
+ self.encoder = BertForMaskedLM.from_pretrained(config.encoder_name_or_path)
26
+ self.prediction_head = self.encoder.cls
27
+ self.encoder = self.encoder.bert
28
+ self.post_init()
29
+
30
+ def _init_weights(self, module):
31
+ """Initialize the weights"""
32
+ if isinstance(module, nn.Linear):
33
+ # Slightly different from the TF version which uses truncated_normal for initialization
34
+ # cf https://github.com/pytorch/pytorch/pull/5617
35
+ module.weight.data.normal_(mean=0.0, std=0.02)
36
+ if module.bias is not None:
37
+ module.bias.data.zero_()
38
+ elif isinstance(module, nn.Embedding):
39
+ module.weight.data.normal_(mean=0.0, std=0.02)
40
+ if module.padding_idx is not None:
41
+ module.weight.data[module.padding_idx].zero_()
42
+ elif isinstance(module, nn.LayerNorm):
43
+ module.bias.data.zero_()
44
+ module.weight.data.fill_(1.0)
45
+
46
+ def encode(self, sentences, batch_size=32, **kwargs):
47
+ """ Returns a list of embeddings for the given sentences.
48
+ Args:
49
+ sentences (`List[str]`): List of sentences to encode
50
+ batch_size (`int`): Batch size for the encoding
51
+
52
+ Returns:
53
+ `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences
54
+ """
55
+ self.eval()
56
+ all_embeddings = []
57
+ length_sorted_idx = np.argsort([len(sen) for sen in sentences])
58
+ sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
59
+ if self.config.pooler == 'mask':
60
+ prompt_length = len(self.tokenizer(self.config.prompt_format, add_special_tokens=False)['input_ids'])
61
+ sentences_sorted = self.tokenizer.batch_decode(self.tokenizer(sentences_sorted, padding=True, truncation=True, max_length=self.config.max_length, return_tensors='pt').input_ids, skip_special_tokens=True)
62
+ sentences_sorted = [self.config.prompt_format.replace('[X]', s).replace('[MASK]', self.tokenizer.mask_token) for s in sentences_sorted]
63
+ for start_index in range(0, len(sentences), batch_size):
64
+ sentences_batch = sentences_sorted[start_index:start_index+batch_size]
65
+ inputs = self.tokenizer(sentences_batch, padding='max_length', truncation=True, return_tensors="pt", max_length=self.config.max_length+prompt_length)
66
+ inputs = {k: v.to(self.device) for k,v in inputs.items()}
67
+ with torch.no_grad():
68
+ encoder_outputs = self.encoder(**inputs, output_hidden_states=True, output_attentions=True, return_dict=True)
69
+ last_hidden_state = encoder_outputs.last_hidden_state
70
+ if self.config.pooler == 'cls':
71
+ embeddings = last_hidden_state[:, 0, :]
72
+ elif self.config.pooler == 'mean':
73
+ embeddings = (last_hidden_state * inputs['attention_mask'].unsqueeze(-1)).sum(1) / inputs['attention_mask'].sum(-1).unsqueeze(-1)
74
+ elif self.pooler == 'mask':
75
+ embeddings = last_hidden_state[inputs['input_ids'] == self.tokenizer.mask_token_id]
76
+ else:
77
+ raise NotImplementedError()
78
+ all_embeddings.extend(embeddings.cpu().numpy())
79
+ all_embeddings = torch.tensor(np.array([all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]))
80
+ return all_embeddings
81
+
82
+ def forward(
83
+ self,
84
+ input_ids: Optional[torch.LongTensor] = None,
85
+ attention_mask: Optional[torch.LongTensor] = None,
86
+ positive_input_ids: Optional[torch.LongTensor] = None,
87
+ positive_attention_mask: Optional[torch.LongTensor] = None,
88
+ negative_input_ids: Optional[torch.LongTensor] = None,
89
+ negative_attention_mask: Optional[torch.LongTensor] = None,
90
+ global_step: Optional[int] = None,
91
+ max_steps: Optional[int] = None,
92
+ ):
93
+ batch_size = input_ids.size(0)
94
+ if negative_input_ids is not None:
95
+ encoder_input_ids = torch.cat([input_ids, positive_input_ids, negative_input_ids], dim=0).to(self.device)
96
+ encoder_attention_mask = torch.cat([attention_mask, positive_attention_mask, negative_attention_mask], dim=0).to(self.device)
97
+ elif positive_input_ids is not None:
98
+ encoder_input_ids = torch.cat([input_ids, positive_input_ids], dim=0).to(self.device)
99
+ encoder_attention_mask = torch.cat([attention_mask, positive_attention_mask], dim=0).to(self.device)
100
+ elif self.config.do_contrastive:
101
+ encoder_input_ids = torch.cat([input_ids, input_ids], dim=0).to(self.device)
102
+ encoder_attention_mask = torch.cat([attention_mask, attention_mask], dim=0).to(self.device)
103
+ elif self.config.do_generative and not self.config.do_contrastive:
104
+ encoder_input_ids = input_ids.to(self.device)
105
+ encoder_attention_mask = attention_mask.to(self.device)
106
+ else:
107
+ raise NotImplementedError()
108
+ encoder_outputs = self.encoder(input_ids=encoder_input_ids, attention_mask=encoder_attention_mask, return_dict=True, output_hidden_states=True, output_attentions=True)
109
+ if self.pooler == 'cls':
110
+ sent_embedding = encoder_outputs.last_hidden_state[:, 0, :]
111
+ elif self.pooler == 'mean':
112
+ sent_embedding = ((encoder_outputs.last_hidden_state * encoder_attention_mask.unsqueeze(-1)).sum(1) / encoder_attention_mask.sum(-1).unsqueeze(-1))
113
+ elif self.pooler == 'mask':
114
+ sent_embedding = encoder_outputs.last_hidden_state[encoder_input_ids == self.tokenizer.mask_token_id]
115
+ else:
116
+ raise NotImplementedError()
117
+ sent_embedding = sent_embedding.unsqueeze(1)
118
+ sent_embedding = self.sent_embedding_projector(sent_embedding)
119
+
120
+ if self.config.do_generative:
121
+ if positive_input_ids is not None:
122
+ tgt = encoder_outputs.hidden_states[0][batch_size:2*batch_size].detach()
123
+ tgt_key_padding_mask = (positive_input_ids == self.tokenizer.pad_token_id)
124
+ labels = positive_input_ids
125
+ else:
126
+ tgt = encoder_outputs.hidden_states[0][:batch_size].detach()
127
+ tgt_key_padding_mask = (input_ids == self.tokenizer.pad_token_id)
128
+ labels = input_ids
129
+ tgt = self.decoder_noise_dropout(tgt)
130
+ decoder_outputs = self.decoder(tgt=tgt, memory=sent_embedding[:batch_size], tgt_mask=None, tgt_key_padding_mask=tgt_key_padding_mask)
131
+ logits = self.prediction_head(decoder_outputs)
132
+ loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
133
+ generative_loss = loss_fct(logits.view(-1, self.encoder.config.vocab_size), labels.view(-1))
134
+ wandb.log({'train/generative_loss': generative_loss})
135
+
136
+ if self.config.do_contrastive:
137
+ positive_sim = self.sim(sent_embedding[:batch_size], sent_embedding[batch_size:2*batch_size].transpose(0, 1))
138
+ cos_sim = positive_sim
139
+ if negative_attention_mask is not None:
140
+ negative_sim = self.sim(sent_embedding[:batch_size], sent_embedding[2*batch_size:].transpose(0, 1))
141
+ cos_sim = torch.cat([positive_sim, negative_sim], dim=1)
142
+ cos_sim = cos_sim / self.config.contrastive_temp
143
+ contrastive_labels = torch.arange(batch_size, dtype=torch.long, device=self.device)
144
+ contrastive_loss = nn.CrossEntropyLoss()(cos_sim, contrastive_labels)
145
+ wandb.log({'train/contrastive_loss': contrastive_loss.item()})
146
+ logits = None
147
+ loss = 0
148
+ if self.config.do_contrastive:
149
+ loss += self.config.contrastive_weight * contrastive_loss
150
+ if self.config.do_generative:
151
+ loss += self.config.generative_weight * generative_loss
152
+ wandb.log({'train/loss': loss})
153
+ return TokenClassifierOutput(
154
+ loss=loss,
155
+ logits=logits,
156
+ hidden_states=encoder_outputs.hidden_states,
157
+ attentions=encoder_outputs.attentions,
158
+ )