prajdabre VarunGumma commited on
Commit
d32aed5
·
verified ·
1 Parent(s): 70746a7

Update modeling_rotary_indictrans.py (#3)

Browse files

- Update modeling_rotary_indictrans.py (e4e81d26aef8d1bfe441cf04c74a0476f8ea9947)


Co-authored-by: Varun Gumma <VarunGumma@users.noreply.huggingface.co>

Files changed (1) hide show
  1. modeling_rotary_indictrans.py +2 -1
modeling_rotary_indictrans.py CHANGED
@@ -27,6 +27,7 @@ from einops import rearrange, repeat
27
  from torch.amp import autocast
28
  from torch import einsum
29
 
 
30
  from transformers.modeling_utils import PreTrainedModel
31
  from .configuration_rotary_indictrans import RotaryIndicTransConfig
32
 
@@ -1496,7 +1497,7 @@ class RotaryIndicTransModel(RotaryIndicTransPreTrainedModel):
1496
 
1497
 
1498
  # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->RotaryIndicTrans
1499
- class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel):
1500
  base_model_prefix = "model"
1501
  _tied_weights_keys = None
1502
  _label_smoothing = 0.0
 
27
  from torch.amp import autocast
28
  from torch import einsum
29
 
30
+ from transformers.generation import GenerationMixin
31
  from transformers.modeling_utils import PreTrainedModel
32
  from .configuration_rotary_indictrans import RotaryIndicTransConfig
33
 
 
1497
 
1498
 
1499
  # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->RotaryIndicTrans
1500
+ class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel, GenerationMixin):
1501
  base_model_prefix = "model"
1502
  _tied_weights_keys = None
1503
  _label_smoothing = 0.0