Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code
zpn commited on
Commit
fcebeef
1 Parent(s): 2a32621

Update modeling_hf_nomic_bert.py

Browse files
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +136 -32
modeling_hf_nomic_bert.py CHANGED
@@ -119,7 +119,7 @@ def filter_shapes(state_dict, model):
119
  return filtered_state_dict
120
 
121
 
122
- def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weights=False, add_pooling_layer=False):
123
  """
124
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
125
  """
@@ -225,6 +225,16 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_wei
225
 
226
  state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
227
 
 
 
 
 
 
 
 
 
 
 
228
  # Word embedding
229
  pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
230
  if pad_vocab_size_multiple > 1:
@@ -232,18 +242,19 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_wei
232
  state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
233
  word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
234
  )
235
- decoder_weight = state_dict["cls.predictions.decoder.weight"]
236
- state_dict["cls.predictions.decoder.weight"] = F.pad(
237
- decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
238
- )
239
- # If the vocab was padded, we want to set the decoder bias for those padded indices to be
240
- # strongly negative (i.e. the decoder shouldn't predict those indices).
241
- # TD [2022-05-09]: I don't think it affects the MLPerf training.
242
- if "cls.predictions.decoder.bias" in state_dict:
243
- decoder_bias = state_dict["cls.predictions.decoder.bias"]
244
- state_dict["cls.predictions.decoder.bias"] = F.pad(
245
- decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
246
  )
 
 
 
 
 
 
 
 
247
 
248
  if add_pooling_layer is False:
249
  pooler_weights = ["bert.pooler.dense.weight",
@@ -252,16 +263,6 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_wei
252
  for key in pooler_weights:
253
  state_dict.pop(key, None)
254
 
255
- if remove_cls_weights:
256
- cls_weights = ["cls.predictions.decoder.bias",
257
- "cls.predictions.transform.dense.weight",
258
- "cls.predictions.transform.dense.bias",
259
- "cls.predictions.transform.layer_norm.weight",
260
- "cls.predictions.transform.layer_norm.bias",
261
- "cls.predictions.decoder.weight"]
262
- for weight in cls_weights:
263
- state_dict.pop(weight, None)
264
-
265
  if remove_bert:
266
  def remove_bert_prefix(key):
267
  key = re.sub(r"^bert.", "", key)
@@ -319,9 +320,21 @@ class NomicBertPreTrainedModel(PreTrainedModel):
319
  remove_bert_prefix = cls != NomicBertForPreTraining
320
  ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
321
  num_labels = kwargs.pop("num_labels", None)
 
 
 
 
 
322
  if num_labels:
323
  config.num_labels = num_labels
324
- model = cls(config, *inputs)
 
 
 
 
 
 
 
325
  # TODO: fix this
326
  # Assuming we know what we're doing when loading from disk
327
  # Prob a bad assumption but i'm tired and want to train this asap
@@ -379,9 +392,9 @@ class NomicBertEmbeddings(nn.Module):
379
  self.word_embeddings = nn.Embedding(
380
  config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
381
  )
382
- self.max_position_embeddings = config.max_position_embeddings
383
  self.type_vocab_size = config.type_vocab_size
384
- if self.max_position_embeddings > 0:
385
  self.position_embeddings = nn.Embedding(
386
  config.max_position_embeddings, config.hidden_size,
387
  )
@@ -542,6 +555,12 @@ class NomicBertRotaryEmbedding(nn.Module):
542
  self.register_buffer("inv_freq", inv_freq, persistent=False)
543
  self.interleaved = interleaved
544
  self.scale_base = scale_base
 
 
 
 
 
 
545
 
546
  self._seq_len_cached = 0
547
  self._cos_cached = None
@@ -607,7 +626,9 @@ class NomicBertRotaryEmbedding(nn.Module):
607
  Apply rotary embedding *inplace* to qkv and / or kv.
608
  """
609
  seqlen = qkv.shape[1]
610
- if max_seqlen is not None:
 
 
611
  self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
612
  elif isinstance(seqlen_offset, int):
613
  self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
@@ -617,6 +638,79 @@ class NomicBertRotaryEmbedding(nn.Module):
617
  return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
618
 
619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
  class NomicBertAttention(nn.Module):
622
  """Multi-head self-attention and cross-attention"""
@@ -651,12 +745,22 @@ class NomicBertAttention(nn.Module):
651
 
652
  self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
653
  if self.rotary_emb_dim > 0:
654
- self.rotary_emb = NomicBertRotaryEmbedding(
655
- self.rotary_emb_dim,
656
- base=config.rotary_emb_base,
657
- scale_base=config.rotary_emb_scale_base,
658
- interleaved=config.rotary_emb_interleaved,
659
- )
 
 
 
 
 
 
 
 
 
 
660
  # bug in xformers: https://github.com/facebookresearch/xformers/issues/841
661
  # uses the head dimension instead of the sequence dimension
662
  self.rotary_head_dim = getattr(config, "rotary_head_dim", False)
 
119
  return filtered_state_dict
120
 
121
 
122
+ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weights=False, add_pooling_layer=False):
123
  """
124
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
125
  """
 
225
 
226
  state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
227
 
228
+ if remove_cls_weights:
229
+ cls_weights = ["cls.predictions.decoder.bias",
230
+ "cls.predictions.transform.dense.weight",
231
+ "cls.predictions.transform.dense.bias",
232
+ "cls.predictions.transform.layer_norm.weight",
233
+ "cls.predictions.transform.layer_norm.bias",
234
+ "cls.predictions.decoder.weight"]
235
+ for weight in cls_weights:
236
+ state_dict.pop(weight, None)
237
+
238
  # Word embedding
239
  pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
240
  if pad_vocab_size_multiple > 1:
 
242
  state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
243
  word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
244
  )
245
+ if not remove_cls_weights:
246
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
247
+ state_dict["cls.predictions.decoder.weight"] = F.pad(
248
+ decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
 
 
 
 
 
 
 
249
  )
250
+ # If the vocab was padded, we want to set the decoder bias for those padded indices to be
251
+ # strongly negative (i.e. the decoder shouldn't predict those indices).
252
+ # TD [2022-05-09]: I don't think it affects the MLPerf training.
253
+ if "cls.predictions.decoder.bias" in state_dict:
254
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
255
+ state_dict["cls.predictions.decoder.bias"] = F.pad(
256
+ decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
257
+ )
258
 
259
  if add_pooling_layer is False:
260
  pooler_weights = ["bert.pooler.dense.weight",
 
263
  for key in pooler_weights:
264
  state_dict.pop(key, None)
265
 
 
 
 
 
 
 
 
 
 
 
266
  if remove_bert:
267
  def remove_bert_prefix(key):
268
  key = re.sub(r"^bert.", "", key)
 
320
  remove_bert_prefix = cls != NomicBertForPreTraining
321
  ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
322
  num_labels = kwargs.pop("num_labels", None)
323
+ rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
324
+ if rotary_scaling_factor:
325
+ config.rotary_scaling_factor = rotary_scaling_factor
326
+ if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
327
+ config.n_positions = 2048
328
  if num_labels:
329
  config.num_labels = num_labels
330
+
331
+ if "add_pooling_layer" in kwargs:
332
+ model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
333
+ else:
334
+ if cls == NomicBertModel:
335
+ model = cls(config, *inputs, add_pooling_layer=False)
336
+ else:
337
+ model = cls(config, *inputs)
338
  # TODO: fix this
339
  # Assuming we know what we're doing when loading from disk
340
  # Prob a bad assumption but i'm tired and want to train this asap
 
392
  self.word_embeddings = nn.Embedding(
393
  config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
394
  )
395
+ self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
396
  self.type_vocab_size = config.type_vocab_size
397
+ if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
398
  self.position_embeddings = nn.Embedding(
399
  config.max_position_embeddings, config.hidden_size,
400
  )
 
555
  self.register_buffer("inv_freq", inv_freq, persistent=False)
556
  self.interleaved = interleaved
557
  self.scale_base = scale_base
558
+ scale = (
559
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
560
+ if scale_base is not None
561
+ else None
562
+ )
563
+ self.register_buffer("scale", scale, persistent=False)
564
 
565
  self._seq_len_cached = 0
566
  self._cos_cached = None
 
626
  Apply rotary embedding *inplace* to qkv and / or kv.
627
  """
628
  seqlen = qkv.shape[1]
629
+ if seqlen > self._seq_len_cached:
630
+ self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype)
631
+ elif max_seqlen is not None:
632
  self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
