oweller2
commited on
Commit
•
11a83af
1
Parent(s):
f5a1962
fix
Browse files- modeling_flexbert.py +9 -12
- pytorch_model.bin +1 -1
modeling_flexbert.py
CHANGED
@@ -1666,42 +1666,39 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1666 |
loss = None
|
1667 |
if labels is not None:
|
1668 |
if cu_seqlens is not None:
|
1669 |
-
shift_labels =
|
1670 |
-
|
1671 |
|
1672 |
# Mask boundaries, so eos doesn't predict bos
|
1673 |
for i in range(len(cu_seqlens) - 1):
|
1674 |
boundary_pos = cu_seqlens[i+1] - 1
|
1675 |
-
|
|
|
1676 |
|
1677 |
# NOTE: no padding or mask in there for now
|
1678 |
assert 50283 not in shift_labels, f"PAD token found in shift_labels: {shift_labels}"
|
1679 |
assert 50284 not in shift_labels, f"MASK token found in shift_labels: {shift_labels}"
|
1680 |
-
assert shift_labels.shape ==
|
1681 |
-
|
1682 |
else:
|
1683 |
# Padded case: simple shift
|
1684 |
shift_labels = input_ids[..., 1:].contiguous()
|
1685 |
-
|
1686 |
# mask out PAD tokens in the shift_labels
|
1687 |
mask = (shift_labels == 50283)
|
1688 |
shift_labels = torch.where(mask, torch.tensor(-100, device=shift_labels.device), shift_labels)
|
1689 |
-
assert shift_labels.shape ==
|
1690 |
|
1691 |
# For both cases, we'll use the shifted input_ids as our labels
|
1692 |
labels = shift_labels
|
1693 |
|
1694 |
# Flatten the tokens
|
1695 |
-
loss = self.loss_fn(
|
1696 |
-
logits.view(-1, logits.size(-1)),
|
1697 |
-
shift_labels.view(-1)
|
1698 |
-
)
|
1699 |
|
1700 |
if self.pad_logits:
|
1701 |
return CausalLMOutput(
|
1702 |
loss=loss,
|
1703 |
logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
|
1704 |
-
hidden_states=
|
1705 |
attentions=None,
|
1706 |
)
|
1707 |
else:
|
|
|
1666 |
loss = None
|
1667 |
if labels is not None:
|
1668 |
if cu_seqlens is not None:
|
1669 |
+
shift_labels = input_ids[1:].clone()
|
1670 |
+
loss_logits = logits[:-1] # Only shift for loss
|
1671 |
|
1672 |
# Mask boundaries, so eos doesn't predict bos
|
1673 |
for i in range(len(cu_seqlens) - 1):
|
1674 |
boundary_pos = cu_seqlens[i+1] - 1
|
1675 |
+
if boundary_pos < len(shift_labels):
|
1676 |
+
shift_labels[boundary_pos] = -100
|
1677 |
|
1678 |
# NOTE: no padding or mask in there for now
|
1679 |
assert 50283 not in shift_labels, f"PAD token found in shift_labels: {shift_labels}"
|
1680 |
assert 50284 not in shift_labels, f"MASK token found in shift_labels: {shift_labels}"
|
1681 |
+
assert shift_labels.shape[0] == loss_logits.shape[0] # Verify shapes align
|
|
|
1682 |
else:
|
1683 |
# Padded case: simple shift
|
1684 |
shift_labels = input_ids[..., 1:].contiguous()
|
1685 |
+
loss_logits = logits[..., :-1, :].contiguous()
|
1686 |
# mask out PAD tokens in the shift_labels
|
1687 |
mask = (shift_labels == 50283)
|
1688 |
shift_labels = torch.where(mask, torch.tensor(-100, device=shift_labels.device), shift_labels)
|
1689 |
+
assert shift_labels.shape[0] == loss_logits.shape[0] # Verify shapes align
|
1690 |
|
1691 |
# For both cases, we'll use the shifted input_ids as our labels
|
1692 |
labels = shift_labels
|
1693 |
|
1694 |
# Flatten the tokens
|
1695 |
+
loss = self.loss_fn(loss_logits.view(-1, loss_logits.size(-1)), shift_labels.view(-1))
|
|
|
|
|
|
|
1696 |
|
1697 |
if self.pad_logits:
|
1698 |
return CausalLMOutput(
|
1699 |
loss=loss,
|
1700 |
logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
|
1701 |
+
hidden_states=hidden_states,
|
1702 |
attentions=None,
|
1703 |
)
|
1704 |
else:
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 598685038
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:823aa77eddb7f9291beddc92b7d093b34962c129c0f6d674b4390f4f54441081
|
3 |
size 598685038
|