Raghavan commited on
Commit
f6f7bf9
1 Parent(s): 9fd6e74

Upload 7 files

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +115 -104
modeling_indictrans.py CHANGED
@@ -14,7 +14,6 @@
14
  # limitations under the License.
15
  """ PyTorch IndicTrans model."""
16
 
17
-
18
  import math
19
  from typing import List, Optional, Tuple, Union
20
 
@@ -36,7 +35,6 @@ from transformers.modeling_utils import PreTrainedModel
36
 
37
  from .configuration_indictrans import IndicTransConfig
38
 
39
-
40
  logger = logging.get_logger(__name__)
41
 
42
  _CONFIG_FOR_DOC = "IndicTransConfig"
@@ -63,7 +61,7 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
63
 
64
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
65
  def _make_causal_mask(
66
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
67
  ):
68
  """
69
  Make causal mask used for bi-directional self-attention.
@@ -147,7 +145,7 @@ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
147
 
148
  @torch.no_grad()
149
  def forward(
150
- self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0
151
  ):
152
  if input_ids is not None:
153
  bsz, seq_len = input_ids.size()
@@ -189,12 +187,12 @@ class IndicTransAttention(nn.Module):
189
  """Multi-headed attention from 'Attention Is All You Need' paper"""
190
 
191
  def __init__(
192
- self,
193
- embed_dim: int,
194
- num_heads: int,
195
- dropout: float = 0.0,
196
- is_decoder: bool = False,
197
- bias: bool = True,
198
  ):
199
  super().__init__()
200
  self.embed_dim = embed_dim
@@ -207,7 +205,7 @@ class IndicTransAttention(nn.Module):
207
  f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
208
  f" and `num_heads`: {num_heads})."
209
  )
210
- self.scaling = self.head_dim**-0.5
211
  self.is_decoder = is_decoder
212
 
213
  self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@@ -219,13 +217,13 @@ class IndicTransAttention(nn.Module):
219
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
220
 
221
  def forward(
222
- self,
223
- hidden_states: torch.Tensor,
224
- key_value_states: Optional[torch.Tensor] = None,
225
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
226
- attention_mask: Optional[torch.Tensor] = None,
227
- layer_head_mask: Optional[torch.Tensor] = None,
228
- output_attentions: bool = False,
229
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
230
  """Input shape: Batch x Time x Channel"""
231
 
@@ -242,9 +240,9 @@ class IndicTransAttention(nn.Module):
242
  # is checking that the `sequence_length` of the `past_key_value` is the same as
243
  # the provided `key_value_states` to support prefix tuning
244
  if (
245
- is_cross_attention
246
- and past_key_value is not None
247
- and past_key_value[0].shape[2] == key_value_states.shape[1]
248
  ):
249
  # reuse k,v, cross_attentions
250
  key_states = past_key_value[0]
@@ -359,11 +357,11 @@ class IndicTransEncoderLayer(nn.Module):
359
  self.normalize_before = config.encoder_normalize_before
360
 
361
  def forward(
362
- self,
363
- hidden_states: torch.Tensor,
364
- attention_mask: torch.Tensor,
365
- layer_head_mask: torch.Tensor,
366
- output_attentions: bool = False,
367
  ) -> torch.Tensor:
368
  """
369
  Args:
@@ -402,7 +400,7 @@ class IndicTransEncoderLayer(nn.Module):
402
  hidden_states = self.final_layer_norm(hidden_states)
403
 
404
  if hidden_states.dtype == torch.float16 and (
405
- torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
406
  ):
407
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
408
  hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
@@ -445,16 +443,16 @@ class IndicTransDecoderLayer(nn.Module):
445
  self.normalize_before = config.decoder_normalize_before
446
 
447
  def forward(
448
- self,
449
- hidden_states: torch.Tensor,
450
- attention_mask: Optional[torch.Tensor] = None,
451
- encoder_hidden_states: Optional[torch.Tensor] = None,
452
- encoder_attention_mask: Optional[torch.Tensor] = None,
453
- layer_head_mask: Optional[torch.Tensor] = None,
454
- cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
455
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
456
- output_attentions: Optional[bool] = False,
457
- use_cache: Optional[bool] = True,
458
  ) -> torch.Tensor:
