Update modeling_hf_nomic_bert.py
Browse files- modeling_hf_nomic_bert.py +136 -32
modeling_hf_nomic_bert.py
CHANGED
@@ -119,7 +119,7 @@ def filter_shapes(state_dict, model):
|
|
119 |
return filtered_state_dict
|
120 |
|
121 |
|
122 |
-
def remap_bert_state_dict(state_dict, config,
|
123 |
"""
|
124 |
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
125 |
"""
|
@@ -225,6 +225,16 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_wei
|
|
225 |
|
226 |
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
|
227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
# Word embedding
|
229 |
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
230 |
if pad_vocab_size_multiple > 1:
|
@@ -232,18 +242,19 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_wei
|
|
232 |
state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
|
233 |
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
|
234 |
)
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
|
240 |
-
# strongly negative (i.e. the decoder shouldn't predict those indices).
|
241 |
-
# TD [2022-05-09]: I don't think it affects the MLPerf training.
|
242 |
-
if "cls.predictions.decoder.bias" in state_dict:
|
243 |
-
decoder_bias = state_dict["cls.predictions.decoder.bias"]
|
244 |
-
state_dict["cls.predictions.decoder.bias"] = F.pad(
|
245 |
-
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
|
246 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
if add_pooling_layer is False:
|
249 |
pooler_weights = ["bert.pooler.dense.weight",
|
@@ -252,16 +263,6 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_wei
|
|
252 |
for key in pooler_weights:
|
253 |
state_dict.pop(key, None)
|
254 |
|
255 |
-
if remove_cls_weights:
|
256 |
-
cls_weights = ["cls.predictions.decoder.bias",
|
257 |
-
"cls.predictions.transform.dense.weight",
|
258 |
-
"cls.predictions.transform.dense.bias",
|
259 |
-
"cls.predictions.transform.layer_norm.weight",
|
260 |
-
"cls.predictions.transform.layer_norm.bias",
|
261 |
-
"cls.predictions.decoder.weight"]
|
262 |
-
for weight in cls_weights:
|
263 |
-
state_dict.pop(weight, None)
|
264 |
-
|
265 |
if remove_bert:
|
266 |
def remove_bert_prefix(key):
|
267 |
key = re.sub(r"^bert.", "", key)
|
@@ -319,9 +320,21 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
319 |
remove_bert_prefix = cls != NomicBertForPreTraining
|
320 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
321 |
num_labels = kwargs.pop("num_labels", None)
|
|
|
|
|
|
|
|
|
|
|
322 |
if num_labels:
|
323 |
config.num_labels = num_labels
|
324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
# TODO: fix this
|
326 |
# Assuming we know what we're doing when loading from disk
|
327 |
# Prob a bad assumption but i'm tired and want to train this asap
|
@@ -379,9 +392,9 @@ class NomicBertEmbeddings(nn.Module):
|
|
379 |
self.word_embeddings = nn.Embedding(
|
380 |
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
|
381 |
)
|
382 |
-
self.max_position_embeddings = config.max_position_embeddings
|
383 |
self.type_vocab_size = config.type_vocab_size
|
384 |
-
if self.max_position_embeddings > 0:
|
385 |
self.position_embeddings = nn.Embedding(
|
386 |
config.max_position_embeddings, config.hidden_size,
|
387 |
)
|
@@ -542,6 +555,12 @@ class NomicBertRotaryEmbedding(nn.Module):
|
|
542 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
543 |
self.interleaved = interleaved
|
544 |
self.scale_base = scale_base
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
|
546 |
self._seq_len_cached = 0
|
547 |
self._cos_cached = None
|
@@ -607,7 +626,9 @@ class NomicBertRotaryEmbedding(nn.Module):
|
|
607 |
Apply rotary embedding *inplace* to qkv and / or kv.
|
608 |
"""
|
609 |
seqlen = qkv.shape[1]
|
610 |
-
if
|
|
|
|
|
611 |
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
612 |
elif isinstance(seqlen_offset, int):
|
613 |
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
@@ -617,6 +638,79 @@ class NomicBertRotaryEmbedding(nn.Module):
|
|
617 |
return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
|
618 |
|
619 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
620 |
|
621 |
class NomicBertAttention(nn.Module):
|
622 |
"""Multi-head self-attention and cross-attention"""
|
@@ -651,12 +745,22 @@ class NomicBertAttention(nn.Module):
|
|
651 |
|
652 |
self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
|
653 |
if self.rotary_emb_dim > 0:
|
654 |
-
|
655 |
-
self.
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
660 |
# bug in xformers: https://github.com/facebookresearch/xformers/issues/841
|
661 |
# uses the head dimension instead of the sequence dimension
|
662 |
self.rotary_head_dim = getattr(config, "rotary_head_dim", False)
|
|
|
119 |
return filtered_state_dict
|
120 |
|
121 |
|
122 |
+
def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weights=False, add_pooling_layer=False):
|
123 |
"""
|
124 |
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
125 |
"""
|
|
|
225 |
|
226 |
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
|
227 |
|
228 |
+
if remove_cls_weights:
|
229 |
+
cls_weights = ["cls.predictions.decoder.bias",
|
230 |
+
"cls.predictions.transform.dense.weight",
|
231 |
+
"cls.predictions.transform.dense.bias",
|
232 |
+
"cls.predictions.transform.layer_norm.weight",
|
233 |
+
"cls.predictions.transform.layer_norm.bias",
|
234 |
+
"cls.predictions.decoder.weight"]
|
235 |
+
for weight in cls_weights:
|
236 |
+
state_dict.pop(weight, None)
|
237 |
+
|
238 |
# Word embedding
|
239 |
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
240 |
if pad_vocab_size_multiple > 1:
|
|
|
242 |
state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
|
243 |
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
|
244 |
)
|
245 |
+
if not remove_cls_weights:
|
246 |
+
decoder_weight = state_dict["cls.predictions.decoder.weight"]
|
247 |
+
state_dict["cls.predictions.decoder.weight"] = F.pad(
|
248 |
+
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
)
|
250 |
+
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
|
251 |
+
# strongly negative (i.e. the decoder shouldn't predict those indices).
|
252 |
+
# TD [2022-05-09]: I don't think it affects the MLPerf training.
|
253 |
+
if "cls.predictions.decoder.bias" in state_dict:
|
254 |
+
decoder_bias = state_dict["cls.predictions.decoder.bias"]
|
255 |
+
state_dict["cls.predictions.decoder.bias"] = F.pad(
|
256 |
+
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
|
257 |
+
)
|
258 |
|
259 |
if add_pooling_layer is False:
|
260 |
pooler_weights = ["bert.pooler.dense.weight",
|
|
|
263 |
for key in pooler_weights:
|
264 |
state_dict.pop(key, None)
|
265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
if remove_bert:
|
267 |
def remove_bert_prefix(key):
|
268 |
key = re.sub(r"^bert.", "", key)
|
|
|
320 |
remove_bert_prefix = cls != NomicBertForPreTraining
|
321 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
322 |
num_labels = kwargs.pop("num_labels", None)
|
323 |
+
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
324 |
+
if rotary_scaling_factor:
|
325 |
+
config.rotary_scaling_factor = rotary_scaling_factor
|
326 |
+
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
327 |
+
config.n_positions = 2048
|
328 |
if num_labels:
|
329 |
config.num_labels = num_labels
|
330 |
+
|
331 |
+
if "add_pooling_layer" in kwargs:
|
332 |
+
model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
|
333 |
+
else:
|
334 |
+
if cls == NomicBertModel:
|
335 |
+
model = cls(config, *inputs, add_pooling_layer=False)
|
336 |
+
else:
|
337 |
+
model = cls(config, *inputs)
|
338 |
# TODO: fix this
|
339 |
# Assuming we know what we're doing when loading from disk
|
340 |
# Prob a bad assumption but i'm tired and want to train this asap
|
|
|
392 |
self.word_embeddings = nn.Embedding(
|
393 |
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
|
394 |
)
|
395 |
+
self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
|
396 |
self.type_vocab_size = config.type_vocab_size
|
397 |
+
if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
|
398 |
self.position_embeddings = nn.Embedding(
|
399 |
config.max_position_embeddings, config.hidden_size,
|
400 |
)
|
|
|
555 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
556 |
self.interleaved = interleaved
|
557 |
self.scale_base = scale_base
|
558 |
+
scale = (
|
559 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
560 |
+
if scale_base is not None
|
561 |
+
else None
|
562 |
+
)
|
563 |
+
self.register_buffer("scale", scale, persistent=False)
|
564 |
|
565 |
self._seq_len_cached = 0
|
566 |
self._cos_cached = None
|
|
|
626 |
Apply rotary embedding *inplace* to qkv and / or kv.
|
627 |
"""
|
628 |
seqlen = qkv.shape[1]
|
629 |
+
if seqlen > self._seq_len_cached:
|
630 |
+
self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype)
|
631 |
+
elif max_seqlen is not None:
|
632 |
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
633 |
elif isinstance(seqlen_offset, int):
|
634 |
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
|
|
638 |
return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
|
639 |
|
640 |
|
641 |
+
class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
|
642 |
+
def __init__(self, rotary_scaling_factor, max_position_embeddings, **kwargs):
|
643 |
+
super().__init__(**kwargs)
|
644 |
+
self.rotary_scaling_factor = rotary_scaling_factor
|
645 |
+
self.max_position_embeddings = max_position_embeddings
|
646 |
+
|
647 |
+
|
648 |
+
def _compute_inv_freq(self, base=None, device=None):
|
649 |
+
if base is None:
|
650 |
+
base = self.base
|
651 |
+
return 1.0 / (
|
652 |
+
base
|
653 |
+
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
|
654 |
+
)
|
655 |
+
|
656 |
+
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
657 |
+
# Reset the tables if the sequence length has changed,
|
658 |
+
# if we're on a new device (possibly due to tracing for instance),
|
659 |
+
# or if we're switching from inference mode to training
|
660 |
+
if seqlen > self.max_position_embeddings:
|
661 |
+
base = self.base * (
|
662 |
+
(self.rotary_scaling_factor * seqlen / self.max_position_embeddings) - (self.rotary_scaling_factor - 1)
|
663 |
+
) ** (self.dim / (self.dim - 2))
|
664 |
+
inv_freq = self._compute_inv_freq(base=base, device=device)
|
665 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
666 |
+
|
667 |
+
if (
|
668 |
+
seqlen > self._seq_len_cached
|
669 |
+
or self._cos_cached is None
|
670 |
+
or self._cos_cached.device != device
|
671 |
+
or self._cos_cached.dtype != dtype
|
672 |
+
or (self.training and self._cos_cached.is_inference())
|
673 |
+
):
|
674 |
+
self._seq_len_cached = seqlen
|
675 |
+
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
676 |
+
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
677 |
+
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
678 |
+
if self.pos_idx_in_fp32:
|
679 |
+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
680 |
+
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
681 |
+
# will be large. Having it in bf16 will lose a lot of precision and cause the
|
682 |
+
# cos & sin output to change significantly.
|
683 |
+
# We want to recompute self.inv_freq if it was not loaded in fp32
|
684 |
+
if self.inv_freq.dtype != torch.float32:
|
685 |
+
if seqlen > self.max_position_embeddings:
|
686 |
+
base = self.base * (
|
687 |
+
(self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)
|
688 |
+
) ** (self.dim / (self.dim - 2))
|
689 |
+
else:
|
690 |
+
base = self.base
|
691 |
+
inv_freq = self._compute_inv_freq(device=device, base=base)
|
692 |
+
else:
|
693 |
+
inv_freq = self.inv_freq
|
694 |
+
else:
|
695 |
+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
696 |
+
inv_freq = self.inv_freq
|
697 |
+
# Don't do einsum, it converts fp32 to fp16 under AMP
|
698 |
+
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
699 |
+
freqs = torch.outer(t, inv_freq)
|
700 |
+
if self.scale is None:
|
701 |
+
self._cos_cached = torch.cos(freqs).to(dtype)
|
702 |
+
self._sin_cached = torch.sin(freqs).to(dtype)
|
703 |
+
else:
|
704 |
+
power = (
|
705 |
+
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
706 |
+
- seqlen // 2
|
707 |
+
) / self.scale_base
|
708 |
+
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
709 |
+
# We want the multiplication by scale to happen in fp32
|
710 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
711 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
712 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
713 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
714 |
|
715 |
class NomicBertAttention(nn.Module):
|
716 |
"""Multi-head self-attention and cross-attention"""
|
|
|
745 |
|
746 |
self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
|
747 |
if self.rotary_emb_dim > 0:
|
748 |
+
if config.rotary_scaling_factor:
|
749 |
+
self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding(
|
750 |
+
dim=self.rotary_emb_dim,
|
751 |
+
base=config.rotary_emb_base,
|
752 |
+
scale_base=config.rotary_emb_scale_base,
|
753 |
+
interleaved=config.rotary_emb_interleaved,
|
754 |
+
rotary_scaling_factor=config.rotary_scaling_factor,
|
755 |
+
max_position_embeddings=config.n_positions,
|
756 |
+
)
|
757 |
+
else:
|
758 |
+
self.rotary_emb = NomicBertRotaryEmbedding(
|
759 |
+
dim=self.rotary_emb_dim,
|
760 |
+
base=config.rotary_emb_base,
|
761 |
+
scale_base=config.rotary_emb_scale_base,
|
762 |
+
interleaved=config.rotary_emb_interleaved,
|
763 |
+
)
|
764 |
# bug in xformers: https://github.com/facebookresearch/xformers/issues/841
|
765 |
# uses the head dimension instead of the sequence dimension
|
766 |
self.rotary_head_dim = getattr(config, "rotary_head_dim", False)
|