633
  elif isinstance(seqlen_offset, int):
634
  self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
 
638
  return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
639
 
640
 
641
+ class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
642
+ def __init__(self, rotary_scaling_factor, max_position_embeddings, **kwargs):
643
+ super().__init__(**kwargs)
644
+ self.rotary_scaling_factor = rotary_scaling_factor
645
+ self.max_position_embeddings = max_position_embeddings
646
+
647
+
648
+ def _compute_inv_freq(self, base=None, device=None):
649
+ if base is None:
650
+ base = self.base
651
+ return 1.0 / (
652
+ base
653
+ ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
654
+ )
655
+
656
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
657
+ # Reset the tables if the sequence length has changed,
658
+ # if we're on a new device (possibly due to tracing for instance),
659
+ # or if we're switching from inference mode to training
660
+ if seqlen > self.max_position_embeddings:
661
+ base = self.base * (
662
+ (self.rotary_scaling_factor * seqlen / self.max_position_embeddings) - (self.rotary_scaling_factor - 1)
663
+ ) ** (self.dim / (self.dim - 2))
664
+ inv_freq = self._compute_inv_freq(base=base, device=device)
665
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
666
+
667
+ if (
668
+ seqlen > self._seq_len_cached
669
+ or self._cos_cached is None
670
+ or self._cos_cached.device != device
671
+ or self._cos_cached.dtype != dtype
672
+ or (self.training and self._cos_cached.is_inference())
673
+ ):
674
+ self._seq_len_cached = seqlen
675
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
676
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
677
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
678
+ if self.pos_idx_in_fp32:
679
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
680
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
681
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
682
+ # cos & sin output to change significantly.
683
+ # We want to recompute self.inv_freq if it was not loaded in fp32
684
+ if self.inv_freq.dtype != torch.float32:
685
+ if seqlen > self.max_position_embeddings:
686
+ base = self.base * (
687
+ (self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)
688
+ ) ** (self.dim / (self.dim - 2))
689
+ else:
690
+ base = self.base
691
+ inv_freq = self._compute_inv_freq(device=device, base=base)
692
+ else:
693
+ inv_freq = self.inv_freq
694
+ else:
695
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
696
+ inv_freq = self.inv_freq
697
+ # Don't do einsum, it converts fp32 to fp16 under AMP
698
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
699
+ freqs = torch.outer(t, inv_freq)
700
+ if self.scale is None:
701
+ self._cos_cached = torch.cos(freqs).to(dtype)
702
+ self._sin_cached = torch.sin(freqs).to(dtype)
703
+ else:
704
+ power = (
705
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
706
+ - seqlen // 2
707
+ ) / self.scale_base
708
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
709
+ # We want the multiplication by scale to happen in fp32
710
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
711
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
712
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
713
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
714
 
715
  class NomicBertAttention(nn.Module):
716
  """Multi-head self-attention and cross-attention"""
 
745
 
746
  self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
747
  if self.rotary_emb_dim > 0:
748
+ if config.rotary_scaling_factor:
749
+ self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding(
750
+ dim=self.rotary_emb_dim,
751
+ base=config.rotary_emb_base,
752
+ scale_base=config.rotary_emb_scale_base,
753
+ interleaved=config.rotary_emb_interleaved,
754
+ rotary_scaling_factor=config.rotary_scaling_factor,
755
+ max_position_embeddings=config.n_positions,
756
+ )
757
+ else:
758
+ self.rotary_emb = NomicBertRotaryEmbedding(
759
+ dim=self.rotary_emb_dim,
760
+ base=config.rotary_emb_base,
761
+ scale_base=config.rotary_emb_scale_base,
762
+ interleaved=config.rotary_emb_interleaved,
763
+ )
764
  # bug in xformers: https://github.com/facebookresearch/xformers/issues/841
765
  # uses the head dimension instead of the sequence dimension
766
  self.rotary_head_dim = getattr(config, "rotary_head_dim", False)