459
  """
460
  Args:
@@ -606,15 +604,26 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
606
  # Initialize weights and apply final processing
607
  self.post_init()
608
 
 
 
 
 
 
 
 
 
 
 
 
609
  def forward(
610
- self,
611
- input_ids: Optional[torch.Tensor] = None,
612
- attention_mask: Optional[torch.Tensor] = None,
613
- head_mask: Optional[torch.Tensor] = None,
614
- inputs_embeds: Optional[torch.Tensor] = None,
615
- output_attentions: Optional[bool] = None,
616
- output_hidden_states: Optional[bool] = None,
617
- return_dict: Optional[bool] = None,
618
  ):
619
  r"""
620
  Args:
@@ -745,6 +754,8 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
745
  if output_hidden_states:
746
  encoder_states = encoder_states + (hidden_states,)
747
 
 
 
748
  if not return_dict:
749
  return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
750
  return BaseModelOutput(
@@ -791,19 +802,19 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
791
  self.post_init()
792
 
793
  def forward(
794
- self,
795
- input_ids: Optional[torch.Tensor] = None,
796
- attention_mask: Optional[torch.Tensor] = None,
797
- encoder_hidden_states: Optional[torch.Tensor] = None,
798
- encoder_attention_mask: Optional[torch.Tensor] = None,
799
- head_mask: Optional[torch.Tensor] = None,
800
- cross_attn_head_mask: Optional[torch.Tensor] = None,
801
- past_key_values: Optional[List[torch.FloatTensor]] = None,
802
- inputs_embeds: Optional[torch.Tensor] = None,
803
- use_cache: Optional[bool] = None,
804
- output_attentions: Optional[bool] = None,
805
- output_hidden_states: Optional[bool] = None,
806
- return_dict: Optional[bool] = None,
807
  ):
808
  r"""
809
  Args:
@@ -1037,7 +1048,7 @@ class IndicTransModel(IndicTransPreTrainedModel):
1037
 
1038
  def __init__(self, config: IndicTransConfig):
1039
  super().__init__(config)
1040
-
1041
  self.encoder = IndicTransEncoder(config)
1042
  self.decoder = IndicTransDecoder(config)
1043
 
@@ -1051,22 +1062,22 @@ class IndicTransModel(IndicTransPreTrainedModel):
1051
  return self.decoder
1052
 
1053
  def forward(
1054
- self,
1055
- input_ids: Optional[torch.LongTensor] = None,
1056
- attention_mask: Optional[torch.Tensor] = None,
1057
- decoder_input_ids: Optional[torch.LongTensor] = None,
1058
- decoder_attention_mask: Optional[torch.LongTensor] = None,
1059
- head_mask: Optional[torch.Tensor] = None,
1060
- decoder_head_mask: Optional[torch.Tensor] = None,
1061
- cross_attn_head_mask: Optional[torch.Tensor] = None,
1062
- encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1063
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1064
- inputs_embeds: Optional[torch.FloatTensor] = None,
1065
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1066
- use_cache: Optional[bool] = None,
1067
- output_attentions: Optional[bool] = None,
1068
- output_hidden_states: Optional[bool] = None,
1069
- return_dict: Optional[bool] = None,
1070
  ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
1071
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1072
  output_hidden_states = (
@@ -1136,9 +1147,9 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1136
 
1137
  if config.share_decoder_input_output_embed:
1138
  self.lm_head.weight = self.model.decoder.embed_tokens.weight
1139
-
1140
  self.post_init()
1141
-
1142
  def tie_weights(self):
1143
  pass
1144
 
@@ -1155,23 +1166,23 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1155
  self.lm_head = new_embeddings
1156
 
1157
  def forward(
1158
- self,
1159
- input_ids: Optional[torch.LongTensor] = None,
1160
- attention_mask: Optional[torch.Tensor] = None,
1161
- decoder_input_ids: Optional[torch.LongTensor] = None,
1162
- decoder_attention_mask: Optional[torch.LongTensor] = None,
1163
- head_mask: Optional[torch.Tensor] = None,
1164
- decoder_head_mask: Optional[torch.Tensor] = None,
1165
- cross_attn_head_mask: Optional[torch.Tensor] = None,
1166
- encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1167
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1168
- inputs_embeds: Optional[torch.FloatTensor] = None,
1169
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1170
- labels: Optional[torch.LongTensor] = None,
1171
- use_cache: Optional[bool] = None,
1172
- output_attentions: Optional[bool] = None,
1173
- output_hidden_states: Optional[bool] = None,
1174
- return_dict: Optional[bool] = None,
1175
  ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
1176
  r"""
1177
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1232,16 +1243,16 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1232
  )
