Raghavan commited on
Commit
02d1844
1 Parent(s): a48eebf

Upload 7 files

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +1 -1
modeling_indictrans.py CHANGED
@@ -1249,7 +1249,7 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1249
  # move labels to the correct device to enable PP
1250
  labels = labels.to(lm_logits.device)
1251
  loss_fct = nn.CrossEntropyLoss()
1252
- masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.view(-1))
1253
 
1254
  if not return_dict:
1255
  output = (lm_logits,) + outputs[1:]
 
1249
  # move labels to the correct device to enable PP
1250
  labels = labels.to(lm_logits.device)
1251
  loss_fct = nn.CrossEntropyLoss()
1252
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.reshape(-1))
1253
 
1254
  if not return_dict:
1255
  output = (lm_logits,) + outputs[1:]