VarunGumma
commited on
Update modeling_rotary_indictrans.py
Browse files- modeling_rotary_indictrans.py +26 -25
modeling_rotary_indictrans.py
CHANGED
@@ -31,16 +31,22 @@ from transformers.generation import GenerationMixin
|
|
31 |
from transformers.modeling_utils import PreTrainedModel
|
32 |
from .configuration_rotary_indictrans import RotaryIndicTransConfig
|
33 |
|
34 |
-
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
35 |
-
from flash_attn.bert_padding import (
|
36 |
-
index_first_axis,
|
37 |
-
pad_input,
|
38 |
-
unpad_input,
|
39 |
-
)
|
40 |
-
|
41 |
logger = logging.get_logger(__name__)
|
42 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
46 |
def _get_unpad_data(attention_mask):
|
@@ -1401,8 +1407,6 @@ class RotaryIndicTransDecoder(RotaryIndicTransPreTrainedModel):
|
|
1401 |
|
1402 |
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->RotaryIndicTrans
|
1403 |
class RotaryIndicTransModel(RotaryIndicTransPreTrainedModel):
|
1404 |
-
_tied_weights_keys = None
|
1405 |
-
|
1406 |
def __init__(self, config: RotaryIndicTransConfig):
|
1407 |
super().__init__(config)
|
1408 |
|
@@ -1497,10 +1501,11 @@ class RotaryIndicTransModel(RotaryIndicTransPreTrainedModel):
|
|
1497 |
|
1498 |
|
1499 |
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->RotaryIndicTrans
|
1500 |
-
class RotaryIndicTransForConditionalGeneration(
|
|
|
|
|
1501 |
base_model_prefix = "model"
|
1502 |
-
_tied_weights_keys =
|
1503 |
-
_label_smoothing = 0.0
|
1504 |
|
1505 |
def __init__(self, config: RotaryIndicTransConfig):
|
1506 |
super().__init__(config)
|
@@ -1509,19 +1514,16 @@ class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel,
|
|
1509 |
config.decoder_embed_dim, config.decoder_vocab_size, bias=False
|
1510 |
)
|
1511 |
|
1512 |
-
if config.share_decoder_input_output_embed:
|
1513 |
-
self.lm_head.weight = self.model.decoder.embed_tokens.weight
|
1514 |
-
|
1515 |
self.post_init()
|
1516 |
|
1517 |
-
def tie_weights(self):
|
1518 |
-
pass
|
1519 |
-
|
1520 |
def get_encoder(self):
|
1521 |
-
return self.model.
|
1522 |
|
1523 |
def get_decoder(self):
|
1524 |
-
return self.model.
|
|
|
|
|
|
|
1525 |
|
1526 |
def get_output_embeddings(self):
|
1527 |
return self.lm_head
|
@@ -1529,8 +1531,9 @@ class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel,
|
|
1529 |
def set_output_embeddings(self, new_embeddings):
|
1530 |
self.lm_head = new_embeddings
|
1531 |
|
1532 |
-
def
|
1533 |
-
self.
|
|
|
1534 |
|
1535 |
def forward(
|
1536 |
self,
|
@@ -1594,8 +1597,6 @@ class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel,
|
|
1594 |
masked_lm_loss = F.cross_entropy(
|
1595 |
input=lm_logits.view(-1, self.config.decoder_vocab_size),
|
1596 |
target=labels.view(-1),
|
1597 |
-
ignore_index=-100,
|
1598 |
-
label_smoothing=self._label_smoothing,
|
1599 |
)
|
1600 |
|
1601 |
if not return_dict:
|
@@ -1652,4 +1653,4 @@ class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel,
|
|
1652 |
past_state.index_select(0, beam_idx) for past_state in layer_past
|
1653 |
),
|
1654 |
)
|
1655 |
-
return reordered_past
|
|
|
31 |
from transformers.modeling_utils import PreTrainedModel
|
32 |
from .configuration_rotary_indictrans import RotaryIndicTransConfig
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
logger = logging.get_logger(__name__)
|
35 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
36 |
|
37 |
+
try:
|
38 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
39 |
+
from flash_attn.bert_padding import (
|
40 |
+
index_first_axis,
|
41 |
+
pad_input,
|
42 |
+
unpad_input,
|
43 |
+
)
|
44 |
+
except ImportError:
|
45 |
+
logger.warning(
|
46 |
+
"It is highly recommended to use `flash_attention_2` for better performance with RotaryIndicTrans."
|
47 |
+
"Falling back to the default `eager` implementation."
|
48 |
+
)
|
49 |
+
|
50 |
|
51 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
52 |
def _get_unpad_data(attention_mask):
|
|
|
1407 |
|
1408 |
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->RotaryIndicTrans
|
1409 |
class RotaryIndicTransModel(RotaryIndicTransPreTrainedModel):
|
|
|
|
|
1410 |
def __init__(self, config: RotaryIndicTransConfig):
|
1411 |
super().__init__(config)
|
1412 |
|
|
|
1501 |
|
1502 |
|
1503 |
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->RotaryIndicTrans
|
1504 |
+
class RotaryIndicTransForConditionalGeneration(
|
1505 |
+
RotaryIndicTransPreTrainedModel, GenerationMixin
|
1506 |
+
):
|
1507 |
base_model_prefix = "model"
|
1508 |
+
_tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"]
|
|
|
1509 |
|
1510 |
def __init__(self, config: RotaryIndicTransConfig):
|
1511 |
super().__init__(config)
|
|
|
1514 |
config.decoder_embed_dim, config.decoder_vocab_size, bias=False
|
1515 |
)
|
1516 |
|
|
|
|
|
|
|
1517 |
self.post_init()
|
1518 |
|
|
|
|
|
|
|
1519 |
def get_encoder(self):
|
1520 |
+
return self.model.encoder
|
1521 |
|
1522 |
def get_decoder(self):
|
1523 |
+
return self.model.decoder
|
1524 |
+
|
1525 |
+
def get_input_embeddings(self):
|
1526 |
+
return self.model.encoder.embed_tokens
|
1527 |
|
1528 |
def get_output_embeddings(self):
|
1529 |
return self.lm_head
|
|
|
1531 |
def set_output_embeddings(self, new_embeddings):
|
1532 |
self.lm_head = new_embeddings
|
1533 |
|
1534 |
+
def tie_weights(self):
|
1535 |
+
if self.config.share_decoder_input_output_embed:
|
1536 |
+
self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.lm_head)
|
1537 |
|
1538 |
def forward(
|
1539 |
self,
|
|
|
1597 |
masked_lm_loss = F.cross_entropy(
|
1598 |
input=lm_logits.view(-1, self.config.decoder_vocab_size),
|
1599 |
target=labels.view(-1),
|
|
|
|
|
1600 |
)
|
1601 |
|
1602 |
if not return_dict:
|
|
|
1653 |
past_state.index_select(0, beam_idx) for past_state in layer_past
|
1654 |
),
|
1655 |
)
|
1656 |
+
return reordered_past
|