1233
 
1234
  def prepare_inputs_for_generation(
1235
- self,
1236
- decoder_input_ids,
1237
- past_key_values=None,
1238
- attention_mask=None,
1239
- head_mask=None,
1240
- decoder_head_mask=None,
1241
- cross_attn_head_mask=None,
1242
- use_cache=None,
1243
- encoder_outputs=None,
1244
- **kwargs,
1245
  ):
1246
  # cut decoder_input_ids if past is used
1247
  if past_key_values is not None:
 
14
  # limitations under the License.
15
  """ PyTorch IndicTrans model."""
16
 
 
17
  import math
18
  from typing import List, Optional, Tuple, Union
19
 
 
35
 
36
  from .configuration_indictrans import IndicTransConfig
37
 
 
38
  logger = logging.get_logger(__name__)
39
 
40
  _CONFIG_FOR_DOC = "IndicTransConfig"
 
61
 
62
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
63
  def _make_causal_mask(
64
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
65
  ):
66
  """
67
  Make causal mask used for bi-directional self-attention.
 
145
 
146
  @torch.no_grad()
147
  def forward(
148
+ self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0
149
  ):
150
  if input_ids is not None:
151
  bsz, seq_len = input_ids.size()
 
187
  """Multi-headed attention from 'Attention Is All You Need' paper"""
188
 
189
  def __init__(
190
+ self,
191
+ embed_dim: int,
192
+ num_heads: int,
193
+ dropout: float = 0.0,
194
+ is_decoder: bool = False,
195
+ bias: bool = True,
196
  ):
197
  super().__init__()
198
  self.embed_dim = embed_dim
 
205
  f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
206
  f" and `num_heads`: {num_heads})."
207
  )
208
+ self.scaling = self.head_dim ** -0.5
209
  self.is_decoder = is_decoder
210
 
211
  self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
 
217
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
218
 
219
  def forward(
220
+ self,
221
+ hidden_states: torch.Tensor,
222
+ key_value_states: Optional[torch.Tensor] = None,
223
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
224
+ attention_mask: Optional[torch.Tensor] = None,
225
+ layer_head_mask: Optional[torch.Tensor] = None,
226
+ output_attentions: bool = False,
227
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
228
  """Input shape: Batch x Time x Channel"""
229
 
 
240
  # is checking that the `sequence_length` of the `past_key_value` is the same as
241
  # the provided `key_value_states` to support prefix tuning
242
  if (
243
+ is_cross_attention
244
+ and past_key_value is not None
245
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
246
  ):
247
  # reuse k,v, cross_attentions
248
  key_states = past_key_value[0]
 
357
  self.normalize_before = config.encoder_normalize_before
358
 
359
  def forward(
360
+ self,
361
+ hidden_states: torch.Tensor,
362
+ attention_mask: torch.Tensor,
363
+ layer_head_mask: torch.Tensor,
364
+ output_attentions: bool = False,
365
  ) -> torch.Tensor:
366
  """
367
  Args:
 
400
  hidden_states = self.final_layer_norm(hidden_states)
401
 
402
  if hidden_states.dtype == torch.float16 and (
403
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
404
  ):
405
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
406
  hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
 
443
  self.normalize_before = config.decoder_normalize_before
444
 
445
  def forward(
446
+ self,
447
+ hidden_states: torch.Tensor,
448
+ attention_mask: Optional[torch.Tensor] = None,
449
+ encoder_hidden_states: Optional[torch.Tensor] = None,
450
+ encoder_attention_mask: Optional[torch.Tensor] = None,
451
+ layer_head_mask: Optional[torch.Tensor] = None,
452
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
453
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
454
+ output_attentions: Optional[bool] = False,
455
+ use_cache: Optional[bool] = True,
456
  ) -> torch.Tensor:
