Markus28 commited on
Commit
ba24fb1
·
1 Parent(s): ca5f516

feat: added BertForSequenceClassification

Browse files
Files changed (1) hide show
  1. modeling_for_glue.py +104 -0
modeling_for_glue.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
6
+ from transformers.modeling_outputs import SequenceClassifierOutput
7
+
8
+ from .modeling_bert import BertPreTrainedModel, BertModel
9
+ from .configuration_bert import JinaBertConfig
10
+
11
+
12
+ class BertForSequenceClassification(BertPreTrainedModel):
13
+ def __init__(self, config: JinaBertConfig):
14
+ super().__init__(config)
15
+ self.num_labels = config.num_labels
16
+ self.config = config
17
+
18
+ self.bert = BertModel(config)
19
+ classifier_dropout = (
20
+ config.classifier_dropout
21
+ if config.classifier_dropout is not None
22
+ else config.hidden_dropout_prob
23
+ )
24
+ self.dropout = nn.Dropout(classifier_dropout)
25
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
26
+
27
+ # Initialize weights and apply final processing
28
+ self.post_init()
29
+
30
+
31
+ def forward(
32
+ self,
33
+ input_ids: Optional[torch.Tensor] = None,
34
+ attention_mask: Optional[torch.Tensor] = None,
35
+ token_type_ids: Optional[torch.Tensor] = None,
36
+ position_ids: Optional[torch.Tensor] = None,
37
+ head_mask: Optional[torch.Tensor] = None,
38
+ inputs_embeds: Optional[torch.Tensor] = None,
39
+ labels: Optional[torch.Tensor] = None,
40
+ output_attentions: Optional[bool] = None,
41
+ output_hidden_states: Optional[bool] = None,
42
+ return_dict: Optional[bool] = None,
43
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
44
+ r"""
45
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
46
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
47
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
48
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
49
+ """
50
+ return_dict = (
51
+ return_dict if return_dict is not None else self.config.use_return_dict
52
+ )
53
+
54
+ outputs = self.bert(
55
+ input_ids,
56
+ attention_mask=attention_mask,
57
+ token_type_ids=token_type_ids,
58
+ position_ids=position_ids,
59
+ head_mask=head_mask,
60
+ inputs_embeds=inputs_embeds,
61
+ output_attentions=output_attentions,
62
+ output_hidden_states=output_hidden_states,
63
+ return_dict=return_dict,
64
+ )
65
+
66
+ pooled_output = outputs[1]
67
+
68
+ pooled_output = self.dropout(pooled_output)
69
+ logits = self.classifier(pooled_output)
70
+
71
+ loss = None
72
+ if labels is not None:
73
+ if self.config.problem_type is None:
74
+ if self.num_labels == 1:
75
+ self.config.problem_type = "regression"
76
+ elif self.num_labels > 1 and (
77
+ labels.dtype == torch.long or labels.dtype == torch.int
78
+ ):
79
+ self.config.problem_type = "single_label_classification"
80
+ else:
81
+ self.config.problem_type = "multi_label_classification"
82
+
83
+ if self.config.problem_type == "regression":
84
+ loss_fct = MSELoss()
85
+ if self.num_labels == 1:
86
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
87
+ else:
88
+ loss = loss_fct(logits, labels)
89
+ elif self.config.problem_type == "single_label_classification":
90
+ loss_fct = CrossEntropyLoss()
91
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
92
+ elif self.config.problem_type == "multi_label_classification":
93
+ loss_fct = BCEWithLogitsLoss()
94
+ loss = loss_fct(logits, labels)
95
+ if not return_dict:
96
+ output = (logits,) + outputs[2:]
97
+ return ((loss,) + output) if loss is not None else output
98
+
99
+ return SequenceClassifierOutput(
100
+ loss=loss,
101
+ logits=logits,
102
+ hidden_states=outputs.hidden_states,
103
+ attentions=outputs.attentions,
104
+ )