Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code
zpn commited on
Commit
7cd983f
1 Parent(s): 0f627bd

Update modeling_hf_nomic_bert.py

Browse files
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +16 -12
modeling_hf_nomic_bert.py CHANGED
@@ -105,7 +105,13 @@ def filter_shapes(state_dict, model):
105
  return filtered_state_dict
106
 
107
 
108
- def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weights=False, add_pooling_layer=False):
 
 
 
 
 
 
109
  """
110
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
111
  """
@@ -305,13 +311,12 @@ class NomicBertPreTrainedModel(PreTrainedModel):
305
  if config is None:
306
  config = cls.config_class.from_pretrained(model_name)
307
  remove_cls = cls != NomicBertForPreTraining
308
- remove_bert_prefix = cls != NomicBertForPreTraining
309
  ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
310
  num_labels = kwargs.pop("num_labels", None)
311
  rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
312
- if rotary_scaling_factor:
313
- config.rotary_scaling_factor = rotary_scaling_factor
314
-
315
  if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
316
  config.n_positions = 2048
317
  if num_labels:
@@ -320,10 +325,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
320
  if "add_pooling_layer" in kwargs:
321
  model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
322
  else:
323
- if cls == NomicBertModel:
324
- model = cls(config, *inputs, add_pooling_layer=False)
325
- else:
326
- model = cls(config, *inputs)
327
  # TODO: fix this
328
  # Assuming we know what we're doing when loading from disk
329
  # Prob a bad assumption but i'm tired and want to train this asap
@@ -342,7 +344,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
342
  load_return = model.load_state_dict(state_dict, strict=False)
343
  else:
344
  # TODO: can probably check config class and see if we need to remap from a bert model
345
- state_dict = state_dict_from_pretrained(model_name, safe_serialization=kwargs.get("safe_serialization", False))
346
  state_dict = remap_bert_state_dict(
347
  state_dict,
348
  config,
@@ -353,7 +355,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
353
  if ignore_mismatched_shapes:
354
  state_dict = filter_shapes(state_dict, model)
355
 
356
- load_return = model.load_state_dict(state_dict, strict=True)
357
  logger.warning(load_return)
358
  return model
359
 
@@ -859,6 +861,7 @@ class NomicBertBlock(nn.Module):
859
  max_seq_len: Optional[int] = None,
860
  ):
861
  r"""Pass the input through the encoder layer.
 
862
  Args:
863
  hidden_states: the sequence to the encoder layer (required).
864
  residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
@@ -1116,6 +1119,7 @@ class NomicBertForPreTraining(NomicBertPreTrainedModel):
1116
  Outputs a tuple comprising
1117
  - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
1118
  - the next sentence classification logits of shape [batch_size, 2].
 
1119
  """
1120
  outputs = self.bert(
1121
  input_ids,
@@ -1220,4 +1224,4 @@ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1220
  logits=logits,
1221
  hidden_states=outputs.hidden_states,
1222
  attentions=outputs.attentions,
1223
- )
 
105
  return filtered_state_dict
106
 
107
 
108
+ def remap_bert_state_dict(
109
+ state_dict,
110
+ config,
111
+ remove_bert=False,
112
+ remove_cls_weights=False,
113
+ add_pooling_layer=False,
114
+ ):
115
  """
116
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
117
  """
 
311
  if config is None:
312
  config = cls.config_class.from_pretrained(model_name)
313
  remove_cls = cls != NomicBertForPreTraining
314
+ remove_bert_prefix = cls != NomicBertForPreTraining and cls != NomicBertForSequenceClassification
315
  ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
316
  num_labels = kwargs.pop("num_labels", None)
317
  rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
318
+ strict = kwargs.pop("strict", True)
319
+ config.rotary_scaling_factor = rotary_scaling_factor
 
320
  if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
321
  config.n_positions = 2048
322
  if num_labels:
 
325
  if "add_pooling_layer" in kwargs:
326
  model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
327
  else:
328
+ model = cls(config, *inputs)
 
 
 
329
  # TODO: fix this
330
  # Assuming we know what we're doing when loading from disk
331
  # Prob a bad assumption but i'm tired and want to train this asap
 
344
  load_return = model.load_state_dict(state_dict, strict=False)
345
  else:
346
  # TODO: can probably check config class and see if we need to remap from a bert model
347
+ state_dict = state_dict_from_pretrained(model_name)
348
  state_dict = remap_bert_state_dict(
349
  state_dict,
350
  config,
 
355
  if ignore_mismatched_shapes:
356
  state_dict = filter_shapes(state_dict, model)
357
 
358
+ load_return = model.load_state_dict(state_dict, strict=strict)
359
  logger.warning(load_return)
360
  return model
361
 
 
861
  max_seq_len: Optional[int] = None,
862
  ):
863
  r"""Pass the input through the encoder layer.
864
+
865
  Args:
866
  hidden_states: the sequence to the encoder layer (required).
867
  residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
 
1119
  Outputs a tuple comprising
1120
  - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
1121
  - the next sentence classification logits of shape [batch_size, 2].
1122
+
1123
  """
1124
  outputs = self.bert(
1125
  input_ids,
 
1224
  logits=logits,
1225
  hidden_states=outputs.hidden_states,
1226
  attentions=outputs.attentions,
1227
+ )