457
  """
458
  Args:
 
604
  # Initialize weights and apply final processing
605
  self.post_init()
606
 
607
+ def get_pooled_representation(self, hidden_states, attention_mask):
608
+ seqs = torch.clone(hidden_states)
609
+ seqs[attention_mask == 0] = 0
610
+ sentence_embedding = seqs.sum(dim=1)
611
+ weights = 1.0 / ((attention_mask != 0).float().sum(dim=1) + 1e-7)
612
+
613
+ sentence_embedding = torch.einsum(
614
+ "i...,i ->i...", sentence_embedding, weights
615
+ )
616
+ return sentence_embedding
617
+
618
  def forward(
619
+ self,
620
+ input_ids: Optional[torch.Tensor] = None,
621
+ attention_mask: Optional[torch.Tensor] = None,
622
+ head_mask: Optional[torch.Tensor] = None,
623
+ inputs_embeds: Optional[torch.Tensor] = None,
624
+ output_attentions: Optional[bool] = None,
625
+ output_hidden_states: Optional[bool] = None,
626
+ return_dict: Optional[bool] = None,
627
  ):
628
  r"""
629
  Args:
 
754
  if output_hidden_states:
755
  encoder_states = encoder_states + (hidden_states,)
756
 
757
+ hidden_states = self.get_pooled_representation(hidden_states, attention_mask)
758
+
759
  if not return_dict:
760
  return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
761
  return BaseModelOutput(
 
802
  self.post_init()
803
 
804
  def forward(
805
+ self,
806
+ input_ids: Optional[torch.Tensor] = None,
807
+ attention_mask: Optional[torch.Tensor] = None,
808
+ encoder_hidden_states: Optional[torch.Tensor] = None,
809
+ encoder_attention_mask: Optional[torch.Tensor] = None,
810
+ head_mask: Optional[torch.Tensor] = None,
811
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
812
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
813
+ inputs_embeds: Optional[torch.Tensor] = None,
814
+ use_cache: Optional[bool] = None,
815
+ output_attentions: Optional[bool] = None,
816
+ output_hidden_states: Optional[bool] = None,
817
+ return_dict: Optional[bool] = None,
818
  ):
819
  r"""
820
  Args:
 
1048
 
1049
  def __init__(self, config: IndicTransConfig):
1050
  super().__init__(config)
1051
+
1052
  self.encoder = IndicTransEncoder(config)
1053
  self.decoder = IndicTransDecoder(config)
1054
 
 
1062
  return self.decoder
1063
 
1064
  def forward(
1065
+ self,
1066
+ input_ids: Optional[torch.LongTensor] = None,
1067
+ attention_mask: Optional[torch.Tensor] = None,
1068
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1069
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1070
+ head_mask: Optional[torch.Tensor] = None,
1071
+ decoder_head_mask: Optional[torch.Tensor] = None,
1072
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1073
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1074
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1075
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1076
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1077
+ use_cache: Optional[bool] = None,
1078
+ output_attentions: Optional[bool] = None,
1079
+ output_hidden_states: Optional[bool] = None,
1080
+ return_dict: Optional[bool] = None,
1081
  ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
1082
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1083
  output_hidden_states = (
 
1147
 
1148
  if config.share_decoder_input_output_embed:
1149
  self.lm_head.weight = self.model.decoder.embed_tokens.weight
1150
+
1151
  self.post_init()
1152
+
1153
  def tie_weights(self):
1154
  pass
1155
 
 
1166
  self.lm_head = new_embeddings
1167
 
1168
  def forward(
1169
+ self,
1170
+ input_ids: Optional[torch.LongTensor] = None,
1171
+ attention_mask: Optional[torch.Tensor] = None,
1172
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1173
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1174
+ head_mask: Optional[torch.Tensor] = None,
1175
+ decoder_head_mask: Optional[torch.Tensor] = None,
1176
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1177
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1178
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1179
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1180
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1181
+ labels: Optional[torch.LongTensor] = None,
1182
+ use_cache: Optional[bool] = None,
1183
+ output_attentions: Optional[bool] = None,
1184
+ output_hidden_states: Optional[bool] = None,
1185
+ return_dict: Optional[bool] = None,
1186
  ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
1187
  r"""
1188
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
1243
  )
1244
 
1245
  def prepare_inputs_for_generation(
1246
+ self,
1247
+ decoder_input_ids,
1248
+ past_key_values=None,
1249
+ attention_mask=None,
1250
+ head_mask=None,
1251
+ decoder_head_mask=None,
1252
+ cross_attn_head_mask=None,
1253
+ use_cache=None,
1254
+ encoder_outputs=None,
1255
+ **kwargs,
1256
  ):
1257
  # cut decoder_input_ids if past is used
1258
  if past_key_values is not None: