Upload 7 files
Browse files- modeling_indictrans.py +14 -0
modeling_indictrans.py
CHANGED
@@ -40,6 +40,7 @@ logger = logging.get_logger(__name__)
|
|
40 |
_CONFIG_FOR_DOC = "IndicTransConfig"
|
41 |
|
42 |
INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
|
|
|
43 |
|
44 |
|
45 |
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
|
@@ -59,6 +60,16 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
|
|
59 |
return shifted_input_ids
|
60 |
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
63 |
def _make_causal_mask(
|
64 |
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
@@ -1206,6 +1217,9 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1206 |
# labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
1207 |
# )
|
1208 |
|
|
|
|
|
|
|
1209 |
outputs = self.model(
|
1210 |
input_ids,
|
1211 |
attention_mask=attention_mask,
|
|
|
40 |
_CONFIG_FOR_DOC = "IndicTransConfig"
|
41 |
|
42 |
INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
|
43 |
+
eos_token_id = 2
|
44 |
|
45 |
|
46 |
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
|
|
|
60 |
return shifted_input_ids
|
61 |
|
62 |
|
63 |
+
def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
|
64 |
+
mask = (decoder_input_ids == eos_token_id)
|
65 |
+
decoder_input_ids[mask] = 1
|
66 |
+
decoder_attention_mask[mask] = 0
|
67 |
+
|
68 |
+
labels = decoder_input_ids[:, 1:]
|
69 |
+
|
70 |
+
return decoder_input_ids, decoder_attention_mask, labels
|
71 |
+
|
72 |
+
|
73 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
74 |
def _make_causal_mask(
|
75 |
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
|
|
1217 |
# labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
1218 |
# )
|
1219 |
|
1220 |
+
decoder_input_ids, decoder_attention_mask, labels = prepare_decoder_input_ids_label(decoder_input_ids,
|
1221 |
+
decoder_attention_mask)
|
1222 |
+
|
1223 |
outputs = self.model(
|
1224 |
input_ids,
|
1225 |
attention_mask=attention_mask,
|