Punyajoy commited on
Commit
1947f7c
1 Parent(s): a3f464a

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +67 -0
models.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForTokenClassification, AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup
3
+ from transformers import BertForTokenClassification, BertForSequenceClassification,BertPreTrainedModel, BertModel
4
+ import torch.nn as nn
5
+ from .utils import *
6
+ import torch.nn.functional as F
7
+
8
+
9
+
10
+ class Model_Rational_Label(BertPreTrainedModel):
11
+ def __init__(self,config,params):
12
+ super().__init__(config)
13
+ self.num_labels=params['num_classes']
14
+ self.impact_factor=params['rationale_impact']
15
+ self.bert = BertModel(config,add_pooling_layer=False)
16
+ self.bert_pooler=BertPooler(config)
17
+ self.token_dropout = nn.Dropout(0.1)
18
+ self.token_classifier = nn.Linear(config.hidden_size, 2)
19
+ self.dropout = nn.Dropout(0.1)
20
+ self.classifier = nn.Linear(config.hidden_size, self.num_labels)
21
+ self.init_weights()
22
+ # self.embeddings = AutoModelForTokenClassification.from_pretrained(params['model_path'], cache_dir=params['cache_path'])
23
+
24
+ def forward(self, input_ids=None, mask=None, attn=None, labels=None):
25
+ outputs = self.bert(input_ids, mask)
26
+ # out = outputs.last_hidden_state
27
+ out=outputs[0]
28
+ logits = self.token_classifier(self.token_dropout(out))
29
+
30
+
31
+ # mean_pooling = torch.mean(out, 1)
32
+ # max_pooling, _ = torch.max(out, 1)
33
+ # embed = torch.cat((mean_pooling, max_pooling), 1)
34
+ embed=self.bert_pooler(outputs[0])
35
+ y_pred = self.classifier(self.dropout(embed))
36
+ loss_token = None
37
+ loss_label = None
38
+ loss_total = None
39
+
40
+ if attn is not None:
41
+ loss_fct = nn.CrossEntropyLoss()
42
+ # Only keep active parts of the loss
43
+ if mask is not None:
44
+ active_loss = mask.view(-1) == 1
45
+ active_logits = logits.view(-1, 2)
46
+ active_labels = torch.where(
47
+ active_loss, attn.view(-1), torch.tensor(loss_fct.ignore_index).type_as(attn)
48
+ )
49
+ loss_token = loss_fct(active_logits, active_labels)
50
+ else:
51
+ loss_token = loss_fct(logits.view(-1, 2), attn.view(-1))
52
+
53
+ loss_total=self.impact_factor*loss_token
54
+
55
+
56
+ if labels is not None:
57
+ loss_funct = nn.CrossEntropyLoss()
58
+ loss_logits = loss_funct(y_pred.view(-1, self.num_labels), labels.view(-1))
59
+ loss_label= loss_logits
60
+ if(loss_total is not None):
61
+ loss_total+=loss_label
62
+ else:
63
+ loss_total=loss_label
64
+ if(loss_total is not None):
65
+ return y_pred, logits, loss_total
66
+ else:
67
+ return y_pred, logits