zen-E commited on
Commit
659bc91
·
1 Parent(s): 8c33234

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +55 -0
README.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - ffgcc/NEWS5M
4
+ - zen-E/NEWS5M-simcse-roberta-large-embeddings-pca-256
5
+ language:
6
+ - en
7
+ metrics:
8
+ - pearsonr
9
+ - spearmanr
10
+ library_name: transformers
11
+ ---
12
+
13
+ The model is trained by knowledge distillation between the "princeton-nlp/unsup-simcse-roberta-large" and "prajjwal1/bert-mini" on the 'ffgcc/NEWS5M'.
14
+
15
+ The model can perform inferenced by Automodel.
16
+
17
+ The model achieves 0.825 and 0.83 for pearsonr and spearmanr respectively on STS-b test dataset.
18
+
19
+ For more training detail, the training config and the pytorch forward function is as follows:
20
+
21
+ ```python
22
+ config = {
23
+ 'epoch' = 200,
24
+ 'learning_rate' = 3e-4,
25
+ 'batch_size' = 12288,
26
+ 'temperature' = 0.05
27
+ }
28
+ ```
29
+
30
+ ```python
31
+ def forward_cos_mse_kd_unsup(self, sentences, teacher_sentence_embs):
32
+ """forward function for the unsupervised News5M dataset"""
33
+ _, o = self.bert(**sentences)
34
+
35
+ # cosine similarity between the first half batch and the second half batch
36
+ half_batch = o.size(0) // 2
37
+ higher_half = half_batch * 2 #skip the last datapoint when the batch size number is odd
38
+
39
+ cos_sim = cosine_sim(o[:half_batch], o[half_batch:higher_half])
40
+ cos_sim_teacher = cosine_sim(teacher_sentence_embs[:half_batch], teacher_sentence_embs[half_batch:higher_half])
41
+
42
+ # KL Divergence between student and teacher probabilities
43
+ soft_teacher_probs = F.softmax(cos_sim_teacher / self.temperature, dim=1)
44
+ kd_contrastive_loss = F.kl_div(F.log_softmax(cos_sim / self.temperature, dim=1),
45
+ soft_teacher_probs,
46
+ reduction='batchmean')
47
+
48
+ # MSE loss
49
+ kd_mse_loss = nn.MSELoss()(o, teacher_sentence_embs)/3
50
+
51
+ # equal weight for the two losses
52
+ total_loss = kd_contrastive_loss*0.5 + kd_mse_loss*0.5
53
+
54
+ return total_loss, kd_contrastive_loss, kd_mse_loss
55
+ ```