oweller2
commited on
Commit
•
bf29787
1
Parent(s):
7bf8a6c
update
Browse files- modeling_flexbert.py +6 -4
modeling_flexbert.py
CHANGED
@@ -1657,6 +1657,7 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1657 |
cu_seqlens=cu_seqlens,
|
1658 |
max_seqlen=max_seqlen,
|
1659 |
)
|
|
|
1660 |
|
1661 |
if self.compile_model:
|
1662 |
logits = self.compiled_lm_head(hidden_states)
|
@@ -1703,8 +1704,8 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1703 |
# print(f"Padding logits: {logits.shape}")
|
1704 |
new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
|
1705 |
# print(new_logits.shape)
|
1706 |
-
if new_logits.dim() ==
|
1707 |
-
|
1708 |
return CausalLMOutput(
|
1709 |
loss=loss,
|
1710 |
logits=new_logits,
|
@@ -1713,8 +1714,9 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1713 |
)
|
1714 |
else:
|
1715 |
print(f"Non-padding logits: {logits.shape}")
|
1716 |
-
|
1717 |
-
|
|
|
1718 |
return CausalLMOutput(
|
1719 |
loss=loss,
|
1720 |
logits=logits,
|
|
|
1657 |
cu_seqlens=cu_seqlens,
|
1658 |
max_seqlen=max_seqlen,
|
1659 |
)
|
1660 |
+
print(hidden_states.shape)
|
1661 |
|
1662 |
if self.compile_model:
|
1663 |
logits = self.compiled_lm_head(hidden_states)
|
|
|
1704 |
# print(f"Padding logits: {logits.shape}")
|
1705 |
new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
|
1706 |
# print(new_logits.shape)
|
1707 |
+
# if new_logits.dim() == 2:
|
1708 |
+
# new_logits = new_logits.unsqueeze(0)
|
1709 |
return CausalLMOutput(
|
1710 |
loss=loss,
|
1711 |
logits=new_logits,
|
|
|
1714 |
)
|
1715 |
else:
|
1716 |
print(f"Non-padding logits: {logits.shape}")
|
1717 |
+
logits = logits.view(-1, logits.size(-1))
|
1718 |
+
# if logits.dim() == 2:
|
1719 |
+
# logits = logits.unsqueeze(0)
|
1720 |
return CausalLMOutput(
|
1721 |
loss=loss,
|
1722 |
logits=logits,
|