oweller2 commited on
Commit
6a75052
1 Parent(s): b06bc52
Files changed (1) hide show
  1. modeling_flexbert.py +2 -1
modeling_flexbert.py CHANGED
@@ -1529,6 +1529,7 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1529
  self.unpad_embeddings = config.unpad_embeddings
1530
  self.pad_logits = config.pad_logits
1531
  self.compile_model = config.compile_model
 
1532
  # self.masked_prediction = config.masked_prediction
1533
 
1534
  # Initialize weights and apply final processing
@@ -1702,7 +1703,7 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1702
 
1703
  if self.unpad_embeddings:
1704
  # reshape to batch size
1705
- logits = logits.view(batch_size, -1, logits.size(-1))
1706
 
1707
  if self.pad_logits:
1708
  # print(f"Padding logits: {logits.shape}")
 
1529
  self.unpad_embeddings = config.unpad_embeddings
1530
  self.pad_logits = config.pad_logits
1531
  self.compile_model = config.compile_model
1532
+ self.vocab_size = config.vocab_size
1533
  # self.masked_prediction = config.masked_prediction
1534
 
1535
  # Initialize weights and apply final processing
 
1703
 
1704
  if self.unpad_embeddings:
1705
  # reshape to batch size
1706
+ logits = logits.view(batch_size, -1, self.vocab_size)
1707
 
1708
  if self.pad_logits:
1709
  # print(f"Padding logits: {logits.shape}")