michael-guenther commited on
Commit
95b4916
β€’
1 Parent(s): eb21270

add mlm model and adjust naming

Browse files
README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Converting Weights
2
+
3
+ ```
4
+ python3 -m "xlm-roberta-flash-implementation".convert_roberta_weights_to_flash --output pytorch_model_xlmr_flash.bin
5
+ ```
config.json CHANGED
@@ -1,9 +1,9 @@
1
  {
2
  "auto_map": {
3
- "AutoConfig": "configuration_bert.XLMFlashConfig",
4
- "AutoModel": "modeling_bert.BertModel",
5
- "AutoModelForPreTraining": "modeling_bert.BertForPreTraining",
6
- "AutoModelForMaskedLM": "modeling_bert.BertForPreTraining"
7
  },
8
  "attention_probs_dropout_prob": 0.1,
9
  "bos_token_id": 0,
 
1
  {
2
  "auto_map": {
3
+ "AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
4
+ "AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
5
+ "AutoModelForPreTraining": "modeling_xlm_roberta.XLMRobertaForPreTraining",
6
+ "AutoModelForMaskedLM": "modeling_xlm_roberta.XLMRobertaForMaskedLM"
7
  },
8
  "attention_probs_dropout_prob": 0.1,
9
  "bos_token_id": 0,
configuration_bert.py β†’ configuration_xlm_roberta.py RENAMED
@@ -1,6 +1,6 @@
1
  from transformers import PretrainedConfig
2
 
3
- class XLMFlashConfig(PretrainedConfig):
4
  def __init__(
5
  self,
6
  vocab_size=30522,
 
1
  from transformers import PretrainedConfig
2
 
3
+ class XLMRobertaFlashConfig(PretrainedConfig):
4
  def __init__(
5
  self,
6
  vocab_size=30522,
convert_roberta_weights_to_flash.py CHANGED
@@ -1,9 +1,10 @@
1
  import re
2
  from collections import OrderedDict
3
- from transformers import BertConfig, PretrainedConfig
4
  from transformers import XLMRobertaForMaskedLM
5
 
6
- from flash_attn.models.bert import BertModel
 
7
  import torch
8
 
9
  import click
@@ -16,12 +17,6 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
16
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
17
  """
18
 
19
- # Replace Roberta with Bert
20
- def key_mapping_roberta(key):
21
- return re.sub(r"^roberta.", "bert.", key)
22
-
23
- state_dict = OrderedDict((key_mapping_roberta(k), v) for k, v in state_dict.items())
24
-
25
  # LayerNorm
26
  def key_mapping_ln_gamma_beta(key):
27
  key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
@@ -34,21 +29,21 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
34
 
35
  # Layers
36
  def key_mapping_layers(key):
37
- return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key)
38
 
39
  state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
40
 
41
  # LayerNorm
42
  def key_mapping_ln(key):
43
- key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
44
  key = re.sub(
45
- r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
46
- r"bert.encoder.layers.\1.norm1.\2",
47
  key,
48
  )
49
  key = re.sub(
50
- r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
51
- r"bert.encoder.layers.\1.norm2.\2",
52
  key,
53
  )
54
  key = re.sub(
@@ -63,13 +58,13 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
63
  # MLP
64
  def key_mapping_mlp(key):
65
  key = re.sub(
66
- r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
67
- r"bert.encoder.layers.\1.mlp.fc1.\2",
68
  key,
69
  )
70
  key = re.sub(
71
- r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
72
- r"bert.encoder.layers.\1.mlp.fc2.\2",
73
  key,
74
  )
75
  return key
@@ -79,33 +74,33 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
79
  # Attention
80
  last_layer_subset = getattr(config, "last_layer_subset", False)
81
  for d in range(config.num_hidden_layers):
82
- Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
83
- Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
84
- Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
85
- bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
86
- bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
87
- bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
88
  if not (last_layer_subset and d == config.num_hidden_layers - 1):
89
- state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
90
  [Wq, Wk, Wv], dim=0
91
  )
92
- state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat(
93
  [bq, bk, bv], dim=0
94
  )
95
  else:
96
- state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
97
- state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat(
98
  [Wk, Wv], dim=0
99
  )
100
- state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
101
- state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat(
102
  [bk, bv], dim=0
103
  )
104
 
105
  def key_mapping_attn(key):
106
  return re.sub(
107
- r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
108
- r"bert.encoder.layers.\1.mixer.out_proj.\2",
109
  key,
110
  )
111
 
@@ -121,8 +116,8 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
121
  # Word embedding
122
  pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
123
  if pad_vocab_size_multiple > 1:
124
- word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
125
- state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
126
  word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
127
  )
128
  decoder_weight = state_dict["cls.predictions.decoder.weight"]
@@ -137,16 +132,6 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
137
  decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
138
  )
139
 
140
- # Embeddings
141
- def key_remove_bert(key):
142
- return re.sub(r"^bert.", "", key)
143
-
144
- state_dict = OrderedDict(
145
- (key_remove_bert(k), v)
146
- for k, v in state_dict.items()
147
- if not k.startswith('lm_head')
148
- )
149
-
150
  return state_dict
151
 
152
 
 
1
  import re
2
  from collections import OrderedDict
3
+ from transformers import PretrainedConfig
4
  from transformers import XLMRobertaForMaskedLM
5
 
6
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig as BertConfig
7
+ from .modeling_xlm_roberta import XLMRobertaForMaskedLM as BertModel
8
  import torch
9
 
10
  import click
 
17
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
18
  """
19
 
 
 
 
 
 
 
20
  # LayerNorm
21
  def key_mapping_ln_gamma_beta(key):
22
  key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
 
29
 
30
  # Layers
31
  def key_mapping_layers(key):
32
+ return re.sub(r"^roberta.encoder.layer.", "roberta.encoder.layers.", key)
33
 
34
  state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
35
 
36
  # LayerNorm
37
  def key_mapping_ln(key):
38
+ key = re.sub(r"^roberta.embeddings.LayerNorm.", "roberta.emb_ln.", key)
39
  key = re.sub(
40
+ r"^roberta.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
41
+ r"roberta.encoder.layers.\1.norm1.\2",
42
  key,
43
  )
44
  key = re.sub(
45
+ r"^roberta.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
46
+ r"roberta.encoder.layers.\1.norm2.\2",
47
  key,
48
  )
49
  key = re.sub(
 
58
  # MLP
59
  def key_mapping_mlp(key):
60
  key = re.sub(
61
+ r"^roberta.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
62
+ r"roberta.encoder.layers.\1.mlp.fc1.\2",
63
  key,
64
  )
65
  key = re.sub(
66
+ r"^roberta.encoder.layers.(\d+).output.dense.(weight|bias)",
67
+ r"roberta.encoder.layers.\1.mlp.fc2.\2",
68
  key,
69
  )
70
  return key
 
74
  # Attention
75
  last_layer_subset = getattr(config, "last_layer_subset", False)
76
  for d in range(config.num_hidden_layers):
77
+ Wq = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.query.weight")
78
+ Wk = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.key.weight")
79
+ Wv = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.value.weight")
80
+ bq = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.query.bias")
81
+ bk = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.key.bias")
82
+ bv = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.value.bias")
83
  if not (last_layer_subset and d == config.num_hidden_layers - 1):
84
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
85
  [Wq, Wk, Wv], dim=0
86
  )
87
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat(
88
  [bq, bk, bv], dim=0
89
  )
90
  else:
91
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wq.weight"] = Wq
92
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat(
93
  [Wk, Wv], dim=0
94
  )
95
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wq.bias"] = bq
96
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat(
97
  [bk, bv], dim=0
98
  )
99
 
100
  def key_mapping_attn(key):
101
  return re.sub(
102
+ r"^roberta.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
103
+ r"roberta.encoder.layers.\1.mixer.out_proj.\2",
104
  key,
105
  )
106
 
 
116
  # Word embedding
117
  pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
118
  if pad_vocab_size_multiple > 1:
119
+ word_embeddings = state_dict["roberta.embeddings.word_embeddings.weight"]
120
+ state_dict["roberta.embeddings.word_embeddings.weight"] = F.pad(
121
  word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
122
  )
123
  decoder_weight = state_dict["cls.predictions.decoder.weight"]
 
132
  decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
133
  )
134
 
 
 
 
 
 
 
 
 
 
 
135
  return state_dict
136
 
137
 
embedding.py CHANGED
@@ -11,7 +11,7 @@ from torch import Tensor
11
  from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
12
 
13
 
14
- class BertEmbeddings(nn.Module):
15
  def __init__(
16
  self,
17
  embed_dim,
 
11
  from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
12
 
13
 
14
+ class XLMRobertaEmbeddings(nn.Module):
15
  def __init__(
16
  self,
17
  embed_dim,
modeling_bert.py β†’ modeling_xlm_roberta.py RENAMED
@@ -13,28 +13,32 @@ import re
13
  from collections import OrderedDict
14
  from collections.abc import Sequence
15
  from functools import partial
16
- from typing import Any, Mapping
17
 
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  from einops import rearrange
22
- from transformers import BertConfig, PretrainedConfig
23
  from transformers.modeling_utils import PreTrainedModel
 
 
 
24
  from transformers.models.bert.modeling_bert import (
25
  BaseModelOutputWithPoolingAndCrossAttentions,
26
  BertForPreTrainingOutput,
27
  )
28
 
29
- from .bert_padding import (
 
 
30
  index_first_axis,
31
  index_first_axis_residual,
32
  pad_input,
33
  unpad_input,
34
  )
35
- from .configuration_bert import XLMFlashConfig
36
  from .block import Block
37
- from .embedding import BertEmbeddings
38
  from .mha import MHA
39
  from .mlp import FusedMLP, Mlp
40
 
@@ -155,8 +159,8 @@ def _init_weights(module, initializer_range=0.02):
155
  nn.init.zeros_(module.weight[module.padding_idx])
156
 
157
 
158
- class BertEncoder(nn.Module):
159
- def __init__(self, config: BertConfig):
160
  super().__init__()
161
  self.use_flash_attn = getattr(config, "use_flash_attn", False)
162
  self.layers = nn.ModuleList(
@@ -218,7 +222,7 @@ class BertEncoder(nn.Module):
218
  return hidden_states
219
 
220
 
221
- class BertPooler(nn.Module):
222
  def __init__(self, config):
223
  super().__init__()
224
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
@@ -237,7 +241,7 @@ class BertPooler(nn.Module):
237
  return pooled_output
238
 
239
 
240
- class BertPredictionHeadTransform(nn.Module):
241
  def __init__(self, config):
242
  super().__init__()
243
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
@@ -268,7 +272,7 @@ class BertPredictionHeadTransform(nn.Module):
268
  return hidden_states
269
 
270
 
271
- class BertLMPredictionHead(nn.Module):
272
  def __init__(self, config):
273
  super().__init__()
274
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
@@ -276,7 +280,7 @@ class BertLMPredictionHead(nn.Module):
276
  raise ImportError("fused_dense is not installed")
277
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
278
 
279
- self.transform = BertPredictionHeadTransform(config)
280
 
281
  # The output weights are the same as the input embeddings, but there is
282
  # an output-only bias for each token.
@@ -288,10 +292,10 @@ class BertLMPredictionHead(nn.Module):
288
  return hidden_states
289
 
290
 
291
- class BertPreTrainingHeads(nn.Module):
292
  def __init__(self, config):
293
  super().__init__()
294
- self.predictions = BertLMPredictionHead(config)
295
  self.seq_relationship = nn.Linear(config.hidden_size, 2)
296
 
297
  def forward(self, sequence_output, pooled_output):
@@ -300,64 +304,22 @@ class BertPreTrainingHeads(nn.Module):
300
  return prediction_scores, seq_relationship_score
301
 
302
 
303
- # class BertPreTrainedModel(nn.Module):
304
- # """An abstract class to handle weights initialization and
305
- # a simple interface for dowloading and loading pretrained models.
306
- # """
307
- #
308
- # def __init__(self, config, *inputs, **kwargs):
309
- # super().__init__()
310
- # if not isinstance(config, BertConfig):
311
- # raise ValueError(
312
- # "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
313
- # "To create a model from a Google pretrained model use "
314
- # "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
315
- # self.__class__.__name__, self.__class__.__name__
316
- # )
317
- # )
318
- # self.config = config
319
- #
320
- # @classmethod
321
- # def from_pretrained(cls, model_name, config, *inputs, **kwargs):
322
- # """
323
- # Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
324
- # Download and cache the pre-trained model file if needed.
325
- #
326
- # Params:
327
- # pretrained_model_name_or_path: either:
328
- # - a path or url to a pretrained model archive containing:
329
- # . `bert_config.json` a configuration file for the model
330
- # . `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
331
- # - a path or url to a pretrained model archive containing:
332
- # . `bert_config.json` a configuration file for the model
333
- # . `model.chkpt` a TensorFlow checkpoint
334
- # *inputs, **kwargs: additional input for the specific Bert class
335
- # (ex: num_labels for BertForSequenceClassification)
336
- # """
337
- # # Instantiate model.
338
- # model = cls(config, *inputs, **kwargs)
339
- # load_return = model.load_state_dict(
340
- # remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
341
- # )
342
- # logger.info(load_return)
343
- # return model
344
-
345
- class BertPreTrainedModel(PreTrainedModel):
346
  """An abstract class to handle weights initialization and
347
  a simple interface for dowloading and loading pretrained models.
348
  """
349
- config_class = XLMFlashConfig
350
- base_model_prefix = "bert"
351
  supports_gradient_checkpointing = True
352
 
353
  def _set_gradient_checkpointing(self, module, value=False):
354
- if isinstance(module, BertEncoder):
355
  module.gradient_checkpointing = value
356
 
357
 
358
 
359
- class BertModel(BertPreTrainedModel):
360
- def __init__(self, config: BertConfig, add_pooling_layer=True):
361
  super().__init__(config)
362
  self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
363
  if config.vocab_size % self.pad_vocab_size_multiple != 0:
@@ -369,7 +331,7 @@ class BertModel(BertPreTrainedModel):
369
  raise ImportError("Triton is not installed")
370
  assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
371
 
372
- self.embeddings = BertEmbeddings(
373
  config.hidden_size,
374
  config.vocab_size,
375
  config.max_position_embeddings,
@@ -378,11 +340,12 @@ class BertModel(BertPreTrainedModel):
378
  )
379
  self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
380
  self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
381
- self.encoder = BertEncoder(config)
382
- self.pooler = BertPooler(config) if add_pooling_layer else None
383
 
384
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
385
 
 
386
  def forward(
387
  self,
388
  input_ids,
@@ -390,12 +353,22 @@ class BertModel(BertPreTrainedModel):
390
  token_type_ids=None,
391
  attention_mask=None,
392
  masked_tokens_mask=None,
 
 
393
  ):
394
- """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
395
  we only want the output for the masked tokens. This means that we only compute the last
396
  layer output for these tokens.
397
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
398
  """
 
 
 
 
 
 
 
 
399
  hidden_states = self.embeddings(
400
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
401
  )
@@ -437,111 +410,200 @@ class BertModel(BertPreTrainedModel):
437
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
438
  pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
439
 
 
 
 
440
  return BaseModelOutputWithPoolingAndCrossAttentions(
441
  last_hidden_state=sequence_output,
442
  pooler_output=pooled_output,
443
  )
444
 
445
 
446
- class BertForPreTraining(BertPreTrainedModel):
447
- def __init__(self, config: BertConfig):
448
- import pdb
449
- pdb.set_trace()
450
  super().__init__(config)
451
- # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
452
- # (around 15%) to the classifier heads.
453
- self.dense_seq_output = getattr(config, "dense_seq_output", False)
454
- # If last_layer_subset, we only need the compute the last layer for a subset of tokens
455
- # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
456
- self.last_layer_subset = getattr(config, "last_layer_subset", False)
457
- if self.last_layer_subset:
458
- assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
459
- use_xentropy = getattr(config, "use_xentropy", False)
460
- if use_xentropy and CrossEntropyLoss is None:
461
- raise ImportError("xentropy_cuda is not installed")
462
- loss_cls = (
463
- nn.CrossEntropyLoss
464
- if not use_xentropy
465
- else partial(CrossEntropyLoss, inplace_backward=True)
466
- )
467
 
468
- self.bert = BertModel(config)
469
- self.cls = BertPreTrainingHeads(config)
470
- self.mlm_loss = loss_cls(ignore_index=0)
471
- self.nsp_loss = loss_cls(ignore_index=-1)
 
 
 
 
472
 
473
  # Initialize weights and apply final processing
474
- self.apply(partial(_init_weights, initializer_range=config.initializer_range))
475
- self.tie_weights()
 
 
 
 
 
 
 
 
476
 
477
- def tie_weights(self):
478
- self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
479
 
480
  def forward(
481
  self,
482
- input_ids,
483
- position_ids=None,
484
- token_type_ids=None,
485
- attention_mask=None,
486
- labels=None,
487
- next_sentence_label=None,
488
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  """
490
- If labels are provided, they must be 0 for masked out tokens (as specified in the attention
491
- mask).
492
- Outputs:
493
- if `labels` and `next_sentence_label` are not `None`:
494
- Outputs the total_loss which is the sum of the masked language modeling loss and the next
495
- sentence classification loss.
496
- if `labels` or `next_sentence_label` is `None`:
497
- Outputs a tuple comprising
498
- - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
499
- - the next sentence classification logits of shape [batch_size, 2].
500
 
501
- """
502
- masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
503
- outputs = self.bert(
504
  input_ids,
505
- position_ids=position_ids,
506
  token_type_ids=token_type_ids,
507
- attention_mask=attention_mask.bool() if attention_mask is not None else None,
508
- masked_tokens_mask=masked_tokens_mask,
 
 
 
 
 
 
509
  )
510
- sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
511
- if self.dense_seq_output and labels is not None:
512
- masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
513
- if not self.last_layer_subset:
514
- sequence_output = index_first_axis(
515
- rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
516
- )
517
- prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
518
-
519
- total_loss = None
520
- if labels is not None and next_sentence_label is not None:
521
- if (
522
- self.dense_seq_output and labels is not None
523
- ): # prediction_scores are already flattened
524
- masked_lm_loss = self.mlm_loss(
525
- prediction_scores, labels.flatten()[masked_token_idx]
526
- )
527
- else:
528
- masked_lm_loss = self.mlm_loss(
529
- rearrange(prediction_scores, "... v -> (...) v"),
530
- rearrange(labels, "... -> (...)"),
531
- )
532
- next_sentence_loss = self.nsp_loss(
533
- rearrange(seq_relationship_score, "... t -> (...) t"),
534
- rearrange(next_sentence_label, "... -> (...)"),
535
- )
536
- total_loss = masked_lm_loss.float() + next_sentence_loss.float()
537
-
538
- return BertForPreTrainingOutput(
539
- loss=total_loss,
540
- prediction_logits=prediction_scores,
541
- seq_relationship_logits=seq_relationship_score,
542
  )
543
 
544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  def remap_state_dict(state_dict, config: PretrainedConfig):
546
  """
547
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
 
13
  from collections import OrderedDict
14
  from collections.abc import Sequence
15
  from functools import partial
 
16
 
17
  import torch
18
  import torch.nn as nn
19
  import torch.nn.functional as F
20
  from einops import rearrange
21
+ from transformers import PretrainedConfig
22
  from transformers.modeling_utils import PreTrainedModel
23
+ from transformers.modeling_outputs import MaskedLMOutput
24
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
25
+
26
  from transformers.models.bert.modeling_bert import (
27
  BaseModelOutputWithPoolingAndCrossAttentions,
28
  BertForPreTrainingOutput,
29
  )
30
 
31
+ from typing import Optional, Tuple, Union
32
+
33
+ from .xlm_padding import (
34
  index_first_axis,
35
  index_first_axis_residual,
36
  pad_input,
37
  unpad_input,
38
  )
39
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig
40
  from .block import Block
41
+ from .embedding import XLMRobertaEmbeddings
42
  from .mha import MHA
43
  from .mlp import FusedMLP, Mlp
44
 
 
159
  nn.init.zeros_(module.weight[module.padding_idx])
160
 
161
 
162
+ class XLMRobertaEncoder(nn.Module):
163
+ def __init__(self, config: XLMRobertaFlashConfig):
164
  super().__init__()
165
  self.use_flash_attn = getattr(config, "use_flash_attn", False)
166
  self.layers = nn.ModuleList(
 
222
  return hidden_states
223
 
224
 
225
+ class XLMRobertaPooler(nn.Module):
226
  def __init__(self, config):
227
  super().__init__()
228
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
 
241
  return pooled_output
242
 
243
 
244
+ class XLMRobertaPredictionHeadTransform(nn.Module):
245
  def __init__(self, config):
246
  super().__init__()
247
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
 
272
  return hidden_states
273
 
274
 
275
+ class XLMRobertaLMPredictionHead(nn.Module):
276
  def __init__(self, config):
277
  super().__init__()
278
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
 
280
  raise ImportError("fused_dense is not installed")
281
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
282
 
283
+ self.transform = XLMRobertaPredictionHeadTransform(config)
284
 
285
  # The output weights are the same as the input embeddings, but there is
286
  # an output-only bias for each token.
 
292
  return hidden_states
293
 
294
 
295
+ class XLMRobertaPreTrainingHeads(nn.Module):
296
  def __init__(self, config):
297
  super().__init__()
298
+ self.predictions = XLMRobertaLMPredictionHead(config)
299
  self.seq_relationship = nn.Linear(config.hidden_size, 2)
300
 
301
  def forward(self, sequence_output, pooled_output):
 
304
  return prediction_scores, seq_relationship_score
305
 
306
 
307
+ class XLMRobertaPreTrainedModel(PreTrainedModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  """An abstract class to handle weights initialization and
309
  a simple interface for dowloading and loading pretrained models.
310
  """
311
+ config_class = XLMRobertaFlashConfig
312
+ base_model_prefix = "roberta"
313
  supports_gradient_checkpointing = True
314
 
315
  def _set_gradient_checkpointing(self, module, value=False):
316
+ if isinstance(module, XLMRobertaEncoder):
317
  module.gradient_checkpointing = value
318
 
319
 
320
 
321
+ class XLMRobertaModel(XLMRobertaPreTrainedModel):
322
+ def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
323
  super().__init__(config)
324
  self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
325
  if config.vocab_size % self.pad_vocab_size_multiple != 0:
 
331
  raise ImportError("Triton is not installed")
332
  assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
333
 
334
+ self.embeddings = XLMRobertaEmbeddings(
335
  config.hidden_size,
336
  config.vocab_size,
337
  config.max_position_embeddings,
 
340
  )
341
  self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
342
  self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
343
+ self.encoder = XLMRobertaEncoder(config)
344
+ self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
345
 
346
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
347
 
348
+
349
  def forward(
350
  self,
351
  input_ids,
 
353
  token_type_ids=None,
354
  attention_mask=None,
355
  masked_tokens_mask=None,
356
+ return_dict=None,
357
+ **kwargs,
358
  ):
359
+ """If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
360
  we only want the output for the masked tokens. This means that we only compute the last
361
  layer output for these tokens.
362
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
363
  """
364
+
365
+ if kwargs:
366
+ for key, value in kwargs.items():
367
+ if value is not None:
368
+ logger.warning('Flash attention implementation does not support kwargs: %s', key)
369
+
370
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
371
+
372
  hidden_states = self.embeddings(
373
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
374
  )
 
410
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
411
  pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
412
 
413
+ if not return_dict:
414
+ return sequence_output, pooled_output
415
+
416
  return BaseModelOutputWithPoolingAndCrossAttentions(
417
  last_hidden_state=sequence_output,
418
  pooler_output=pooled_output,
419
  )
420
 
421
 
422
+ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
423
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
424
+
425
+ def __init__(self, config):
426
  super().__init__(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
+ if config.is_decoder:
429
+ logger.warning(
430
+ "If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder=False` for "
431
+ "bi-directional self-attention."
432
+ )
433
+
434
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
435
+ self.lm_head = XLMRobertaLMHead(config)
436
 
437
  # Initialize weights and apply final processing
438
+ self.post_init()
439
+
440
+ def get_input_embeddings(self):
441
+ return self.roberta.embeddings.word_embeddings
442
+
443
+ def get_output_embeddings(self):
444
+ return self.lm_head.decoder
445
+
446
+ def set_output_embeddings(self, new_embeddings):
447
+ self.lm_head.decoder = new_embeddings
448
 
 
 
449
 
450
  def forward(
451
  self,
452
+ input_ids: Optional[torch.LongTensor] = None,
453
+ attention_mask: Optional[torch.FloatTensor] = None,
454
+ token_type_ids: Optional[torch.LongTensor] = None,
455
+ position_ids: Optional[torch.LongTensor] = None,
456
+ head_mask: Optional[torch.FloatTensor] = None,
457
+ inputs_embeds: Optional[torch.FloatTensor] = None,
458
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
459
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
460
+ labels: Optional[torch.LongTensor] = None,
461
+ output_attentions: Optional[bool] = None,
462
+ output_hidden_states: Optional[bool] = None,
463
+ return_dict: Optional[bool] = None,
464
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
465
+ r"""
466
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
467
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
468
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
469
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
470
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
471
+ Used to hide legacy arguments that have been deprecated.
472
  """
473
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
 
 
 
 
 
 
 
474
 
475
+ outputs = self.roberta(
 
 
476
  input_ids,
477
+ attention_mask=attention_mask,
478
  token_type_ids=token_type_ids,
479
+ position_ids=position_ids,
480
+ head_mask=head_mask,
481
+ inputs_embeds=inputs_embeds,
482
+ encoder_hidden_states=encoder_hidden_states,
483
+ encoder_attention_mask=encoder_attention_mask,
484
+ output_attentions=output_attentions,
485
+ output_hidden_states=output_hidden_states,
486
+ return_dict=return_dict,
487
  )
488
+ sequence_output = outputs[0]
489
+ prediction_scores = self.lm_head(sequence_output)
490
+
491
+ masked_lm_loss = None
492
+ if labels is not None:
493
+ # move labels to correct device to enable model parallelism
494
+ labels = labels.to(prediction_scores.device)
495
+ loss_fct = CrossEntropyLoss()
496
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
497
+
498
+ if not return_dict:
499
+ output = (prediction_scores,) + outputs[2:]
500
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
501
+
502
+ return MaskedLMOutput(
503
+ loss=masked_lm_loss,
504
+ logits=prediction_scores,
505
+ hidden_states=outputs.hidden_states,
506
+ attentions=outputs.attentions,
 
 
 
 
 
 
 
 
 
 
 
 
 
507
  )
508
 
509
 
510
+ # class XLMRobertaForPreTraining(XLMRobertaPreTrainedModel):
511
+ # def __init__(self, config: XLMRobertaFlashConfig):
512
+ # super().__init__(config)
513
+ # # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
514
+ # # (around 15%) to the classifier heads.
515
+ # self.dense_seq_output = getattr(config, "dense_seq_output", False)
516
+ # # If last_layer_subset, we only need the compute the last layer for a subset of tokens
517
+ # # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
518
+ # self.last_layer_subset = getattr(config, "last_layer_subset", False)
519
+ # if self.last_layer_subset:
520
+ # assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
521
+ # use_xentropy = getattr(config, "use_xentropy", False)
522
+ # if use_xentropy and CrossEntropyLoss is None:
523
+ # raise ImportError("xentropy_cuda is not installed")
524
+ # loss_cls = (
525
+ # nn.CrossEntropyLoss
526
+ # if not use_xentropy
527
+ # else partial(CrossEntropyLoss, inplace_backward=True)
528
+ # )
529
+ #
530
+ # self.xlm = XLMRobertaModel(config)
531
+ # self.cls = XLMRobertaPreTrainingHeads(config)
532
+ # self.mlm_loss = loss_cls(ignore_index=0)
533
+ # self.nsp_loss = loss_cls(ignore_index=-1)
534
+ #
535
+ # # Initialize weights and apply final processing
536
+ # self.apply(partial(_init_weights, initializer_range=config.initializer_range))
537
+ # self.tie_weights()
538
+ #
539
+ # def tie_weights(self):
540
+ # self.cls.predictions.decoder.weight = self.xlm.embeddings.word_embeddings.weight
541
+ #
542
+ # def forward(
543
+ # self,
544
+ # input_ids,
545
+ # position_ids=None,
546
+ # token_type_ids=None,
547
+ # attention_mask=None,
548
+ # labels=None,
549
+ # next_sentence_label=None,
550
+ # ):
551
+ # """
552
+ # If labels are provided, they must be 0 for masked out tokens (as specified in the attention
553
+ # mask).
554
+ # Outputs:
555
+ # if `labels` and `next_sentence_label` are not `None`:
556
+ # Outputs the total_loss which is the sum of the masked language modeling loss and the next
557
+ # sentence classification loss.
558
+ # if `labels` or `next_sentence_label` is `None`:
559
+ # Outputs a tuple comprising
560
+ # - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
561
+ # - the next sentence classification logits of shape [batch_size, 2].
562
+ #
563
+ # """
564
+ # masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
565
+ # outputs = self.xlm(
566
+ # input_ids,
567
+ # position_ids=position_ids,
568
+ # token_type_ids=token_type_ids,
569
+ # attention_mask=attention_mask.bool() if attention_mask is not None else None,
570
+ # masked_tokens_mask=masked_tokens_mask,
571
+ # )
572
+ # sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
573
+ # if self.dense_seq_output and labels is not None:
574
+ # masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
575
+ # if not self.last_layer_subset:
576
+ # sequence_output = index_first_axis(
577
+ # rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
578
+ # )
579
+ # prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
580
+ #
581
+ # total_loss = None
582
+ # if labels is not None and next_sentence_label is not None:
583
+ # if (
584
+ # self.dense_seq_output and labels is not None
585
+ # ): # prediction_scores are already flattened
586
+ # masked_lm_loss = self.mlm_loss(
587
+ # prediction_scores, labels.flatten()[masked_token_idx]
588
+ # )
589
+ # else:
590
+ # masked_lm_loss = self.mlm_loss(
591
+ # rearrange(prediction_scores, "... v -> (...) v"),
592
+ # rearrange(labels, "... -> (...)"),
593
+ # )
594
+ # next_sentence_loss = self.nsp_loss(
595
+ # rearrange(seq_relationship_score, "... t -> (...) t"),
596
+ # rearrange(next_sentence_label, "... -> (...)"),
597
+ # )
598
+ # total_loss = masked_lm_loss.float() + next_sentence_loss.float()
599
+ #
600
+ # return BertForPreTrainingOutput(
601
+ # loss=total_loss,
602
+ # prediction_logits=prediction_scores,
603
+ # seq_relationship_logits=seq_relationship_score,
604
+ # )
605
+
606
+
607
  def remap_state_dict(state_dict, config: PretrainedConfig):
608
  """
609
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:61bdee1ea6ae50618c387234ae94a500df9ce095e59d836b8aefef33e9d8884e
3
- size 1112222546
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfa8fa7c7e120199548fe7149512c0adfe58f6bc13ce19f09b895aa25e8af910
3
+ size 1113232188
bert_padding.py β†’ xlm_padding.py RENAMED
File without changes