Upload 7 files
Browse files- modeling_indictrans.py +115 -104
modeling_indictrans.py
CHANGED
@@ -14,7 +14,6 @@
|
|
14 |
# limitations under the License.
|
15 |
""" PyTorch IndicTrans model."""
|
16 |
|
17 |
-
|
18 |
import math
|
19 |
from typing import List, Optional, Tuple, Union
|
20 |
|
@@ -36,7 +35,6 @@ from transformers.modeling_utils import PreTrainedModel
|
|
36 |
|
37 |
from .configuration_indictrans import IndicTransConfig
|
38 |
|
39 |
-
|
40 |
logger = logging.get_logger(__name__)
|
41 |
|
42 |
_CONFIG_FOR_DOC = "IndicTransConfig"
|
@@ -63,7 +61,7 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
|
|
63 |
|
64 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
65 |
def _make_causal_mask(
|
66 |
-
|
67 |
):
|
68 |
"""
|
69 |
Make causal mask used for bi-directional self-attention.
|
@@ -147,7 +145,7 @@ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
|
|
147 |
|
148 |
@torch.no_grad()
|
149 |
def forward(
|
150 |
-
|
151 |
):
|
152 |
if input_ids is not None:
|
153 |
bsz, seq_len = input_ids.size()
|
@@ -189,12 +187,12 @@ class IndicTransAttention(nn.Module):
|
|
189 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
190 |
|
191 |
def __init__(
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
):
|
199 |
super().__init__()
|
200 |
self.embed_dim = embed_dim
|
@@ -207,7 +205,7 @@ class IndicTransAttention(nn.Module):
|
|
207 |
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
208 |
f" and `num_heads`: {num_heads})."
|
209 |
)
|
210 |
-
self.scaling = self.head_dim
|
211 |
self.is_decoder = is_decoder
|
212 |
|
213 |
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
@@ -219,13 +217,13 @@ class IndicTransAttention(nn.Module):
|
|
219 |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
220 |
|
221 |
def forward(
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
230 |
"""Input shape: Batch x Time x Channel"""
|
231 |
|
@@ -242,9 +240,9 @@ class IndicTransAttention(nn.Module):
|
|
242 |
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
243 |
# the provided `key_value_states` to support prefix tuning
|
244 |
if (
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
):
|
249 |
# reuse k,v, cross_attentions
|
250 |
key_states = past_key_value[0]
|
@@ -359,11 +357,11 @@ class IndicTransEncoderLayer(nn.Module):
|
|
359 |
self.normalize_before = config.encoder_normalize_before
|
360 |
|
361 |
def forward(
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
) -> torch.Tensor:
|
368 |
"""
|
369 |
Args:
|
@@ -402,7 +400,7 @@ class IndicTransEncoderLayer(nn.Module):
|
|
402 |
hidden_states = self.final_layer_norm(hidden_states)
|
403 |
|
404 |
if hidden_states.dtype == torch.float16 and (
|
405 |
-
|
406 |
):
|
407 |
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
408 |
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
@@ -445,16 +443,16 @@ class IndicTransDecoderLayer(nn.Module):
|
|
445 |
self.normalize_before = config.decoder_normalize_before
|
446 |
|
447 |
def forward(
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
) -> torch.Tensor:
|
459 |
"""
|
460 |
Args:
|
@@ -606,15 +604,26 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
|
|
606 |
# Initialize weights and apply final processing
|
607 |
self.post_init()
|
608 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
609 |
def forward(
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
):
|
619 |
r"""
|
620 |
Args:
|
@@ -745,6 +754,8 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
|
|
745 |
if output_hidden_states:
|
746 |
encoder_states = encoder_states + (hidden_states,)
|
747 |
|
|
|
|
|
748 |
if not return_dict:
|
749 |
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
750 |
return BaseModelOutput(
|
@@ -791,19 +802,19 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
|
|
791 |
self.post_init()
|
792 |
|
793 |
def forward(
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
):
|
808 |
r"""
|
809 |
Args:
|
@@ -1037,7 +1048,7 @@ class IndicTransModel(IndicTransPreTrainedModel):
|
|
1037 |
|
1038 |
def __init__(self, config: IndicTransConfig):
|
1039 |
super().__init__(config)
|
1040 |
-
|
1041 |
self.encoder = IndicTransEncoder(config)
|
1042 |
self.decoder = IndicTransDecoder(config)
|
1043 |
|
@@ -1051,22 +1062,22 @@ class IndicTransModel(IndicTransPreTrainedModel):
|
|
1051 |
return self.decoder
|
1052 |
|
1053 |
def forward(
|
1054 |
-
|
1055 |
-
|
1056 |
-
|
1057 |
-
|
1058 |
-
|
1059 |
-
|
1060 |
-
|
1061 |
-
|
1062 |
-
|
1063 |
-
|
1064 |
-
|
1065 |
-
|
1066 |
-
|
1067 |
-
|
1068 |
-
|
1069 |
-
|
1070 |
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
|
1071 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1072 |
output_hidden_states = (
|
@@ -1136,9 +1147,9 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1136 |
|
1137 |
if config.share_decoder_input_output_embed:
|
1138 |
self.lm_head.weight = self.model.decoder.embed_tokens.weight
|
1139 |
-
|
1140 |
self.post_init()
|
1141 |
-
|
1142 |
def tie_weights(self):
|
1143 |
pass
|
1144 |
|
@@ -1155,23 +1166,23 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1155 |
self.lm_head = new_embeddings
|
1156 |
|
1157 |
def forward(
|
1158 |
-
|
1159 |
-
|
1160 |
-
|
1161 |
-
|
1162 |
-
|
1163 |
-
|
1164 |
-
|
1165 |
-
|
1166 |
-
|
1167 |
-
|
1168 |
-
|
1169 |
-
|
1170 |
-
|
1171 |
-
|
1172 |
-
|
1173 |
-
|
1174 |
-
|
1175 |
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
1176 |
r"""
|
1177 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
@@ -1232,16 +1243,16 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1232 |
)
|
1233 |
|
1234 |
def prepare_inputs_for_generation(
|
1235 |
-
|
1236 |
-
|
1237 |
-
|
1238 |
-
|
1239 |
-
|
1240 |
-
|
1241 |
-
|
1242 |
-
|
1243 |
-
|
1244 |
-
|
1245 |
):
|
1246 |
# cut decoder_input_ids if past is used
|
1247 |
if past_key_values is not None:
|
|
|
14 |
# limitations under the License.
|
15 |
""" PyTorch IndicTrans model."""
|
16 |
|
|
|
17 |
import math
|
18 |
from typing import List, Optional, Tuple, Union
|
19 |
|
|
|
35 |
|
36 |
from .configuration_indictrans import IndicTransConfig
|
37 |
|
|
|
38 |
logger = logging.get_logger(__name__)
|
39 |
|
40 |
_CONFIG_FOR_DOC = "IndicTransConfig"
|
|
|
61 |
|
62 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
63 |
def _make_causal_mask(
|
64 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
65 |
):
|
66 |
"""
|
67 |
Make causal mask used for bi-directional self-attention.
|
|
|
145 |
|
146 |
@torch.no_grad()
|
147 |
def forward(
|
148 |
+
self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0
|
149 |
):
|
150 |
if input_ids is not None:
|
151 |
bsz, seq_len = input_ids.size()
|
|
|
187 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
188 |
|
189 |
def __init__(
|
190 |
+
self,
|
191 |
+
embed_dim: int,
|
192 |
+
num_heads: int,
|
193 |
+
dropout: float = 0.0,
|
194 |
+
is_decoder: bool = False,
|
195 |
+
bias: bool = True,
|
196 |
):
|
197 |
super().__init__()
|
198 |
self.embed_dim = embed_dim
|
|
|
205 |
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
206 |
f" and `num_heads`: {num_heads})."
|
207 |
)
|
208 |
+
self.scaling = self.head_dim ** -0.5
|
209 |
self.is_decoder = is_decoder
|
210 |
|
211 |
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
|
217 |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
218 |
|
219 |
def forward(
|
220 |
+
self,
|
221 |
+
hidden_states: torch.Tensor,
|
222 |
+
key_value_states: Optional[torch.Tensor] = None,
|
223 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
224 |
+
attention_mask: Optional[torch.Tensor] = None,
|
225 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
226 |
+
output_attentions: bool = False,
|
227 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
228 |
"""Input shape: Batch x Time x Channel"""
|
229 |
|
|
|
240 |
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
241 |
# the provided `key_value_states` to support prefix tuning
|
242 |
if (
|
243 |
+
is_cross_attention
|
244 |
+
and past_key_value is not None
|
245 |
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
246 |
):
|
247 |
# reuse k,v, cross_attentions
|
248 |
key_states = past_key_value[0]
|
|
|
357 |
self.normalize_before = config.encoder_normalize_before
|
358 |
|
359 |
def forward(
|
360 |
+
self,
|
361 |
+
hidden_states: torch.Tensor,
|
362 |
+
attention_mask: torch.Tensor,
|
363 |
+
layer_head_mask: torch.Tensor,
|
364 |
+
output_attentions: bool = False,
|
365 |
) -> torch.Tensor:
|
366 |
"""
|
367 |
Args:
|
|
|
400 |
hidden_states = self.final_layer_norm(hidden_states)
|
401 |
|
402 |
if hidden_states.dtype == torch.float16 and (
|
403 |
+
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
404 |
):
|
405 |
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
406 |
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
|
443 |
self.normalize_before = config.decoder_normalize_before
|
444 |
|
445 |
def forward(
|
446 |
+
self,
|
447 |
+
hidden_states: torch.Tensor,
|
448 |
+
attention_mask: Optional[torch.Tensor] = None,
|
449 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
450 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
451 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
452 |
+
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
453 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
454 |
+
output_attentions: Optional[bool] = False,
|
455 |
+
use_cache: Optional[bool] = True,
|
456 |
) -> torch.Tensor:
|
457 |
"""
|
458 |
Args:
|
|
|
604 |
# Initialize weights and apply final processing
|
605 |
self.post_init()
|
606 |
|
607 |
+
def get_pooled_representation(self, hidden_states, attention_mask):
|
608 |
+
seqs = torch.clone(hidden_states)
|
609 |
+
seqs[attention_mask == 0] = 0
|
610 |
+
sentence_embedding = seqs.sum(dim=1)
|
611 |
+
weights = 1.0 / ((attention_mask != 0).float().sum(dim=1) + 1e-7)
|
612 |
+
|
613 |
+
sentence_embedding = torch.einsum(
|
614 |
+
"i...,i ->i...", sentence_embedding, weights
|
615 |
+
)
|
616 |
+
return sentence_embedding
|
617 |
+
|
618 |
def forward(
|
619 |
+
self,
|
620 |
+
input_ids: Optional[torch.Tensor] = None,
|
621 |
+
attention_mask: Optional[torch.Tensor] = None,
|
622 |
+
head_mask: Optional[torch.Tensor] = None,
|
623 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
624 |
+
output_attentions: Optional[bool] = None,
|
625 |
+
output_hidden_states: Optional[bool] = None,
|
626 |
+
return_dict: Optional[bool] = None,
|
627 |
):
|
628 |
r"""
|
629 |
Args:
|
|
|
754 |
if output_hidden_states:
|
755 |
encoder_states = encoder_states + (hidden_states,)
|
756 |
|
757 |
+
hidden_states = self.get_pooled_representation(hidden_states, attention_mask)
|
758 |
+
|
759 |
if not return_dict:
|
760 |
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
761 |
return BaseModelOutput(
|
|
|
802 |
self.post_init()
|
803 |
|
804 |
def forward(
|
805 |
+
self,
|
806 |
+
input_ids: Optional[torch.Tensor] = None,
|
807 |
+
attention_mask: Optional[torch.Tensor] = None,
|
808 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
809 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
810 |
+
head_mask: Optional[torch.Tensor] = None,
|
811 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
812 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
813 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
814 |
+
use_cache: Optional[bool] = None,
|
815 |
+
output_attentions: Optional[bool] = None,
|
816 |
+
output_hidden_states: Optional[bool] = None,
|
817 |
+
return_dict: Optional[bool] = None,
|
818 |
):
|
819 |
r"""
|
820 |
Args:
|
|
|
1048 |
|
1049 |
def __init__(self, config: IndicTransConfig):
|
1050 |
super().__init__(config)
|
1051 |
+
|
1052 |
self.encoder = IndicTransEncoder(config)
|
1053 |
self.decoder = IndicTransDecoder(config)
|
1054 |
|
|
|
1062 |
return self.decoder
|
1063 |
|
1064 |
def forward(
|
1065 |
+
self,
|
1066 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1067 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1068 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1069 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
1070 |
+
head_mask: Optional[torch.Tensor] = None,
|
1071 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
1072 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
1073 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1074 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1075 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1076 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
1077 |
+
use_cache: Optional[bool] = None,
|
1078 |
+
output_attentions: Optional[bool] = None,
|
1079 |
+
output_hidden_states: Optional[bool] = None,
|
1080 |
+
return_dict: Optional[bool] = None,
|
1081 |
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
|
1082 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1083 |
output_hidden_states = (
|
|
|
1147 |
|
1148 |
if config.share_decoder_input_output_embed:
|
1149 |
self.lm_head.weight = self.model.decoder.embed_tokens.weight
|
1150 |
+
|
1151 |
self.post_init()
|
1152 |
+
|
1153 |
def tie_weights(self):
|
1154 |
pass
|
1155 |
|
|
|
1166 |
self.lm_head = new_embeddings
|
1167 |
|
1168 |
def forward(
|
1169 |
+
self,
|
1170 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1171 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1172 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1173 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
1174 |
+
head_mask: Optional[torch.Tensor] = None,
|
1175 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
1176 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
1177 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1178 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1179 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1180 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
1181 |
+
labels: Optional[torch.LongTensor] = None,
|
1182 |
+
use_cache: Optional[bool] = None,
|
1183 |
+
output_attentions: Optional[bool] = None,
|
1184 |
+
output_hidden_states: Optional[bool] = None,
|
1185 |
+
return_dict: Optional[bool] = None,
|
1186 |
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
1187 |
r"""
|
1188 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
1243 |
)
|
1244 |
|
1245 |
def prepare_inputs_for_generation(
|
1246 |
+
self,
|
1247 |
+
decoder_input_ids,
|
1248 |
+
past_key_values=None,
|
1249 |
+
attention_mask=None,
|
1250 |
+
head_mask=None,
|
1251 |
+
decoder_head_mask=None,
|
1252 |
+
cross_attn_head_mask=None,
|
1253 |
+
use_cache=None,
|
1254 |
+
encoder_outputs=None,
|
1255 |
+
**kwargs,
|
1256 |
):
|
1257 |
# cut decoder_input_ids if past is used
|
1258 |
if past_key_values is not None:
|