kiddothe2b commited on
Commit
895ac06
1 Parent(s): 278b6ef

Add HAT implementation files

Browse files
Files changed (1) hide show
  1. modelling_hat.py +25 -24
modelling_hat.py CHANGED
@@ -1078,7 +1078,7 @@ class HATForMaskedLM(HATPreTrainedModel):
1078
  def __init__(self, config):
1079
  super().__init__(config)
1080
 
1081
- self.hat = HATModel(config)
1082
  self.lm_head = HATLMHead(config)
1083
 
1084
  # The LM head weights require special treatment only when they are tied with the word embeddings
@@ -1123,7 +1123,7 @@ class HATForMaskedLM(HATPreTrainedModel):
1123
  """
1124
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1125
 
1126
- outputs = self.hat(
1127
  input_ids,
1128
  attention_mask=attention_mask,
1129
  token_type_ids=token_type_ids,
@@ -1161,7 +1161,7 @@ class HATModelForDocumentRepresentation(HATPreTrainedModel):
1161
  self.num_labels = config.num_labels
1162
  self.config = config
1163
 
1164
- self.hat = HATModel(config)
1165
  self.pooler = HATPooler(config, pooling=pooling)
1166
 
1167
  # Initialize weights and apply final processing
@@ -1195,7 +1195,7 @@ class HATModelForDocumentRepresentation(HATPreTrainedModel):
1195
  """
1196
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1197
 
1198
- outputs = self.hat(
1199
  input_ids,
1200
  attention_mask=attention_mask,
1201
  token_type_ids=token_type_ids,
@@ -1237,7 +1237,7 @@ class HATModelForMaskedSentenceRepresentation(HATPreTrainedModel):
1237
  self.num_labels = config.num_labels
1238
  self.config = config
1239
 
1240
- self.hat = HATModel(config)
1241
  self.sentencizer = HATSentencizer(config)
1242
 
1243
  # Initialize weights and apply final processing
@@ -1271,7 +1271,7 @@ class HATModelForMaskedSentenceRepresentation(HATPreTrainedModel):
1271
  """
1272
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1273
 
1274
- outputs = self.hat(
1275
  input_ids,
1276
  attention_mask=attention_mask,
1277
  token_type_ids=token_type_ids,
@@ -1313,7 +1313,7 @@ class HATModelForBoWPreTraining(HATPreTrainedModel):
1313
  def __init__(self, config):
1314
  super().__init__(config)
1315
 
1316
- self.hat = HATModel(config)
1317
  if self.config.mlm or self.config.mslm:
1318
  self.lm_head = HATLMHead(config)
1319
  if self.config.srp or self.config.srp:
@@ -1346,7 +1346,7 @@ class HATModelForBoWPreTraining(HATPreTrainedModel):
1346
  ):
1347
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1348
 
1349
- outputs = self.hat(
1350
  input_ids,
1351
  attention_mask=attention_mask,
1352
  token_type_ids=token_type_ids,
@@ -1447,7 +1447,7 @@ class HATModelForVICRegPreTraining(HATPreTrainedModel):
1447
 
1448
  self.document_regularization = document_regularization
1449
  self.sentence_regularization = sentence_regularization
1450
- self.hat = HATModel(config)
1451
  if self.config.mlm:
1452
  self.lm_head = HATLMHead(config)
1453
  if self.config.sent_sim or self.config.doc_sim:
@@ -1474,7 +1474,7 @@ class HATModelForVICRegPreTraining(HATPreTrainedModel):
1474
  ):
1475
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1476
 
1477
- primary_outputs = self.hat(
1478
  input_ids,
1479
  attention_mask=attention_mask,
1480
  token_type_ids=token_type_ids,
@@ -1484,7 +1484,7 @@ class HATModelForVICRegPreTraining(HATPreTrainedModel):
1484
  return_dict=return_dict,
1485
  )
1486
 
1487
- secondary_outputs = self.hat(
1488
  secondary_input_ids,
1489
  attention_mask=attention_mask,
1490
  token_type_ids=token_type_ids,
@@ -1600,7 +1600,7 @@ class HATModelForSimCLRPreTraining(HATPreTrainedModel):
1600
 
1601
  self.document_regularization = document_regularization
1602
  self.sentence_regularization = sentence_regularization
1603
- self.hat = HATModel(config)
1604
  if self.config.mlm:
1605
  self.lm_head = HATLMHead(config)
1606
  if self.config.sent_sim or self.config.doc_sim:
@@ -1626,7 +1626,7 @@ class HATModelForSimCLRPreTraining(HATPreTrainedModel):
1626
  ):
1627
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1628
 
1629
- primary_outputs = self.hat(
1630
  input_ids,
1631
  attention_mask=attention_mask,
1632
  token_type_ids=token_type_ids,
@@ -1636,7 +1636,7 @@ class HATModelForSimCLRPreTraining(HATPreTrainedModel):
1636
  return_dict=return_dict,
1637
  )
1638
 
1639
- secondary_outputs = self.hat(
1640
  secondary_input_ids,
1641
  attention_mask=attention_mask,
1642
  token_type_ids=token_type_ids,
@@ -1808,7 +1808,7 @@ class HATForSequenceClassification(HATPreTrainedModel):
1808
  self.config = config
1809
  self.pooling = pooling
1810
 
1811
- self.hat = HATModel(config)
1812
  classifier_dropout = (
1813
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1814
  )
@@ -1848,7 +1848,7 @@ class HATForSequenceClassification(HATPreTrainedModel):
1848
  """
1849
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1850
 
1851
- outputs = self.hat(
1852
  input_ids,
1853
  attention_mask=attention_mask,
1854
  token_type_ids=token_type_ids,
@@ -1916,7 +1916,7 @@ class HATModelForSequentialSentenceClassification(HATPreTrainedModel):
1916
  self.num_labels = config.num_labels
1917
  self.config = config
1918
 
1919
- self.hat = HATModel(config)
1920
  self.sentencizer = HATSentencizer(config)
1921
  classifier_dropout = (
1922
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
@@ -1954,7 +1954,7 @@ class HATModelForSequentialSentenceClassification(HATPreTrainedModel):
1954
  """
1955
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1956
 
1957
- outputs = self.hat(
1958
  input_ids,
1959
  attention_mask=attention_mask,
1960
  token_type_ids=token_type_ids,
@@ -2020,7 +2020,7 @@ class HATForMultipleChoice(HATPreTrainedModel):
2020
  super().__init__(config)
2021
 
2022
  self.pooling = pooling
2023
- self.hat = HATModel(config)
2024
  classifier_dropout = (
2025
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
2026
  )
@@ -2071,7 +2071,7 @@ class HATForMultipleChoice(HATPreTrainedModel):
2071
  else None
2072
  )
2073
 
2074
- outputs = self.hat(
2075
  flat_input_ids,
2076
  position_ids=flat_position_ids,
2077
  token_type_ids=flat_token_type_ids,
@@ -2125,7 +2125,7 @@ class HATForTokenClassification(HATPreTrainedModel):
2125
  super().__init__(config)
2126
  self.num_labels = config.num_labels
2127
 
2128
- self.hat = HATModel(config, add_pooling_layer=False)
2129
  classifier_dropout = (
2130
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
2131
  )
@@ -2160,7 +2160,7 @@ class HATForTokenClassification(HATPreTrainedModel):
2160
  """
2161
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2162
 
2163
- outputs = self.hat(
2164
  input_ids,
2165
  attention_mask=attention_mask,
2166
  token_type_ids=token_type_ids,
@@ -2208,7 +2208,7 @@ class HATForQuestionAnswering(HATPreTrainedModel):
2208
  super().__init__(config)
2209
  self.num_labels = config.num_labels
2210
 
2211
- self.hat = HATModel(config, add_pooling_layer=False)
2212
  self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
2213
 
2214
  # Initialize weights and apply final processing
@@ -2247,7 +2247,7 @@ class HATForQuestionAnswering(HATPreTrainedModel):
2247
  """
2248
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2249
 
2250
- outputs = self.hat(
2251
  input_ids,
2252
  attention_mask=attention_mask,
2253
  token_type_ids=token_type_ids,
@@ -2333,3 +2333,4 @@ def off_diagonal(x):
2333
  assert n == m
2334
  return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
2335
 
 
 
1078
  def __init__(self, config):
1079
  super().__init__(config)
1080
 
1081
+ self.hi_transformer = HATModel(config)
1082
  self.lm_head = HATLMHead(config)
1083
 
1084
  # The LM head weights require special treatment only when they are tied with the word embeddings
 
1123
  """
1124
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1125
 
1126
+ outputs = self.hi_transformer(
1127
  input_ids,
1128
  attention_mask=attention_mask,
1129
  token_type_ids=token_type_ids,
 
1161
  self.num_labels = config.num_labels
1162
  self.config = config
1163
 
1164
+ self.hi_transformer = HATModel(config)
1165
  self.pooler = HATPooler(config, pooling=pooling)
1166
 
1167
  # Initialize weights and apply final processing
 
1195
  """
1196
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1197
 
1198
+ outputs = self.hi_transformer(
1199
  input_ids,
1200
  attention_mask=attention_mask,
1201
  token_type_ids=token_type_ids,
 
1237
  self.num_labels = config.num_labels
1238
  self.config = config
1239
 
1240
+ self.hi_transformer = HATModel(config)
1241
  self.sentencizer = HATSentencizer(config)
1242
 
1243
  # Initialize weights and apply final processing
 
1271
  """
1272
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1273
 
1274
+ outputs = self.hi_transformer(
1275
  input_ids,
1276
  attention_mask=attention_mask,
1277
  token_type_ids=token_type_ids,
 
1313
  def __init__(self, config):
1314
  super().__init__(config)
1315
 
1316
+ self.hi_transformer = HATModel(config)
1317
  if self.config.mlm or self.config.mslm:
1318
  self.lm_head = HATLMHead(config)
1319
  if self.config.srp or self.config.srp:
 
1346
  ):
1347
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1348
 
1349
+ outputs = self.hi_transformer(
1350
  input_ids,
1351
  attention_mask=attention_mask,
1352
  token_type_ids=token_type_ids,
 
1447
 
1448
  self.document_regularization = document_regularization
1449
  self.sentence_regularization = sentence_regularization
1450
+ self.hi_transformer = HATModel(config)
1451
  if self.config.mlm:
1452
  self.lm_head = HATLMHead(config)
1453
  if self.config.sent_sim or self.config.doc_sim:
 
1474
  ):
1475
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1476
 
1477
+ primary_outputs = self.hi_transformer(
1478
  input_ids,
1479
  attention_mask=attention_mask,
1480
  token_type_ids=token_type_ids,
 
1484
  return_dict=return_dict,
1485
  )
1486
 
1487
+ secondary_outputs = self.hi_transformer(
1488
  secondary_input_ids,
1489
  attention_mask=attention_mask,
1490
  token_type_ids=token_type_ids,
 
1600
 
1601
  self.document_regularization = document_regularization
1602
  self.sentence_regularization = sentence_regularization
1603
+ self.hi_transformer = HATModel(config)
1604
  if self.config.mlm:
1605
  self.lm_head = HATLMHead(config)
1606
  if self.config.sent_sim or self.config.doc_sim:
 
1626
  ):
1627
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1628
 
1629
+ primary_outputs = self.hi_transformer(
1630
  input_ids,
1631
  attention_mask=attention_mask,
1632
  token_type_ids=token_type_ids,
 
1636
  return_dict=return_dict,
1637
  )
1638
 
1639
+ secondary_outputs = self.hi_transformer(
1640
  secondary_input_ids,
1641
  attention_mask=attention_mask,
1642
  token_type_ids=token_type_ids,
 
1808
  self.config = config
1809
  self.pooling = pooling
1810
 
1811
+ self.hi_transformer = HATModel(config)
1812
  classifier_dropout = (
1813
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1814
  )
 
1848
  """
1849
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1850
 
1851
+ outputs = self.hi_transformer(
1852
  input_ids,
1853
  attention_mask=attention_mask,
1854
  token_type_ids=token_type_ids,
 
1916
  self.num_labels = config.num_labels
1917
  self.config = config
1918
 
1919
+ self.hi_transformer = HATModel(config)
1920
  self.sentencizer = HATSentencizer(config)
1921
  classifier_dropout = (
1922
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
 
1954
  """
1955
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1956
 
1957
+ outputs = self.hi_transformer(
1958
  input_ids,
1959
  attention_mask=attention_mask,
1960
  token_type_ids=token_type_ids,
 
2020
  super().__init__(config)
2021
 
2022
  self.pooling = pooling
2023
+ self.hi_transformer = HATModel(config)
2024
  classifier_dropout = (
2025
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
2026
  )
 
2071
  else None
2072
  )
2073
 
2074
+ outputs = self.hi_transformer(
2075
  flat_input_ids,
2076
  position_ids=flat_position_ids,
2077
  token_type_ids=flat_token_type_ids,
 
2125
  super().__init__(config)
2126
  self.num_labels = config.num_labels
2127
 
2128
+ self.hi_transformer = HATModel(config, add_pooling_layer=False)
2129
  classifier_dropout = (
2130
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
2131
  )
 
2160
  """
2161
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2162
 
2163
+ outputs = self.hi_transformer(
2164
  input_ids,
2165
  attention_mask=attention_mask,
2166
  token_type_ids=token_type_ids,
 
2208
  super().__init__(config)
2209
  self.num_labels = config.num_labels
2210
 
2211
+ self.hi_transformer = HATModel(config, add_pooling_layer=False)
2212
  self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
2213
 
2214
  # Initialize weights and apply final processing
 
2247
  """
2248
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2249
 
2250
+ outputs = self.hi_transformer(
2251
  input_ids,
2252
  attention_mask=attention_mask,
2253
  token_type_ids=token_type_ids,
 
2333
  assert n == m
2334
  return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
2335
 
2336
+