oweller2
commited on
Commit
•
7bf8a6c
1
Parent(s):
6369df1
logits
Browse files- modeling_flexbert.py +4 -3
modeling_flexbert.py
CHANGED
@@ -1702,7 +1702,8 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1702 |
if self.pad_logits:
|
1703 |
# print(f"Padding logits: {logits.shape}")
|
1704 |
new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
|
1705 |
-
|
|
|
1706 |
new_logits = new_logits.unsqueeze(0)
|
1707 |
return CausalLMOutput(
|
1708 |
loss=loss,
|
@@ -1711,8 +1712,8 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1711 |
attentions=None,
|
1712 |
)
|
1713 |
else:
|
1714 |
-
|
1715 |
-
if
|
1716 |
logits = logits.unsqueeze(0)
|
1717 |
return CausalLMOutput(
|
1718 |
loss=loss,
|
|
|
1702 |
if self.pad_logits:
|
1703 |
# print(f"Padding logits: {logits.shape}")
|
1704 |
new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
|
1705 |
+
# print(new_logits.shape)
|
1706 |
+
if new_logits.dim() == 1:
|
1707 |
new_logits = new_logits.unsqueeze(0)
|
1708 |
return CausalLMOutput(
|
1709 |
loss=loss,
|
|
|
1712 |
attentions=None,
|
1713 |
)
|
1714 |
else:
|
1715 |
+
print(f"Non-padding logits: {logits.shape}")
|
1716 |
+
if logits.dim() == 1:
|
1717 |
logits = logits.unsqueeze(0)
|
1718 |
return CausalLMOutput(
|
1719 |
loss=loss,
|