kiddothe2b
commited on
Commit
•
895ac06
1
Parent(s):
278b6ef
Add HAT implementation files
Browse files- modelling_hat.py +25 -24
modelling_hat.py
CHANGED
@@ -1078,7 +1078,7 @@ class HATForMaskedLM(HATPreTrainedModel):
|
|
1078 |
def __init__(self, config):
|
1079 |
super().__init__(config)
|
1080 |
|
1081 |
-
self.
|
1082 |
self.lm_head = HATLMHead(config)
|
1083 |
|
1084 |
# The LM head weights require special treatment only when they are tied with the word embeddings
|
@@ -1123,7 +1123,7 @@ class HATForMaskedLM(HATPreTrainedModel):
|
|
1123 |
"""
|
1124 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1125 |
|
1126 |
-
outputs = self.
|
1127 |
input_ids,
|
1128 |
attention_mask=attention_mask,
|
1129 |
token_type_ids=token_type_ids,
|
@@ -1161,7 +1161,7 @@ class HATModelForDocumentRepresentation(HATPreTrainedModel):
|
|
1161 |
self.num_labels = config.num_labels
|
1162 |
self.config = config
|
1163 |
|
1164 |
-
self.
|
1165 |
self.pooler = HATPooler(config, pooling=pooling)
|
1166 |
|
1167 |
# Initialize weights and apply final processing
|
@@ -1195,7 +1195,7 @@ class HATModelForDocumentRepresentation(HATPreTrainedModel):
|
|
1195 |
"""
|
1196 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1197 |
|
1198 |
-
outputs = self.
|
1199 |
input_ids,
|
1200 |
attention_mask=attention_mask,
|
1201 |
token_type_ids=token_type_ids,
|
@@ -1237,7 +1237,7 @@ class HATModelForMaskedSentenceRepresentation(HATPreTrainedModel):
|
|
1237 |
self.num_labels = config.num_labels
|
1238 |
self.config = config
|
1239 |
|
1240 |
-
self.
|
1241 |
self.sentencizer = HATSentencizer(config)
|
1242 |
|
1243 |
# Initialize weights and apply final processing
|
@@ -1271,7 +1271,7 @@ class HATModelForMaskedSentenceRepresentation(HATPreTrainedModel):
|
|
1271 |
"""
|
1272 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1273 |
|
1274 |
-
outputs = self.
|
1275 |
input_ids,
|
1276 |
attention_mask=attention_mask,
|
1277 |
token_type_ids=token_type_ids,
|
@@ -1313,7 +1313,7 @@ class HATModelForBoWPreTraining(HATPreTrainedModel):
|
|
1313 |
def __init__(self, config):
|
1314 |
super().__init__(config)
|
1315 |
|
1316 |
-
self.
|
1317 |
if self.config.mlm or self.config.mslm:
|
1318 |
self.lm_head = HATLMHead(config)
|
1319 |
if self.config.srp or self.config.srp:
|
@@ -1346,7 +1346,7 @@ class HATModelForBoWPreTraining(HATPreTrainedModel):
|
|
1346 |
):
|
1347 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1348 |
|
1349 |
-
outputs = self.
|
1350 |
input_ids,
|
1351 |
attention_mask=attention_mask,
|
1352 |
token_type_ids=token_type_ids,
|
@@ -1447,7 +1447,7 @@ class HATModelForVICRegPreTraining(HATPreTrainedModel):
|
|
1447 |
|
1448 |
self.document_regularization = document_regularization
|
1449 |
self.sentence_regularization = sentence_regularization
|
1450 |
-
self.
|
1451 |
if self.config.mlm:
|
1452 |
self.lm_head = HATLMHead(config)
|
1453 |
if self.config.sent_sim or self.config.doc_sim:
|
@@ -1474,7 +1474,7 @@ class HATModelForVICRegPreTraining(HATPreTrainedModel):
|
|
1474 |
):
|
1475 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1476 |
|
1477 |
-
primary_outputs = self.
|
1478 |
input_ids,
|
1479 |
attention_mask=attention_mask,
|
1480 |
token_type_ids=token_type_ids,
|
@@ -1484,7 +1484,7 @@ class HATModelForVICRegPreTraining(HATPreTrainedModel):
|
|
1484 |
return_dict=return_dict,
|
1485 |
)
|
1486 |
|
1487 |
-
secondary_outputs = self.
|
1488 |
secondary_input_ids,
|
1489 |
attention_mask=attention_mask,
|
1490 |
token_type_ids=token_type_ids,
|
@@ -1600,7 +1600,7 @@ class HATModelForSimCLRPreTraining(HATPreTrainedModel):
|
|
1600 |
|
1601 |
self.document_regularization = document_regularization
|
1602 |
self.sentence_regularization = sentence_regularization
|
1603 |
-
self.
|
1604 |
if self.config.mlm:
|
1605 |
self.lm_head = HATLMHead(config)
|
1606 |
if self.config.sent_sim or self.config.doc_sim:
|
@@ -1626,7 +1626,7 @@ class HATModelForSimCLRPreTraining(HATPreTrainedModel):
|
|
1626 |
):
|
1627 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1628 |
|
1629 |
-
primary_outputs = self.
|
1630 |
input_ids,
|
1631 |
attention_mask=attention_mask,
|
1632 |
token_type_ids=token_type_ids,
|
@@ -1636,7 +1636,7 @@ class HATModelForSimCLRPreTraining(HATPreTrainedModel):
|
|
1636 |
return_dict=return_dict,
|
1637 |
)
|
1638 |
|
1639 |
-
secondary_outputs = self.
|
1640 |
secondary_input_ids,
|
1641 |
attention_mask=attention_mask,
|
1642 |
token_type_ids=token_type_ids,
|
@@ -1808,7 +1808,7 @@ class HATForSequenceClassification(HATPreTrainedModel):
|
|
1808 |
self.config = config
|
1809 |
self.pooling = pooling
|
1810 |
|
1811 |
-
self.
|
1812 |
classifier_dropout = (
|
1813 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
1814 |
)
|
@@ -1848,7 +1848,7 @@ class HATForSequenceClassification(HATPreTrainedModel):
|
|
1848 |
"""
|
1849 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1850 |
|
1851 |
-
outputs = self.
|
1852 |
input_ids,
|
1853 |
attention_mask=attention_mask,
|
1854 |
token_type_ids=token_type_ids,
|
@@ -1916,7 +1916,7 @@ class HATModelForSequentialSentenceClassification(HATPreTrainedModel):
|
|
1916 |
self.num_labels = config.num_labels
|
1917 |
self.config = config
|
1918 |
|
1919 |
-
self.
|
1920 |
self.sentencizer = HATSentencizer(config)
|
1921 |
classifier_dropout = (
|
1922 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
@@ -1954,7 +1954,7 @@ class HATModelForSequentialSentenceClassification(HATPreTrainedModel):
|
|
1954 |
"""
|
1955 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1956 |
|
1957 |
-
outputs = self.
|
1958 |
input_ids,
|
1959 |
attention_mask=attention_mask,
|
1960 |
token_type_ids=token_type_ids,
|
@@ -2020,7 +2020,7 @@ class HATForMultipleChoice(HATPreTrainedModel):
|
|
2020 |
super().__init__(config)
|
2021 |
|
2022 |
self.pooling = pooling
|
2023 |
-
self.
|
2024 |
classifier_dropout = (
|
2025 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
2026 |
)
|
@@ -2071,7 +2071,7 @@ class HATForMultipleChoice(HATPreTrainedModel):
|
|
2071 |
else None
|
2072 |
)
|
2073 |
|
2074 |
-
outputs = self.
|
2075 |
flat_input_ids,
|
2076 |
position_ids=flat_position_ids,
|
2077 |
token_type_ids=flat_token_type_ids,
|
@@ -2125,7 +2125,7 @@ class HATForTokenClassification(HATPreTrainedModel):
|
|
2125 |
super().__init__(config)
|
2126 |
self.num_labels = config.num_labels
|
2127 |
|
2128 |
-
self.
|
2129 |
classifier_dropout = (
|
2130 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
2131 |
)
|
@@ -2160,7 +2160,7 @@ class HATForTokenClassification(HATPreTrainedModel):
|
|
2160 |
"""
|
2161 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
2162 |
|
2163 |
-
outputs = self.
|
2164 |
input_ids,
|
2165 |
attention_mask=attention_mask,
|
2166 |
token_type_ids=token_type_ids,
|
@@ -2208,7 +2208,7 @@ class HATForQuestionAnswering(HATPreTrainedModel):
|
|
2208 |
super().__init__(config)
|
2209 |
self.num_labels = config.num_labels
|
2210 |
|
2211 |
-
self.
|
2212 |
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
2213 |
|
2214 |
# Initialize weights and apply final processing
|
@@ -2247,7 +2247,7 @@ class HATForQuestionAnswering(HATPreTrainedModel):
|
|
2247 |
"""
|
2248 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
2249 |
|
2250 |
-
outputs = self.
|
2251 |
input_ids,
|
2252 |
attention_mask=attention_mask,
|
2253 |
token_type_ids=token_type_ids,
|
@@ -2333,3 +2333,4 @@ def off_diagonal(x):
|
|
2333 |
assert n == m
|
2334 |
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
|
2335 |
|
|
|
|
1078 |
def __init__(self, config):
|
1079 |
super().__init__(config)
|
1080 |
|
1081 |
+
self.hi_transformer = HATModel(config)
|
1082 |
self.lm_head = HATLMHead(config)
|
1083 |
|
1084 |
# The LM head weights require special treatment only when they are tied with the word embeddings
|
|
|
1123 |
"""
|
1124 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1125 |
|
1126 |
+
outputs = self.hi_transformer(
|
1127 |
input_ids,
|
1128 |
attention_mask=attention_mask,
|
1129 |
token_type_ids=token_type_ids,
|
|
|
1161 |
self.num_labels = config.num_labels
|
1162 |
self.config = config
|
1163 |
|
1164 |
+
self.hi_transformer = HATModel(config)
|
1165 |
self.pooler = HATPooler(config, pooling=pooling)
|
1166 |
|
1167 |
# Initialize weights and apply final processing
|
|
|
1195 |
"""
|
1196 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1197 |
|
1198 |
+
outputs = self.hi_transformer(
|
1199 |
input_ids,
|
1200 |
attention_mask=attention_mask,
|
1201 |
token_type_ids=token_type_ids,
|
|
|
1237 |
self.num_labels = config.num_labels
|
1238 |
self.config = config
|
1239 |
|
1240 |
+
self.hi_transformer = HATModel(config)
|
1241 |
self.sentencizer = HATSentencizer(config)
|
1242 |
|
1243 |
# Initialize weights and apply final processing
|
|
|
1271 |
"""
|
1272 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1273 |
|
1274 |
+
outputs = self.hi_transformer(
|
1275 |
input_ids,
|
1276 |
attention_mask=attention_mask,
|
1277 |
token_type_ids=token_type_ids,
|
|
|
1313 |
def __init__(self, config):
|
1314 |
super().__init__(config)
|
1315 |
|
1316 |
+
self.hi_transformer = HATModel(config)
|
1317 |
if self.config.mlm or self.config.mslm:
|
1318 |
self.lm_head = HATLMHead(config)
|
1319 |
if self.config.srp or self.config.srp:
|
|
|
1346 |
):
|
1347 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1348 |
|
1349 |
+
outputs = self.hi_transformer(
|
1350 |
input_ids,
|
1351 |
attention_mask=attention_mask,
|
1352 |
token_type_ids=token_type_ids,
|
|
|
1447 |
|
1448 |
self.document_regularization = document_regularization
|
1449 |
self.sentence_regularization = sentence_regularization
|
1450 |
+
self.hi_transformer = HATModel(config)
|
1451 |
if self.config.mlm:
|
1452 |
self.lm_head = HATLMHead(config)
|
1453 |
if self.config.sent_sim or self.config.doc_sim:
|
|
|
1474 |
):
|
1475 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1476 |
|
1477 |
+
primary_outputs = self.hi_transformer(
|
1478 |
input_ids,
|
1479 |
attention_mask=attention_mask,
|
1480 |
token_type_ids=token_type_ids,
|
|
|
1484 |
return_dict=return_dict,
|
1485 |
)
|
1486 |
|
1487 |
+
secondary_outputs = self.hi_transformer(
|
1488 |
secondary_input_ids,
|
1489 |
attention_mask=attention_mask,
|
1490 |
token_type_ids=token_type_ids,
|
|
|
1600 |
|
1601 |
self.document_regularization = document_regularization
|
1602 |
self.sentence_regularization = sentence_regularization
|
1603 |
+
self.hi_transformer = HATModel(config)
|
1604 |
if self.config.mlm:
|
1605 |
self.lm_head = HATLMHead(config)
|
1606 |
if self.config.sent_sim or self.config.doc_sim:
|
|
|
1626 |
):
|
1627 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1628 |
|
1629 |
+
primary_outputs = self.hi_transformer(
|
1630 |
input_ids,
|
1631 |
attention_mask=attention_mask,
|
1632 |
token_type_ids=token_type_ids,
|
|
|
1636 |
return_dict=return_dict,
|
1637 |
)
|
1638 |
|
1639 |
+
secondary_outputs = self.hi_transformer(
|
1640 |
secondary_input_ids,
|
1641 |
attention_mask=attention_mask,
|
1642 |
token_type_ids=token_type_ids,
|
|
|
1808 |
self.config = config
|
1809 |
self.pooling = pooling
|
1810 |
|
1811 |
+
self.hi_transformer = HATModel(config)
|
1812 |
classifier_dropout = (
|
1813 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
1814 |
)
|
|
|
1848 |
"""
|
1849 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1850 |
|
1851 |
+
outputs = self.hi_transformer(
|
1852 |
input_ids,
|
1853 |
attention_mask=attention_mask,
|
1854 |
token_type_ids=token_type_ids,
|
|
|
1916 |
self.num_labels = config.num_labels
|
1917 |
self.config = config
|
1918 |
|
1919 |
+
self.hi_transformer = HATModel(config)
|
1920 |
self.sentencizer = HATSentencizer(config)
|
1921 |
classifier_dropout = (
|
1922 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
|
|
1954 |
"""
|
1955 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1956 |
|
1957 |
+
outputs = self.hi_transformer(
|
1958 |
input_ids,
|
1959 |
attention_mask=attention_mask,
|
1960 |
token_type_ids=token_type_ids,
|
|
|
2020 |
super().__init__(config)
|
2021 |
|
2022 |
self.pooling = pooling
|
2023 |
+
self.hi_transformer = HATModel(config)
|
2024 |
classifier_dropout = (
|
2025 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
2026 |
)
|
|
|
2071 |
else None
|
2072 |
)
|
2073 |
|
2074 |
+
outputs = self.hi_transformer(
|
2075 |
flat_input_ids,
|
2076 |
position_ids=flat_position_ids,
|
2077 |
token_type_ids=flat_token_type_ids,
|
|
|
2125 |
super().__init__(config)
|
2126 |
self.num_labels = config.num_labels
|
2127 |
|
2128 |
+
self.hi_transformer = HATModel(config, add_pooling_layer=False)
|
2129 |
classifier_dropout = (
|
2130 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
2131 |
)
|
|
|
2160 |
"""
|
2161 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
2162 |
|
2163 |
+
outputs = self.hi_transformer(
|
2164 |
input_ids,
|
2165 |
attention_mask=attention_mask,
|
2166 |
token_type_ids=token_type_ids,
|
|
|
2208 |
super().__init__(config)
|
2209 |
self.num_labels = config.num_labels
|
2210 |
|
2211 |
+
self.hi_transformer = HATModel(config, add_pooling_layer=False)
|
2212 |
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
2213 |
|
2214 |
# Initialize weights and apply final processing
|
|
|
2247 |
"""
|
2248 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
2249 |
|
2250 |
+
outputs = self.hi_transformer(
|
2251 |
input_ids,
|
2252 |
attention_mask=attention_mask,
|
2253 |
token_type_ids=token_type_ids,
|
|
|
2333 |
assert n == m
|
2334 |
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
|
2335 |
|
2336 |
+
|