Crystalcareai commited on
Commit
89c0aaa
1 Parent(s): 2f71525

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +5 -2
modeling_quiet.py CHANGED
@@ -1847,8 +1847,11 @@ class QuietForCausalLM(QuietPreTrainedModel):
1847
  [shift_labels, padding],
1848
  dim=-1
1849
  )
1850
- probabilities_2d = F.one_hot(new_rm_tokens, num_classes=self.vocab_size).reshape(-1, self.vocab_size).to(probabilities_2d.dtype).to(probabilities_2d.device)
1851
- skip_sampling = True
 
 
 
1852
  else:
1853
  continue
1854
  temperature = self.gumbel_temperature if self.training else 0.001
 
1847
  [shift_labels, padding],
1848
  dim=-1
1849
  )
1850
+ # Before converting rm tokens to one-hot, clamp the values to ensure they are within the valid index range
1851
+ new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
1852
+
1853
+ # Now safely convert rm tokens to one-hot
1854
+ probabilities_2d = F.one_hot(new_rm_tokens, num_classes=self.vocab_size).reshape(-1, self.vocab_size).to(probabilities_2d.dtype)
1855
  else:
1856
  continue
1857
  temperature = self.gumbel_temperature if self.training else 0.001