oweller2
commited on
Commit
·
c9fb228
1
Parent(s):
bf29787
modeling
Browse files- modeling_flexbert.py +4 -0
modeling_flexbert.py
CHANGED
@@ -1700,6 +1700,10 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1700 |
shift_labels.view(-1)
|
1701 |
)
|
1702 |
|
|
|
|
|
|
|
|
|
1703 |
if self.pad_logits:
|
1704 |
# print(f"Padding logits: {logits.shape}")
|
1705 |
new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
|
|
|
1700 |
shift_labels.view(-1)
|
1701 |
)
|
1702 |
|
1703 |
+
if self.unpad_embeddings:
|
1704 |
+
# reshape to batch size
|
1705 |
+
logits = logits.view(batch_size, -1, logits.size(-1))
|
1706 |
+
|
1707 |
if self.pad_logits:
|
1708 |
# print(f"Padding logits: {logits.shape}")
|
1709 |
new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
|