Raghavan commited on
Commit
edfd39e
1 Parent(s): e67c1c1

Upload 7 files

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +11 -5
modeling_indictrans.py CHANGED
@@ -801,6 +801,12 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
801
  # Initialize weights and apply final processing
802
  self.post_init()
803
 
 
 
 
 
 
 
804
  def forward(
805
  self,
806
  input_ids: Optional[torch.Tensor] = None,
@@ -1194,11 +1200,11 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1194
  """
1195
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1196
 
1197
- if labels is not None:
1198
- if decoder_input_ids is None:
1199
- decoder_input_ids = shift_tokens_right(
1200
- labels, self.config.pad_token_id, self.config.decoder_start_token_id
1201
- )
1202
 
1203
  outputs = self.model(
1204
  input_ids,
 
801
  # Initialize weights and apply final processing
802
  self.post_init()
803
 
804
+ def get_input_embeddings(self):
805
+ return self.embed_tokens.word_embeddings
806
+
807
+ def set_input_embeddings(self, value):
808
+ self.embed_tokens.word_embeddings = value
809
+
810
  def forward(
811
  self,
812
  input_ids: Optional[torch.Tensor] = None,
 
1200
  """
1201
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1202
 
1203
+ # if labels is not None:
1204
+ # if decoder_input_ids is None:
1205
+ # decoder_input_ids = shift_tokens_right(
1206
+ # labels, self.config.pad_token_id, self.config.decoder_start_token_id
1207
+ # )
1208
 
1209
  outputs = self.model(
1210
  input_ids,