Update modeling_hf_nomic_bert.py
Browse files- modeling_hf_nomic_bert.py +16 -12
modeling_hf_nomic_bert.py
CHANGED
@@ -105,7 +105,13 @@ def filter_shapes(state_dict, model):
|
|
105 |
return filtered_state_dict
|
106 |
|
107 |
|
108 |
-
def remap_bert_state_dict(
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
"""
|
110 |
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
111 |
"""
|
@@ -305,13 +311,12 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
305 |
if config is None:
|
306 |
config = cls.config_class.from_pretrained(model_name)
|
307 |
remove_cls = cls != NomicBertForPreTraining
|
308 |
-
remove_bert_prefix = cls != NomicBertForPreTraining
|
309 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
310 |
num_labels = kwargs.pop("num_labels", None)
|
311 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
316 |
config.n_positions = 2048
|
317 |
if num_labels:
|
@@ -320,10 +325,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
320 |
if "add_pooling_layer" in kwargs:
|
321 |
model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
|
322 |
else:
|
323 |
-
|
324 |
-
model = cls(config, *inputs, add_pooling_layer=False)
|
325 |
-
else:
|
326 |
-
model = cls(config, *inputs)
|
327 |
# TODO: fix this
|
328 |
# Assuming we know what we're doing when loading from disk
|
329 |
# Prob a bad assumption but i'm tired and want to train this asap
|
@@ -342,7 +344,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
342 |
load_return = model.load_state_dict(state_dict, strict=False)
|
343 |
else:
|
344 |
# TODO: can probably check config class and see if we need to remap from a bert model
|
345 |
-
state_dict = state_dict_from_pretrained(model_name
|
346 |
state_dict = remap_bert_state_dict(
|
347 |
state_dict,
|
348 |
config,
|
@@ -353,7 +355,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
353 |
if ignore_mismatched_shapes:
|
354 |
state_dict = filter_shapes(state_dict, model)
|
355 |
|
356 |
-
load_return = model.load_state_dict(state_dict, strict=
|
357 |
logger.warning(load_return)
|
358 |
return model
|
359 |
|
@@ -859,6 +861,7 @@ class NomicBertBlock(nn.Module):
|
|
859 |
max_seq_len: Optional[int] = None,
|
860 |
):
|
861 |
r"""Pass the input through the encoder layer.
|
|
|
862 |
Args:
|
863 |
hidden_states: the sequence to the encoder layer (required).
|
864 |
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
|
@@ -1116,6 +1119,7 @@ class NomicBertForPreTraining(NomicBertPreTrainedModel):
|
|
1116 |
Outputs a tuple comprising
|
1117 |
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
|
1118 |
- the next sentence classification logits of shape [batch_size, 2].
|
|
|
1119 |
"""
|
1120 |
outputs = self.bert(
|
1121 |
input_ids,
|
@@ -1220,4 +1224,4 @@ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
|
|
1220 |
logits=logits,
|
1221 |
hidden_states=outputs.hidden_states,
|
1222 |
attentions=outputs.attentions,
|
1223 |
-
)
|
|
|
105 |
return filtered_state_dict
|
106 |
|
107 |
|
108 |
+
def remap_bert_state_dict(
|
109 |
+
state_dict,
|
110 |
+
config,
|
111 |
+
remove_bert=False,
|
112 |
+
remove_cls_weights=False,
|
113 |
+
add_pooling_layer=False,
|
114 |
+
):
|
115 |
"""
|
116 |
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
117 |
"""
|
|
|
311 |
if config is None:
|
312 |
config = cls.config_class.from_pretrained(model_name)
|
313 |
remove_cls = cls != NomicBertForPreTraining
|
314 |
+
remove_bert_prefix = cls != NomicBertForPreTraining and cls != NomicBertForSequenceClassification
|
315 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
316 |
num_labels = kwargs.pop("num_labels", None)
|
317 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
318 |
+
strict = kwargs.pop("strict", True)
|
319 |
+
config.rotary_scaling_factor = rotary_scaling_factor
|
|
|
320 |
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
321 |
config.n_positions = 2048
|
322 |
if num_labels:
|
|
|
325 |
if "add_pooling_layer" in kwargs:
|
326 |
model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
|
327 |
else:
|
328 |
+
model = cls(config, *inputs)
|
|
|
|
|
|
|
329 |
# TODO: fix this
|
330 |
# Assuming we know what we're doing when loading from disk
|
331 |
# Prob a bad assumption but i'm tired and want to train this asap
|
|
|
344 |
load_return = model.load_state_dict(state_dict, strict=False)
|
345 |
else:
|
346 |
# TODO: can probably check config class and see if we need to remap from a bert model
|
347 |
+
state_dict = state_dict_from_pretrained(model_name)
|
348 |
state_dict = remap_bert_state_dict(
|
349 |
state_dict,
|
350 |
config,
|
|
|
355 |
if ignore_mismatched_shapes:
|
356 |
state_dict = filter_shapes(state_dict, model)
|
357 |
|
358 |
+
load_return = model.load_state_dict(state_dict, strict=strict)
|
359 |
logger.warning(load_return)
|
360 |
return model
|
361 |
|
|
|
861 |
max_seq_len: Optional[int] = None,
|
862 |
):
|
863 |
r"""Pass the input through the encoder layer.
|
864 |
+
|
865 |
Args:
|
866 |
hidden_states: the sequence to the encoder layer (required).
|
867 |
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
|
|
|
1119 |
Outputs a tuple comprising
|
1120 |
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
|
1121 |
- the next sentence classification logits of shape [batch_size, 2].
|
1122 |
+
|
1123 |
"""
|
1124 |
outputs = self.bert(
|
1125 |
input_ids,
|
|
|
1224 |
logits=logits,
|
1225 |
hidden_states=outputs.hidden_states,
|
1226 |
attentions=outputs.attentions,
|
1227 |
+
)
|