kiddothe2b
commited on
Commit
•
c1c87bf
1
Parent(s):
af99e83
Add HAT implementation files
Browse files- modelling_hat.py +4 -9
modelling_hat.py
CHANGED
@@ -1839,8 +1839,6 @@ class HATForSequenceClassification(HATPreTrainedModel):
|
|
1839 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
1840 |
)
|
1841 |
self.dropout = nn.Dropout(classifier_dropout)
|
1842 |
-
if self.pooling != 'cls':
|
1843 |
-
self.sentencizer = HATSentencizer(config)
|
1844 |
self.pooler = HATPooler(config, pooling=pooling)
|
1845 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1846 |
|
@@ -1885,13 +1883,12 @@ class HATForSequenceClassification(HATPreTrainedModel):
|
|
1885 |
return_dict=return_dict,
|
1886 |
)
|
1887 |
sequence_output = outputs[0]
|
1888 |
-
if self.pooling
|
1889 |
-
sentence_outputs = self.sentencizer(sequence_output)
|
1890 |
-
pooled_output = self.pooler(sentence_outputs)
|
1891 |
-
elif self.pooling == 'first':
|
1892 |
pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, 0, :], 1))
|
1893 |
elif self.pooling == 'last':
|
1894 |
pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, -128, :], 1))
|
|
|
|
|
1895 |
|
1896 |
pooled_output = self.dropout(pooled_output)
|
1897 |
logits = self.classifier(pooled_output)
|
@@ -2051,8 +2048,6 @@ class HATForMultipleChoice(HATPreTrainedModel):
|
|
2051 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
2052 |
)
|
2053 |
self.dropout = nn.Dropout(classifier_dropout)
|
2054 |
-
if self.pooling not in ['first', 'last']:
|
2055 |
-
self.sentencizer = HATSentencizer(config)
|
2056 |
self.pooler = HATPooler(config, pooling=pooling)
|
2057 |
self.classifier = nn.Linear(config.hidden_size, 1)
|
2058 |
|
@@ -2113,7 +2108,7 @@ class HATForMultipleChoice(HATPreTrainedModel):
|
|
2113 |
elif self.pooling == 'last':
|
2114 |
pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, -128, :], 1))
|
2115 |
else:
|
2116 |
-
pooled_output = self.pooler(self.
|
2117 |
|
2118 |
pooled_output = self.dropout(pooled_output)
|
2119 |
logits = self.classifier(pooled_output)
|
|
|
1839 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
1840 |
)
|
1841 |
self.dropout = nn.Dropout(classifier_dropout)
|
|
|
|
|
1842 |
self.pooler = HATPooler(config, pooling=pooling)
|
1843 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1844 |
|
|
|
1883 |
return_dict=return_dict,
|
1884 |
)
|
1885 |
sequence_output = outputs[0]
|
1886 |
+
if self.pooling == 'first':
|
|
|
|
|
|
|
1887 |
pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, 0, :], 1))
|
1888 |
elif self.pooling == 'last':
|
1889 |
pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, -128, :], 1))
|
1890 |
+
else:
|
1891 |
+
pooled_output = self.pooler(sequence_output[:, ::self.max_sentence_length])
|
1892 |
|
1893 |
pooled_output = self.dropout(pooled_output)
|
1894 |
logits = self.classifier(pooled_output)
|
|
|
2048 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
2049 |
)
|
2050 |
self.dropout = nn.Dropout(classifier_dropout)
|
|
|
|
|
2051 |
self.pooler = HATPooler(config, pooling=pooling)
|
2052 |
self.classifier = nn.Linear(config.hidden_size, 1)
|
2053 |
|
|
|
2108 |
elif self.pooling == 'last':
|
2109 |
pooled_output = self.pooler(torch.unsqueeze(sequence_output[:, -128, :], 1))
|
2110 |
else:
|
2111 |
+
pooled_output = self.pooler(sequence_output[:, ::self.max_sentence_length])
|
2112 |
|
2113 |
pooled_output = self.dropout(pooled_output)
|
2114 |
logits = self.classifier(pooled_output)
|