oweller2 commited on
Commit
11a83af
1 Parent(s): f5a1962
Files changed (2) hide show
  1. modeling_flexbert.py +9 -12
  2. pytorch_model.bin +1 -1
modeling_flexbert.py CHANGED
@@ -1666,42 +1666,39 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1666
  loss = None
1667
  if labels is not None:
1668
  if cu_seqlens is not None:
1669
- shift_labels = torch.full_like(input_ids, -100)
1670
- shift_labels[:-1] = input_ids[1:]
1671
 
1672
  # Mask boundaries, so eos doesn't predict bos
1673
  for i in range(len(cu_seqlens) - 1):
1674
  boundary_pos = cu_seqlens[i+1] - 1
1675
- shift_labels[boundary_pos] = -100
 
1676
 
1677
  # NOTE: no padding or mask in there for now
1678
  assert 50283 not in shift_labels, f"PAD token found in shift_labels: {shift_labels}"
1679
  assert 50284 not in shift_labels, f"MASK token found in shift_labels: {shift_labels}"
1680
- assert shift_labels.shape == logits.shape[:-1] # Verify shapes align
1681
-
1682
  else:
1683
  # Padded case: simple shift
1684
  shift_labels = input_ids[..., 1:].contiguous()
1685
- logits = logits[..., :-1, :].contiguous()
1686
  # mask out PAD tokens in the shift_labels
1687
  mask = (shift_labels == 50283)
1688
  shift_labels = torch.where(mask, torch.tensor(-100, device=shift_labels.device), shift_labels)
1689
- assert shift_labels.shape == logits.shape[:-1] # Verify shapes align
1690
 
1691
  # For both cases, we'll use the shifted input_ids as our labels
1692
  labels = shift_labels
1693
 
1694
  # Flatten the tokens
1695
- loss = self.loss_fn(
1696
- logits.view(-1, logits.size(-1)),
1697
- shift_labels.view(-1)
1698
- )
1699
 
1700
  if self.pad_logits:
1701
  return CausalLMOutput(
1702
  loss=loss,
1703
  logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
1704
- hidden_states=None,
1705
  attentions=None,
1706
  )
1707
  else:
 
1666
  loss = None
1667
  if labels is not None:
1668
  if cu_seqlens is not None:
1669
+ shift_labels = input_ids[1:].clone()
1670
+ loss_logits = logits[:-1] # Only shift for loss
1671
 
1672
  # Mask boundaries, so eos doesn't predict bos
1673
  for i in range(len(cu_seqlens) - 1):
1674
  boundary_pos = cu_seqlens[i+1] - 1
1675
+ if boundary_pos < len(shift_labels):
1676
+ shift_labels[boundary_pos] = -100
1677
 
1678
  # NOTE: no padding or mask in there for now
1679
  assert 50283 not in shift_labels, f"PAD token found in shift_labels: {shift_labels}"
1680
  assert 50284 not in shift_labels, f"MASK token found in shift_labels: {shift_labels}"
1681
+ assert shift_labels.shape[0] == loss_logits.shape[0] # Verify shapes align
 
1682
  else:
1683
  # Padded case: simple shift
1684
  shift_labels = input_ids[..., 1:].contiguous()
1685
+ loss_logits = logits[..., :-1, :].contiguous()
1686
  # mask out PAD tokens in the shift_labels
1687
  mask = (shift_labels == 50283)
1688
  shift_labels = torch.where(mask, torch.tensor(-100, device=shift_labels.device), shift_labels)
1689
+ assert shift_labels.shape[0] == loss_logits.shape[0] # Verify shapes align
1690
 
1691
  # For both cases, we'll use the shifted input_ids as our labels
1692
  labels = shift_labels
1693
 
1694
  # Flatten the tokens
1695
+ loss = self.loss_fn(loss_logits.view(-1, loss_logits.size(-1)), shift_labels.view(-1))
 
 
 
1696
 
1697
  if self.pad_logits:
1698
  return CausalLMOutput(
1699
  loss=loss,
1700
  logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
1701
+ hidden_states=hidden_states,
1702
  attentions=None,
1703
  )
1704
  else:
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c0bf65e93c1438333b19d7cc744b62913df117b318d7acb300c23f6c202a00e0
3
  size 598685038
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:823aa77eddb7f9291beddc92b7d093b34962c129c0f6d674b4390f4f54441081
3
  size 598685038