Raghavan commited on
Commit
1c5c454
1 Parent(s): b1e9935

Upload 7 files

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +2 -4
modeling_indictrans.py CHANGED
@@ -64,7 +64,7 @@ def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
64
  new_decoder_input_ids = decoder_input_ids.clone().detach()
65
  new_decoder_attention_mask = decoder_attention_mask.clone().detach()
66
 
67
- labels = torch.full(new_decoder_input_ids.size(),-100)
68
  labels[:, :-1] = new_decoder_input_ids[:, 1:]
69
 
70
  labels_mask = labels == 1
@@ -74,9 +74,7 @@ def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
74
  new_decoder_input_ids[mask] = 1
75
  new_decoder_attention_mask[mask] = 0
76
 
77
- return new_decoder_input_ids, new_decoder_attention_mask, labels
78
-
79
-
80
 
81
 
82
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
 
64
  new_decoder_input_ids = decoder_input_ids.clone().detach()
65
  new_decoder_attention_mask = decoder_attention_mask.clone().detach()
66
 
67
+ labels = torch.full(new_decoder_input_ids.size(), -100)
68
  labels[:, :-1] = new_decoder_input_ids[:, 1:]
69
 
70
  labels_mask = labels == 1
 
74
  new_decoder_input_ids[mask] = 1
75
  new_decoder_attention_mask[mask] = 0
76
 
77
+ return new_decoder_input_ids[:, :-1], new_decoder_attention_mask[:, :-1], labels[:, :-1]
 
 
78
 
79
 
80
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask