Upload 7 files
Browse files- modeling_indictrans.py +1 -1
modeling_indictrans.py
CHANGED
@@ -1249,7 +1249,7 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1249 |
# move labels to the correct device to enable PP
|
1250 |
labels = labels.to(lm_logits.device)
|
1251 |
loss_fct = nn.CrossEntropyLoss()
|
1252 |
-
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.
|
1253 |
|
1254 |
if not return_dict:
|
1255 |
output = (lm_logits,) + outputs[1:]
|
|
|
1249 |
# move labels to the correct device to enable PP
|
1250 |
labels = labels.to(lm_logits.device)
|
1251 |
loss_fct = nn.CrossEntropyLoss()
|
1252 |
+
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.reshape(-1))
|
1253 |
|
1254 |
if not return_dict:
|
1255 |
output = (lm_logits,) + outputs[1:]
|