Upload 7 files
Browse files- modeling_indictrans.py +2 -4
modeling_indictrans.py
CHANGED
@@ -64,7 +64,7 @@ def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
|
|
64 |
new_decoder_input_ids = decoder_input_ids.clone().detach()
|
65 |
new_decoder_attention_mask = decoder_attention_mask.clone().detach()
|
66 |
|
67 |
-
labels = torch.full(new_decoder_input_ids.size()
|
68 |
labels[:, :-1] = new_decoder_input_ids[:, 1:]
|
69 |
|
70 |
labels_mask = labels == 1
|
@@ -74,9 +74,7 @@ def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
|
|
74 |
new_decoder_input_ids[mask] = 1
|
75 |
new_decoder_attention_mask[mask] = 0
|
76 |
|
77 |
-
return new_decoder_input_ids, new_decoder_attention_mask, labels
|
78 |
-
|
79 |
-
|
80 |
|
81 |
|
82 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
|
|
64 |
new_decoder_input_ids = decoder_input_ids.clone().detach()
|
65 |
new_decoder_attention_mask = decoder_attention_mask.clone().detach()
|
66 |
|
67 |
+
labels = torch.full(new_decoder_input_ids.size(), -100)
|
68 |
labels[:, :-1] = new_decoder_input_ids[:, 1:]
|
69 |
|
70 |
labels_mask = labels == 1
|
|
|
74 |
new_decoder_input_ids[mask] = 1
|
75 |
new_decoder_attention_mask[mask] = 0
|
76 |
|
77 |
+
return new_decoder_input_ids[:, :-1], new_decoder_attention_mask[:, :-1], labels[:, :-1]
|
|
|
|
|
78 |
|
79 |
|
80 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|