Upload 7 files
Browse files- modeling_indictrans.py +11 -5
modeling_indictrans.py
CHANGED
@@ -801,6 +801,12 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
|
|
801 |
# Initialize weights and apply final processing
|
802 |
self.post_init()
|
803 |
|
|
|
|
|
|
|
|
|
|
|
|
|
804 |
def forward(
|
805 |
self,
|
806 |
input_ids: Optional[torch.Tensor] = None,
|
@@ -1194,11 +1200,11 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1194 |
"""
|
1195 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1196 |
|
1197 |
-
if labels is not None:
|
1198 |
-
|
1199 |
-
|
1200 |
-
|
1201 |
-
|
1202 |
|
1203 |
outputs = self.model(
|
1204 |
input_ids,
|
|
|
801 |
# Initialize weights and apply final processing
|
802 |
self.post_init()
|
803 |
|
804 |
+
def get_input_embeddings(self):
|
805 |
+
return self.embed_tokens.word_embeddings
|
806 |
+
|
807 |
+
def set_input_embeddings(self, value):
|
808 |
+
self.embed_tokens.word_embeddings = value
|
809 |
+
|
810 |
def forward(
|
811 |
self,
|
812 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
1200 |
"""
|
1201 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1202 |
|
1203 |
+
# if labels is not None:
|
1204 |
+
# if decoder_input_ids is None:
|
1205 |
+
# decoder_input_ids = shift_tokens_right(
|
1206 |
+
# labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
1207 |
+
# )
|
1208 |
|
1209 |
outputs = self.model(
|
1210 |
input_ids,
|