GEETHANAYAGI
commited on
Commit
•
f9d7028
1
Parent(s):
5e885fd
Upload 79 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- IndicTrans2/.gitignore +148 -0
- IndicTrans2/LICENSE +21 -0
- IndicTrans2/README.md +523 -0
- IndicTrans2/apply_sentence_piece.sh +48 -0
- IndicTrans2/baseline_eval/azure_translate.py +183 -0
- IndicTrans2/baseline_eval/google_translate.py +129 -0
- IndicTrans2/baseline_eval/m2m100_inference.py +148 -0
- IndicTrans2/baseline_eval/mbart_inference.py +159 -0
- IndicTrans2/baseline_eval/nllb_moe_cpu_inference.py +157 -0
- IndicTrans2/compute_comet_score.sh +84 -0
- IndicTrans2/compute_metrics.sh +29 -0
- IndicTrans2/compute_metrics_significance.sh +66 -0
- IndicTrans2/eval.sh +54 -0
- IndicTrans2/eval_rev.sh +55 -0
- IndicTrans2/finetune.sh +54 -0
- IndicTrans2/huggingface_interface/.gitignore +1 -0
- IndicTrans2/huggingface_interface/README.md +119 -0
- IndicTrans2/huggingface_interface/colab_inference.ipynb +458 -0
- IndicTrans2/huggingface_interface/configuration_indictrans.py +309 -0
- IndicTrans2/huggingface_interface/convert_indictrans_checkpoint_to_pytorch.py +107 -0
- IndicTrans2/huggingface_interface/example.py +275 -0
- IndicTrans2/huggingface_interface/install.sh +49 -0
- IndicTrans2/huggingface_interface/modeling_indictrans.py +1801 -0
- IndicTrans2/huggingface_interface/train_lora.py +355 -0
- IndicTrans2/huggingface_interface/train_lora.sh +35 -0
- IndicTrans2/inference/__init__.py +0 -0
- IndicTrans2/inference/custom_interactive.py +304 -0
- IndicTrans2/inference/download.py +5 -0
- IndicTrans2/inference/engine.py +472 -0
- IndicTrans2/inference/flores_codes_map_indic.py +83 -0
- IndicTrans2/inference/indic_num_map.py +117 -0
- IndicTrans2/inference/model_configs/__init__.py +1 -0
- IndicTrans2/inference/model_configs/custom_transformer.py +82 -0
- IndicTrans2/inference/normalize-punctuation.perl +90 -0
- IndicTrans2/inference/normalize_punctuation.py +60 -0
- IndicTrans2/inference/normalize_punctuation.sh +33 -0
- IndicTrans2/inference/normalize_regex_inference.py +105 -0
- IndicTrans2/inference/requirements.txt +11 -0
- IndicTrans2/inference/triton_server/Dockerfile +25 -0
- IndicTrans2/inference/triton_server/README.md +22 -0
- IndicTrans2/inference/triton_server/azure_ml/README.md +56 -0
- IndicTrans2/inference/triton_server/azure_ml/deployment.yml +13 -0
- IndicTrans2/inference/triton_server/azure_ml/endpoint.yml +3 -0
- IndicTrans2/inference/triton_server/azure_ml/environment.yml +14 -0
- IndicTrans2/inference/triton_server/azure_ml/model.yml +5 -0
- IndicTrans2/inference/triton_server/client.py +55 -0
- IndicTrans2/inference/triton_server/dhruva/ulca_model.json +0 -0
- IndicTrans2/inference/triton_server/triton_repo/nmt/1/model.py +167 -0
- IndicTrans2/inference/triton_server/triton_repo/nmt/config.pbtxt +32 -0
- IndicTrans2/inference/utils.map_token_lang.tsv +26 -0
IndicTrans2/.gitignore
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ignore libs and data folder we use
|
2 |
+
indic_nlp_library
|
3 |
+
indic_nlp_resources
|
4 |
+
fairseq
|
5 |
+
devtest
|
6 |
+
checkpoints
|
7 |
+
eval_benchmarks
|
8 |
+
|
9 |
+
# Byte-compiled / optimized / DLL files
|
10 |
+
__pycache__/
|
11 |
+
*.py[cod]
|
12 |
+
*$py.class
|
13 |
+
|
14 |
+
# C extensions
|
15 |
+
*.so
|
16 |
+
|
17 |
+
# Distribution / packaging
|
18 |
+
.Python
|
19 |
+
build/
|
20 |
+
develop-eggs/
|
21 |
+
dist/
|
22 |
+
downloads/
|
23 |
+
eggs/
|
24 |
+
.eggs/
|
25 |
+
lib/
|
26 |
+
lib64/
|
27 |
+
parts/
|
28 |
+
sdist/
|
29 |
+
var/
|
30 |
+
wheels/
|
31 |
+
share/python-wheels/
|
32 |
+
*.egg-info/
|
33 |
+
.installed.cfg
|
34 |
+
*.egg
|
35 |
+
MANIFEST
|
36 |
+
|
37 |
+
# PyInstaller
|
38 |
+
# Usually these files are written by a python script from a template
|
39 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
40 |
+
*.manifest
|
41 |
+
*.spec
|
42 |
+
|
43 |
+
# Installer logs
|
44 |
+
pip-log.txt
|
45 |
+
pip-delete-this-directory.txt
|
46 |
+
|
47 |
+
# Unit test / coverage reports
|
48 |
+
htmlcov/
|
49 |
+
.tox/
|
50 |
+
.nox/
|
51 |
+
.coverage
|
52 |
+
.coverage.*
|
53 |
+
.cache
|
54 |
+
nosetests.xml
|
55 |
+
coverage.xml
|
56 |
+
*.cover
|
57 |
+
*.py,cover
|
58 |
+
.hypothesis/
|
59 |
+
.pytest_cache/
|
60 |
+
cover/
|
61 |
+
|
62 |
+
# Translations
|
63 |
+
*.mo
|
64 |
+
*.pot
|
65 |
+
|
66 |
+
# Django stuff:
|
67 |
+
*.log
|
68 |
+
local_settings.py
|
69 |
+
db.sqlite3
|
70 |
+
db.sqlite3-journal
|
71 |
+
|
72 |
+
# Flask stuff:
|
73 |
+
instance/
|
74 |
+
.webassets-cache
|
75 |
+
|
76 |
+
# Scrapy stuff:
|
77 |
+
.scrapy
|
78 |
+
|
79 |
+
# Sphinx documentation
|
80 |
+
docs/_build/
|
81 |
+
|
82 |
+
# PyBuilder
|
83 |
+
.pybuilder/
|
84 |
+
target/
|
85 |
+
|
86 |
+
# Jupyter Notebook
|
87 |
+
.ipynb_checkpoints
|
88 |
+
|
89 |
+
# IPython
|
90 |
+
profile_default/
|
91 |
+
ipython_config.py
|
92 |
+
|
93 |
+
# pyenv
|
94 |
+
# For a library or package, you might want to ignore these files since the code is
|
95 |
+
# intended to run in multiple environments; otherwise, check them in:
|
96 |
+
# .python-version
|
97 |
+
|
98 |
+
# pipenv
|
99 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
100 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
101 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
102 |
+
# install all needed dependencies.
|
103 |
+
#Pipfile.lock
|
104 |
+
|
105 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
106 |
+
__pypackages__/
|
107 |
+
|
108 |
+
# Celery stuff
|
109 |
+
celerybeat-schedule
|
110 |
+
celerybeat.pid
|
111 |
+
|
112 |
+
# SageMath parsed files
|
113 |
+
*.sage.py
|
114 |
+
|
115 |
+
# Environments
|
116 |
+
.env
|
117 |
+
.venv
|
118 |
+
env/
|
119 |
+
venv/
|
120 |
+
ENV/
|
121 |
+
env.bak/
|
122 |
+
venv.bak/
|
123 |
+
|
124 |
+
# Spyder project settings
|
125 |
+
.spyderproject
|
126 |
+
.spyproject
|
127 |
+
|
128 |
+
# Rope project settings
|
129 |
+
.ropeproject
|
130 |
+
|
131 |
+
# mkdocs documentation
|
132 |
+
/site
|
133 |
+
|
134 |
+
# mypy
|
135 |
+
.mypy_cache/
|
136 |
+
.dmypy.json
|
137 |
+
dmypy.json
|
138 |
+
|
139 |
+
# Pyre type checker
|
140 |
+
.pyre/
|
141 |
+
|
142 |
+
# pytype static type analyzer
|
143 |
+
.pytype/
|
144 |
+
|
145 |
+
# Cython debug symbols
|
146 |
+
cython_debug/
|
147 |
+
|
148 |
+
.DS_Store
|
IndicTrans2/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) AI4Bharat.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE
|
IndicTrans2/README.md
ADDED
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# IndicTrans2
|
2 |
+
|
3 |
+
[📜 Paper](https://arxiv.org/abs/2305.16307) | [🌐 Website](https://ai4bharat.iitm.ac.in/indic-trans2) | [▶️ Demo](https://models.ai4bharat.org/#/nmt/v2) | [🤗 HF Interface](https://github.com/AI4Bharat/IndicTrans2/tree/main/huggingface_interface) | [![colab link](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI4Bharat/IndicTrans2/blob/main/huggingface_interface/colab_inference.ipynb)
|
4 |
+
|
5 |
+
IndicTrans2 is the first open-source transformer-based multilingual NMT model that supports high-quality translations across all the 22 scheduled Indic languages — including multiple scripts for low-resouce languages like Kashmiri, Manipuri and Sindhi. It adopts script unification wherever feasible to leverage transfer learning by lexical sharing between languages. Overall, the model supports five scripts Perso-Arabic (Kashmiri, Sindhi, Urdu), Ol Chiki (Santali), Meitei (Manipuri), Latin (English), and Devanagari (used for all the remaining languages).
|
6 |
+
|
7 |
+
We open-souce all our training dataset (BPCC), back-translation data (BPCC-BT), final IndicTrans2 models, evaluation benchmarks (IN22, which includes IN22-Gen and IN22-Conv) and training and inference scripts for easier use and adoption within the research community. We hope that this will foster even more research in low-resource Indic languages, leading to further improvements in the quality of low-resource translation through contributions from the research community.
|
8 |
+
|
9 |
+
This code repository contains instructions for downloading the artifacts associated with IndicTrans2, as well as the code for training/fine-tuning the multilingual NMT models.
|
10 |
+
|
11 |
+
Here is the list of languages supported by the IndicTrans2 models:
|
12 |
+
|
13 |
+
<table>
|
14 |
+
<tbody>
|
15 |
+
<tr>
|
16 |
+
<td>Assamese (asm_Beng)</td>
|
17 |
+
<td>Kashmiri (Arabic) (kas_Arab)</td>
|
18 |
+
<td>Punjabi (pan_Guru)</td>
|
19 |
+
</tr>
|
20 |
+
<tr>
|
21 |
+
<td>Bengali (ben_Beng)</td>
|
22 |
+
<td>Kashmiri (Devanagari) (kas_Deva)</td>
|
23 |
+
<td>Sanskrit (san_Deva)</td>
|
24 |
+
</tr>
|
25 |
+
<tr>
|
26 |
+
<td>Bodo (brx_Deva)</td>
|
27 |
+
<td>Maithili (mai_Deva)</td>
|
28 |
+
<td>Santali (sat_Olck)</td>
|
29 |
+
</tr>
|
30 |
+
<tr>
|
31 |
+
<td>Dogri (doi_Deva)</td>
|
32 |
+
<td>Malayalam (mal_Mlym)</td>
|
33 |
+
<td>Sindhi (Arabic) (snd_Arab)</td>
|
34 |
+
</tr>
|
35 |
+
<tr>
|
36 |
+
<td>English (eng_Latn)</td>
|
37 |
+
<td>Marathi (mar_Deva)</td>
|
38 |
+
<td>Sindhi (Devanagari) (snd_Deva)</td>
|
39 |
+
</tr>
|
40 |
+
<tr>
|
41 |
+
<td>Konkani (gom_Deva)</td>
|
42 |
+
<td>Manipuri (Bengali) (mni_Beng)</td>
|
43 |
+
<td>Tamil (tam_Taml)</td>
|
44 |
+
</tr>
|
45 |
+
<tr>
|
46 |
+
<td>Gujarati (guj_Gujr)</td>
|
47 |
+
<td>Manipuri (Meitei) (mni_Mtei)</td>
|
48 |
+
<td>Telugu (tel_Telu)</td>
|
49 |
+
</tr>
|
50 |
+
<tr>
|
51 |
+
<td>Hindi (hin_Deva)</td>
|
52 |
+
<td>Nepali (npi_Deva)</td>
|
53 |
+
<td>Urdu (urd_Arab)</td>
|
54 |
+
</tr>
|
55 |
+
<tr>
|
56 |
+
<td>Kannada (kan_Knda)</td>
|
57 |
+
<td>Odia (ory_Orya)</td>
|
58 |
+
<td></td>
|
59 |
+
</tr>
|
60 |
+
</tbody>
|
61 |
+
</table>
|
62 |
+
|
63 |
+
## Updates
|
64 |
+
|
65 |
+
- 🚨 Dec 30, 2023 - Migrated IndicTrans2 tokenizer for HF compatible IndicTrans2 models to [IndicTransTokenizer](https://github.com/VarunGumma/IndicTransTokenizer) and will be maintained separately there from now onwards. Add LoRA fine-tuning scripts for our IndicTrans2 models in [huggingface_interface](https://github.com/AI4Bharat/IndicTrans2/tree/main/huggingface_interface).
|
66 |
+
- 🚨 Dec 1, 2023 - Release of Indic-Indic model and corresponding distilled variants for each base model. Please refer to the [Download section](https://github.com/AI4Bharat/IndicTrans2#multilingual-translation-models) for the checkpoints.
|
67 |
+
- 🚨 Sep 9, 2023 - Added HF compatible IndicTrans2 models. Please refer to the [README](https://github.com/AI4Bharat/IndicTrans2/tree/main/huggingface_interface) for detailed example usage.
|
68 |
+
|
69 |
+
## Tables of Contents
|
70 |
+
|
71 |
+
- [Download Models and Other Artifacts](#download-models-and-other-artifacts)
|
72 |
+
- [Multilingual Translation Models](#multilingual-translation-models)
|
73 |
+
- [Training Data](#training-data)
|
74 |
+
- [Evaluation Data](#evaluation-data)
|
75 |
+
- [Installation](#installation)
|
76 |
+
- [Data](#data)
|
77 |
+
- [Training](#training)
|
78 |
+
- [Evaluation](#evaluation)
|
79 |
+
- [Preparing Data for Training](#preparing-data-for-training)
|
80 |
+
- [Using our SPM model and Fairseq dictionary](#using-our-spm-model-and-fairseq-dictionary)
|
81 |
+
- [Training your own SPM models and learning Fairseq dictionary](#training-your-own-spm-models-and-learning-fairseq-dictionary)
|
82 |
+
- [Training / Fine-tuning](#training--fine-tuning)
|
83 |
+
- [Inference](#inference)
|
84 |
+
- [Fairseq Inference](#fairseq-inference)
|
85 |
+
- [CT2 Inference](#ct2-inference)
|
86 |
+
- [Evaluations](#evaluations)
|
87 |
+
- [Baseline Evaluation](#baseline-evaluation)
|
88 |
+
- [LICENSE](#license)
|
89 |
+
- [Citation](#citation)
|
90 |
+
|
91 |
+
## Download Models and Other Artifacts
|
92 |
+
|
93 |
+
### Multilingual Translation Models
|
94 |
+
|
95 |
+
| Model | En-Indic | Indic-En | Indic-Indic | Evaluations |
|
96 |
+
| ---------------------------- | ----------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
97 |
+
| Base (used for benchmarking) | [download](https://indictrans2-public.objectstore.e2enetworks.net/it2_preprint_ckpts/en-indic-preprint.zip) | [download](https://indictrans2-public.objectstore.e2enetworks.net/it2_preprint_ckpts/indic-en-preprint.zip) | [download](https://indictrans2-public.objectstore.e2enetworks.net/it2_preprint_ckpts/indic-indic.zip) | [translations](https://indictrans2-public.objectstore.e2enetworks.net/translation_outputs.zip) (as of May 10, 2023), [metrics](https://drive.google.com/drive/folders/1lOOdaU0VdRSBgJEsNav5zC7wwLBis9NI?usp=sharing) |
|
98 |
+
| Distilled | [download](https://indictrans2-public.objectstore.e2enetworks.net/it2_distilled_ckpts/en-indic.zip) | [download](https://indictrans2-public.objectstore.e2enetworks.net/it2_distilled_ckpts/indic-en.zip) | [download](https://indictrans2-public.objectstore.e2enetworks.net/it2_distilled_ckpts/indic-indic.zip) |
|
99 |
+
|
100 |
+
### Training Data
|
101 |
+
|
102 |
+
| Data | URL |
|
103 |
+
| ---------------------------------------- | ------------------------------------------------------------------------------ |
|
104 |
+
| Bharat Parallel Corpus Collection (BPCC) | [download](https://indictrans2-public.objectstore.e2enetworks.net/BPCC.zip) |
|
105 |
+
| Back-translation (BPCC-BT) | [download](https://indictrans2-public.objectstore.e2enetworks.net/BT_data.zip) |
|
106 |
+
|
107 |
+
### Evaluation Data
|
108 |
+
|
109 |
+
| Data | URL |
|
110 |
+
| ----------------------- | ------------------------------------------------------------------------------------ |
|
111 |
+
| IN22 test set | [download](https://indictrans2-public.objectstore.e2enetworks.net/IN22_testset.zip) |
|
112 |
+
| FLORES-22 Indic dev set | [download](https://indictrans2-public.objectstore.e2enetworks.net/flores-22_dev.zip) |
|
113 |
+
|
114 |
+
## Installation
|
115 |
+
|
116 |
+
Instructions to setup and install everything before running the code.
|
117 |
+
|
118 |
+
```bash
|
119 |
+
# Clone the github repository and navigate to the project directory.
|
120 |
+
git clone https://github.com/AI4Bharat/IndicTrans2
|
121 |
+
cd IndicTrans2
|
122 |
+
|
123 |
+
# Install all the dependencies and requirements associated with the project.
|
124 |
+
source install.sh
|
125 |
+
```
|
126 |
+
|
127 |
+
Note: We recommend creating a virtual environment with python>=3.7.
|
128 |
+
|
129 |
+
### Additional notes about Installation
|
130 |
+
The ``prepare_data_joint_finetuning.sh`` and ``prepare_data_joint_training.sh`` scripts expect that the sentencepiece commandline utility and GNU parallel are installed.
|
131 |
+
1. To install the sentencepiece command line utility, please follow the instructions [here](https://github.com/google/sentencepiece?tab=readme-ov-file#build-and-install-sentencepiece-command-line-tools-from-c-source).
|
132 |
+
2. Please check if GNU parallel is installed, if not please install the same or alternatively in case of installation issues, remove ``parallel --pipe --keep-order`` from the respective training / finetuning script as well as ``apply_sentence_piece.sh``.
|
133 |
+
|
134 |
+
|
135 |
+
## Data
|
136 |
+
|
137 |
+
### Training
|
138 |
+
|
139 |
+
Bharat Parallel Corpus Collection (BPCC) is a comprehensive and publicly available parallel corpus that includes both existing and new data for all 22 scheduled Indic languages. It is comprised of two parts: BPCC-Mined and BPCC-Human, totaling approximately 230 million bitext pairs. BPCC-Mined contains about 228 million pairs, with nearly 126 million pairs newly added as a part of this work. On the other hand, BPCC-Human consists of 2.2 million gold standard English-Indic pairs, with an additional 644K bitext pairs from English Wikipedia sentences (forming the BPCC-H-Wiki subset) and 139K sentences covering everyday use cases (forming the BPCC-H-Daily subset). It is worth highlighting that BPCC provides the first available datasets for 7 languages and significantly increases the available data for all languages covered.
|
140 |
+
|
141 |
+
You can find the contribution from different sources in the following table:
|
142 |
+
|
143 |
+
<table>
|
144 |
+
<tbody>
|
145 |
+
<tr>
|
146 |
+
<td rowspan="4">BPCC-Mined</th>
|
147 |
+
<td rowspan="2">Existing</th>
|
148 |
+
<td>Samanantar</th>
|
149 |
+
<td>19.4M</th>
|
150 |
+
</tr>
|
151 |
+
<tr>
|
152 |
+
<td>NLLB</th>
|
153 |
+
<td>85M</th>
|
154 |
+
</tr>
|
155 |
+
<tr>
|
156 |
+
<td rowspan="2">Newly Added</th>
|
157 |
+
<td>Samanantar++</th>
|
158 |
+
<td>121.6M</th>
|
159 |
+
</tr>
|
160 |
+
<tr>
|
161 |
+
<td>Comparable</th>
|
162 |
+
<td>4.3M</th>
|
163 |
+
</tr>
|
164 |
+
<tr>
|
165 |
+
<td rowspan="5">BPCC-Human</td>
|
166 |
+
<td rowspan="3">Existing</td>
|
167 |
+
<td>NLLB</td>
|
168 |
+
<td>18.5K</td>
|
169 |
+
</tr>
|
170 |
+
<tr>
|
171 |
+
<td>ILCI</td>
|
172 |
+
<td>1.3M</td>
|
173 |
+
</tr>
|
174 |
+
<tr>
|
175 |
+
<td>Massive</td>
|
176 |
+
<td>115K</td>
|
177 |
+
</tr>
|
178 |
+
<tr>
|
179 |
+
<td rowspan="2">Newly Added</td>
|
180 |
+
<td>Wiki</td>
|
181 |
+
<td>644K</td>
|
182 |
+
</tr>
|
183 |
+
<tr>
|
184 |
+
<td>Daily</td>
|
185 |
+
<td>139K</td>
|
186 |
+
</tr>
|
187 |
+
</tbody>
|
188 |
+
</table>
|
189 |
+
|
190 |
+
Additionally, we provide augmented back-translation data generated by our intermediate IndicTrans2 models for training purposes. Please refer our paper for more details on the selection of sample proportions and sources.
|
191 |
+
|
192 |
+
<table>
|
193 |
+
<tbody>
|
194 |
+
<tr>
|
195 |
+
<td>English BT data (English Original)</td>
|
196 |
+
<td>401.9M</td>
|
197 |
+
</tr>
|
198 |
+
<tr>
|
199 |
+
<td>Indic BT data (Indic Original)</td>
|
200 |
+
<td>400.9M</td>
|
201 |
+
</tr>
|
202 |
+
</tbody>
|
203 |
+
</table>
|
204 |
+
|
205 |
+
<br>
|
206 |
+
|
207 |
+
### Evaluation
|
208 |
+
|
209 |
+
IN22 test set is a newly created comprehensive benchmark for evaluating machine translation performance in multi-domain, n-way parallel contexts across 22 Indic languages. It has been created from three distinct subsets, namely IN22-Wiki, IN22-Web and IN22-Conv. The Wikipedia and Web sources subsets offer diverse content spanning news, entertainment, culture, legal, and India-centric topics. IN22-Wiki and IN22-Web have been combined and considered for evaluation purposes and released as IN22-Gen. Meanwhile, IN22-Conv the conversation domain subset is designed to assess translation quality in typical day-to-day conversational-style applications.
|
210 |
+
|
211 |
+
<table>
|
212 |
+
<tbody>
|
213 |
+
<tr>
|
214 |
+
<td>IN22-Gen (IN22-Wiki + IN22-Web)</td>
|
215 |
+
<td>1024 sentences</td>
|
216 |
+
<td>🤗 <a href="https://huggingface.co/datasets/ai4bharat/IN22-Gen">ai4bharat/IN22-Gen</td>
|
217 |
+
</tr>
|
218 |
+
<tr>
|
219 |
+
<td>IN22-Conv</td>
|
220 |
+
<td>1503 sentences</td>
|
221 |
+
<td>🤗 <a href="https://huggingface.co/datasets/ai4bharat/IN22-Conv">ai4bharat/IN22-Conv</td>
|
222 |
+
</tr>
|
223 |
+
</tbody>
|
224 |
+
</table>
|
225 |
+
|
226 |
+
You can download the data artifacts released as a part of this work from the [following section](#download-models-and-other-artifacts).
|
227 |
+
|
228 |
+
## Preparing Data for Training
|
229 |
+
|
230 |
+
BPCC data is organized under different subsets as described above, where each subset contains language pair subdirectories with the sentences pairs. We also provide LaBSE and LASER for the mined subsets of BPCC. In order to replicate our training setup, you will need to combine the data for corresponding language pairs from different subsets and remove overlapping bitext pairs if any.
|
231 |
+
|
232 |
+
Here is the expected directory structure of the data:
|
233 |
+
|
234 |
+
```bash
|
235 |
+
BPCC
|
236 |
+
├── eng_Latn-asm_Beng
|
237 |
+
│ ├── train.eng_Latn
|
238 |
+
│ └── train.asm_Beng
|
239 |
+
├── eng_Latn-ben_Beng
|
240 |
+
└── ...
|
241 |
+
```
|
242 |
+
|
243 |
+
While we provide deduplicated subsets with the current available benchmarks, we highly recommend performing deduplication using the combined monolingual side of all the benchmarks. You can use the following command for deduplication once you combine the monolingual side of all the benchmarks in the directory.
|
244 |
+
|
245 |
+
```python3
|
246 |
+
python3 scripts/dedup_benchmark.py <in_data_dir> <out_data_dir> <benchmark_dir>
|
247 |
+
```
|
248 |
+
|
249 |
+
- `<in_data_dir>`: path to the directory containing train data for each language pair in the format `{src_lang}-{tgt_lang}`
|
250 |
+
- `<out_data_dir>`: path to the directory where the deduplicated train data will be written for each language pair in the format `{src_lang}-{tgt_lang}`
|
251 |
+
- `<benchmark_dir>`: path to the directory containing the language-wise monolingual side of dev/test set, with monolingual files named as `test.{lang}`
|
252 |
+
|
253 |
+
### Using our SPM model and Fairseq dictionary
|
254 |
+
|
255 |
+
Once you complete the deduplication of the training data with the available benchmarks, you can preprocess and binarize the data for training models. Please download our trained SPM model and learned Fairseq dictionary using the following links for your experiments.
|
256 |
+
|
257 |
+
| | En-Indic | Indic-En | Indic-Indic |
|
258 |
+
| ------------------ | -------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------- |
|
259 |
+
| SPM model | [download](https://indictrans2-public.objectstore.e2enetworks.net/en-indic-spm.zip) | [download](https://indictrans2-public.objectstore.e2enetworks.net/indic-en-spm.zip) | [download](https://indictrans2-public.objectstore.e2enetworks.net/indic-indic-spm.zip) |
|
260 |
+
| Fairseq dictionary | [download](https://indictrans2-public.objectstore.e2enetworks.net/en-indic-fairseq-dict.zip) | [download](https://indictrans2-public.objectstore.e2enetworks.net/indic-en-fairseq-dict.zip) | [download](https://indictrans2-public.objectstore.e2enetworks.net/indic-indic-fairseq-dict.zip) |
|
261 |
+
|
262 |
+
To prepare the data for training En-Indic model, please do the following:
|
263 |
+
|
264 |
+
1. Download the SPM model in the experiment directory and rename it as `vocab`.
|
265 |
+
2. Download the Fairseq dictionary in the experiment directory and rename it as `final_bin`.
|
266 |
+
|
267 |
+
Here is the expected directory for training En-Indic model:
|
268 |
+
|
269 |
+
```bash
|
270 |
+
en-indic-exp
|
271 |
+
├── train
|
272 |
+
│ ├── eng_Latn-asm_Beng
|
273 |
+
│ │ ├── train.eng_Latn
|
274 |
+
│ │ └── train.asm_Beng
|
275 |
+
│ ├── eng_Latn-ben_Beng
|
276 |
+
│ └── ...
|
277 |
+
├── devtest
|
278 |
+
│ └── all
|
279 |
+
│ ├── eng_Latn-asm_Beng
|
280 |
+
│ │ ├── dev.eng_Latn
|
281 |
+
│ │ └── dev.asm_Beng
|
282 |
+
│ ├── eng_Latn-ben_Beng
|
283 |
+
│ └── ...
|
284 |
+
├── vocab
|
285 |
+
│ ├── model.SRC
|
286 |
+
│ ├── model.TGT
|
287 |
+
│ ├── vocab.SRC
|
288 |
+
│ └── vocab.TGT
|
289 |
+
└── final_bin
|
290 |
+
├── dict.SRC.txt
|
291 |
+
└── dict.TGT.txt
|
292 |
+
```
|
293 |
+
|
294 |
+
To prepare data for training the Indic-En model, you should reverse the language pair directories within the train and devtest directories. Additionally, make sure to download the corresponding SPM model and Fairseq dictionary and put them in the experiment directory, similar to the procedure mentioned above for En-Indic model training.
|
295 |
+
|
296 |
+
You can binarize the data for model training using the following:
|
297 |
+
|
298 |
+
```bash
|
299 |
+
bash prepare_data_joint_finetuning.sh <exp_dir>
|
300 |
+
```
|
301 |
+
|
302 |
+
- `<exp_dir>`: path to the directory containing the raw data for binarization
|
303 |
+
|
304 |
+
You will need to follow the same steps for data preparation in case of fine-tuning models.
|
305 |
+
|
306 |
+
### Training your own SPM models and learning Fairseq dictionary
|
307 |
+
|
308 |
+
If you want to train your own SPM model and learn Fairseq dictionary, then please do the following:
|
309 |
+
|
310 |
+
1. Collect a balanced amount of English and Indic monolingual data (we use around 3 million sentences per language-script combination). If some languages have limited data available, increase their representation to achieve a fair distribution of tokens across languages.
|
311 |
+
2. Perform script unification for Indic languages wherever possible using `scripts/preprocess_translate.py` and concatenate all Indic data into a single file.
|
312 |
+
3. Train two SPM models, one for English and other for Indic side using the following:
|
313 |
+
|
314 |
+
```bash
|
315 |
+
spm_train --input=train.indic --model_prefix=<model_name> --vocab_size=<vocab_size> --character_coverage=1.0 --model_type=BPE
|
316 |
+
```
|
317 |
+
|
318 |
+
4. Copy the trained SPM models in the experiment directory mentioned earlier and learn the Fairseq dictionary using the following:
|
319 |
+
|
320 |
+
```bash
|
321 |
+
bash prepare_data_joint_training.sh <exp_dir>
|
322 |
+
```
|
323 |
+
|
324 |
+
5. You will need to use the same Fairseq dictionary for any subsequent fine-tuning experiments and refer to the steps described above ([link](#using-our-spm-model-and-fairseq-dictionary)).
|
325 |
+
|
326 |
+
## Training / Fine-tuning
|
327 |
+
|
328 |
+
After binarizing the data, you can use train.sh to train the models. We provide the default hyperparameters used in this work. You can modify the hyperparameters as per your requirement if needed. If you want to train the model on a customized architecture, then please define the architecture in `model_configs/custom_transformer.py`. You can start the model training with the following command:
|
329 |
+
|
330 |
+
```bash
|
331 |
+
bash train.sh <exp_dir> <model_arch>
|
332 |
+
```
|
333 |
+
|
334 |
+
- `<exp_dir>`: path to the directory containing the binarized data
|
335 |
+
- `<model_arch>`: custom transformer architecture used for model training
|
336 |
+
|
337 |
+
For fine-tuning, the initial steps remain the same. However, the `finetune.sh` script includes an additional argument, `pretrained_ckpt`, which specifies the model checkpoint to be loaded for further fine-tuning. You can perform fine-tuning using the following command:
|
338 |
+
|
339 |
+
```bash
|
340 |
+
bash finetune.sh <exp_dir> <model_arch> <pretrained_ckpt>
|
341 |
+
```
|
342 |
+
|
343 |
+
- `<exp_dir>`: path to the directory containing the binarized data
|
344 |
+
- `<model_arch>`: custom transformer architecture used for model training
|
345 |
+
- `transformer_18_18` - For IT2 Base models
|
346 |
+
- `transformer_base18L` - For IT2 Distilled models
|
347 |
+
- `<pretrained_ckpt>`: path to the fairseq model checkpoint to be loaded for further fine-tuning
|
348 |
+
|
349 |
+
You can download the model artifacts released as a part of this work from the [following section](#download-models-and-other-artifacts).
|
350 |
+
|
351 |
+
The pretrained checkpoints have 3 directories, a fairseq model directory and 2 CT-ported model directories. Please note that the CT2 models are provided only for efficient inference. For fine-tuning purposes you should use the `fairseq_model`. Post that you can use the [fairseq-ct2-converter](https://opennmt.net/CTranslate2/guides/fairseq.html) to port your fine-tuned checkpoints to CT2 for faster inference.
|
352 |
+
|
353 |
+
## Inference
|
354 |
+
|
355 |
+
### Fairseq Inference
|
356 |
+
|
357 |
+
In order to run inference on our pretrained models using bash interface, please use the following:
|
358 |
+
|
359 |
+
```bash
|
360 |
+
bash joint_translate.sh <infname> <outfname> <src_lang> <tgt_lang> <ckpt_dir>
|
361 |
+
```
|
362 |
+
|
363 |
+
- `infname`: path to the input file containing sentences
|
364 |
+
- `outfname`: path to the output file where the translations should be stored
|
365 |
+
- `src_lang`: source language
|
366 |
+
- `tgt_lang`: target language
|
367 |
+
- `ckpt_dir`: path to the fairseq model checkpoint directory
|
368 |
+
|
369 |
+
If you want to run the inference using python interface then please execute the following block of code from the root directory:
|
370 |
+
|
371 |
+
```python3
|
372 |
+
from inference.engine import Model
|
373 |
+
|
374 |
+
model = Model(ckpt_dir, model_type="fairseq")
|
375 |
+
|
376 |
+
sents = [sent1, sent2,...]
|
377 |
+
|
378 |
+
# for a batch of sentences
|
379 |
+
model.batch_translate(sents, src_lang, tgt_lang)
|
380 |
+
|
381 |
+
# for a paragraph
|
382 |
+
model.translate_paragraph(text, src_lang, tgt_lang)
|
383 |
+
```
|
384 |
+
|
385 |
+
### CT2 Inference
|
386 |
+
|
387 |
+
In order to run inference on CT2-ported model using python inference then please execute the following block of code from the root directory:
|
388 |
+
|
389 |
+
```python3
|
390 |
+
from inference.engine import Model
|
391 |
+
|
392 |
+
model = Model(ckpt_dir, model_type="ctranslate2")
|
393 |
+
|
394 |
+
sents = [sent1, sent2,...]
|
395 |
+
|
396 |
+
# for a batch of sentences
|
397 |
+
model.batch_translate(sents, src_lang, tgt_lang)
|
398 |
+
|
399 |
+
# for a paragraph
|
400 |
+
model.translate_paragraph(text, src_lang, tgt_lang)
|
401 |
+
```
|
402 |
+
|
403 |
+
## Evaluations
|
404 |
+
|
405 |
+
We consider the chrF++ as our primary metric. Additionally, we also report the BLEU and Comet scores.
|
406 |
+
We also perform statistical significance tests for each metric to ascertain whether the differences are statistically significant.
|
407 |
+
|
408 |
+
In order to run our evaluation scripts, you will need to organize the evaluation test sets into the following directory structure:
|
409 |
+
|
410 |
+
```bash
|
411 |
+
eval_benchmarks
|
412 |
+
├── flores
|
413 |
+
│ └── eng_Latn-asm_Beng
|
414 |
+
│ ├── test.eng_Latn
|
415 |
+
│ └── test.asm_Beng
|
416 |
+
├── in22-gen
|
417 |
+
├── in22-conv
|
418 |
+
├── ntrex
|
419 |
+
└── ...
|
420 |
+
```
|
421 |
+
|
422 |
+
To compute the BLEU and chrF++ scores for prediction file, you can use the following command:
|
423 |
+
|
424 |
+
```bash
|
425 |
+
bash compute_metrics.sh <pred_fname> <ref_fname> <tgt_lang>
|
426 |
+
```
|
427 |
+
|
428 |
+
- `pred_fname`: path to the model translations
|
429 |
+
- `ref_fname`: path to the reference translations
|
430 |
+
- `tgt_lang`: target language
|
431 |
+
|
432 |
+
In order to automate the inference over the individual test sets for En-Indic, you can use the following command:
|
433 |
+
|
434 |
+
```bash
|
435 |
+
bash eval.sh <devtest_data_dir> <ckpt_dir> <system>
|
436 |
+
```
|
437 |
+
|
438 |
+
- `<devtest_data_dir>`: path to the evaluation set with language pair subdirectories (for example, flores directory in the above tree structure)
|
439 |
+
- `<ckpt_dir>`: path to the fairseq model checkpoint directory
|
440 |
+
- `<system>`: system name suffix to store the predictions in the format `test.{lang}.pred.{system}`
|
441 |
+
|
442 |
+
In case of Indic-En evaluation, please use the following command:
|
443 |
+
|
444 |
+
```bash
|
445 |
+
bash eval_rev.sh <devtest_data_dir> <ckpt_dir> <system>
|
446 |
+
```
|
447 |
+
|
448 |
+
- `<devtest_data_dir>`: path to the evaluation set with language pair subdirectories (for example, flores directory in the above tree structure)
|
449 |
+
- `<ckpt_dir>`: path to the fairseq model checkpoint directory
|
450 |
+
- `<system>`: system name suffix to store the predictions in the format `test.{lang}.pred.{system}`
|
451 |
+
|
452 |
+
**_Note: You don’t need to reverse the test set directions for each language pair._**
|
453 |
+
|
454 |
+
In case of Indic-Indic evaluation, please use the following command:
|
455 |
+
|
456 |
+
```bash
|
457 |
+
bash pivot_eval.sh <devtest_data_dir> <pivot_lang> <src2pivot_ckpt_dir> <pivot2tgt_ckpt_dir> <system>
|
458 |
+
```
|
459 |
+
|
460 |
+
- `<devtest_data_dir>`: path to the evaluation set with language pair subdirectories (for example, flores directory in the above tree structure)
|
461 |
+
- `<pivot_lang>`: pivot language (default should be `eng_Latn`)
|
462 |
+
- `<src2pivot_ckpt_dir>`: path to the fairseq Indic-En model checkpoint directory
|
463 |
+
- `<pivot2tgt_ckpt_dir>`: path to the fairseq En-Indic model checkpoint directory
|
464 |
+
- `<system>`: system name suffix to store the predictions in the format test.{lang}.pred.{system}
|
465 |
+
|
466 |
+
In order to perform significance testing for BLEU and chrF++ metrics after you have the predictions for different systems, you can use the following command:
|
467 |
+
|
468 |
+
```bash
|
469 |
+
bash compute_comet_metrics_significance.sh <devtest_data_dir>
|
470 |
+
```
|
471 |
+
|
472 |
+
- `<devtest_data_dir>`: path to the evaluation set with language pair subdirectories (for example, flores directory in the above tree structure)
|
473 |
+
|
474 |
+
Similarly, to compute the COMET scores and perform significance testing on predictions of different systems, you can use the following command.
|
475 |
+
|
476 |
+
```bash
|
477 |
+
bash compute_comet_score.sh <devtest_data_dir>
|
478 |
+
```
|
479 |
+
|
480 |
+
- `<devtest_data_dir>`: path to the evaluation set with language pair subdirectories (for example, flores directory in the above tree structure)
|
481 |
+
|
482 |
+
Please note that as we compute significance tests with the same script and automate everything, it is best to have all the predictions for all the systems in place to avoid repeating anything.
|
483 |
+
Also, we define the systems in the script itself, if you want to try out other systems, make sure to edit it there itself.
|
484 |
+
|
485 |
+
### Baseline Evaluation
|
486 |
+
|
487 |
+
To generate the translation results for baseline models such as M2M-100, MBART, Azure, Google, and NLLB MoE, you can check the scripts provided in the "baseline_eval" directory of this repository. For NLLB distilled, you can either modify NLLB_MoE eval or use this [repository](https://github.com/pluiez/NLLB-inference). Similarly, for IndicTrans inference, please refer to this [repository](https://github.com/ai4bharat/IndicTrans).
|
488 |
+
|
489 |
+
You can download the translation outputs released as a part of this work from the [following section](#download-models-and-other-artifacts).
|
490 |
+
|
491 |
+
## LICENSE
|
492 |
+
|
493 |
+
The following table lists the licenses associated with the different artifacts released as a part of this work:
|
494 |
+
|
495 |
+
| Artifact | LICENSE |
|
496 |
+
| ----------------------------------------------------- | --------------------------------------------------------------------- |
|
497 |
+
| Existing Mined Corpora (NLLB & Samanantar) | [CC0](https://creativecommons.org/share-your-work/public-domain/cc0/) |
|
498 |
+
| Existing Seed Corpora (NLLB-Seed, ILCI, MASSIVE) | [CC0](https://creativecommons.org/share-your-work/public-domain/cc0/) |
|
499 |
+
| Newly Added Mined Corpora (Samanantar++ & Comparable) | [CC0](https://creativecommons.org/share-your-work/public-domain/cc0/) |
|
500 |
+
| Newly Added Seed Corpora (BPCC-H-Wiki & BPCC-H-Daily) | [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/) |
|
501 |
+
| Newly Created IN-22 test set (IN22-Gen & IN22-Conv) | [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/) |
|
502 |
+
| Back-translation data (BPCC-BT) | [CC0](https://creativecommons.org/share-your-work/public-domain/cc0/) |
|
503 |
+
| Model checkpoints | [MIT](https://github.com/ai4bharat/IndicTrans2/blob/main/LICENSE) |
|
504 |
+
|
505 |
+
The mined corpora collection (BPCC-Mined), existing seed corpora (NLLB-Seed, ILCI, MASSIVE), Backtranslation data (BPCC-BT), are released under the following licensing scheme:
|
506 |
+
|
507 |
+
- We do not own any of the text from which this data has been extracted.
|
508 |
+
- We license the actual packaging of this data under the Creative Commons [CC0 license (“no rights reserved”)](https://creativecommons.org/share-your-work/public-domain/cc0/).
|
509 |
+
- To the extent possible under law, [AI4Bharat](https://ai4bharat.iitm.ac.in/) has waived all copyright and related or neighboring rights to BPCC-Mined, existing seed corpora (NLLB-Seed, ILCI, MASSIVE) and BPCC-BT.
|
510 |
+
|
511 |
+
## Citation
|
512 |
+
|
513 |
+
```bibtex
|
514 |
+
@article{gala2023indictrans,
|
515 |
+
title={IndicTrans2: Towards High-Quality and Accessible Machine Translation Models for all 22 Scheduled Indian Languages},
|
516 |
+
author={Jay Gala and Pranjal A Chitale and A K Raghavan and Varun Gumma and Sumanth Doddapaneni and Aswanth Kumar M and Janki Atul Nawale and Anupama Sujatha and Ratish Puduppully and Vivek Raghavan and Pratyush Kumar and Mitesh M Khapra and Raj Dabre and Anoop Kunchukuttan},
|
517 |
+
journal={Transactions on Machine Learning Research},
|
518 |
+
issn={2835-8856},
|
519 |
+
year={2023},
|
520 |
+
url={https://openreview.net/forum?id=vfT4YuzAYA},
|
521 |
+
note={}
|
522 |
+
}
|
523 |
+
```
|
IndicTrans2/apply_sentence_piece.sh
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# This script tokenizes the preprocessed train and dev set using the trained spm models.
|
4 |
+
|
5 |
+
|
6 |
+
echo `date`
|
7 |
+
exp_dir=$1 # path to the experiment directory
|
8 |
+
data_dir=$2 # path to the data directory where all lang pairs are concatenated
|
9 |
+
bpe_dir=$3 # path to the tokenized data directory
|
10 |
+
src_lang=$4 # source language
|
11 |
+
tgt_lang=$5 # target language
|
12 |
+
split=$6 # name of the split
|
13 |
+
parallel_installed=${7:-false} # If GNU Parallel is installed or not
|
14 |
+
|
15 |
+
in_split_dir=$data_dir/$split
|
16 |
+
out_split_dir=$bpe_dir/$split
|
17 |
+
|
18 |
+
echo "Apply Sentence Piece tokenization to SRC corpus"
|
19 |
+
# for very large datasets, it is recommended to use gnu-parallel to speed up applying bpe
|
20 |
+
|
21 |
+
if $parallel_installed; then
|
22 |
+
parallel --pipe --keep-order \
|
23 |
+
spm_encode --model=$exp_dir/vocab/model.SRC \
|
24 |
+
--output_format=piece \
|
25 |
+
< $in_split_dir.$src_lang \
|
26 |
+
> $out_split_dir.$src_lang
|
27 |
+
else
|
28 |
+
spm_encode --model=$exp_dir/vocab/model.SRC \
|
29 |
+
--output_format=piece \
|
30 |
+
< $in_split_dir.$src_lang \
|
31 |
+
> $out_split_dir.$src_lang
|
32 |
+
fi
|
33 |
+
|
34 |
+
echo "Apply Sentence Piece tokenization to TGT corpus"
|
35 |
+
# for very large datasets, it is recommended to use gnu-parallel to speed up applying bpe
|
36 |
+
|
37 |
+
if $parallel_installed; then
|
38 |
+
parallel --pipe --keep-order \
|
39 |
+
spm_encode --model=$exp_dir/vocab/model.TGT \
|
40 |
+
--output_format=piece \
|
41 |
+
< $in_split_dir.$tgt_lang \
|
42 |
+
> $out_split_dir.$tgt_lang
|
43 |
+
else
|
44 |
+
spm_encode --model=$exp_dir/vocab/model.TGT \
|
45 |
+
--output_format=piece \
|
46 |
+
< $in_split_dir.$tgt_lang \
|
47 |
+
> $out_split_dir.$tgt_lang
|
48 |
+
fi
|
IndicTrans2/baseline_eval/azure_translate.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import glob
|
4 |
+
import requests
|
5 |
+
from urllib.parse import urlencode
|
6 |
+
from dotenv import dotenv_values
|
7 |
+
import traceback
|
8 |
+
import time
|
9 |
+
|
10 |
+
flores_to_iso = {
|
11 |
+
"asm_Beng": "as",
|
12 |
+
"ben_Beng": "bn",
|
13 |
+
"brx_Deva": "brx",
|
14 |
+
"doi_Deva": "doi",
|
15 |
+
"eng_Latn": "en",
|
16 |
+
"gom_Deva": "gom",
|
17 |
+
"guj_Gujr": "gu",
|
18 |
+
"hin_Deva": "hi",
|
19 |
+
"kan_Knda": "kn",
|
20 |
+
"kas_Arab": "ks",
|
21 |
+
"kas_Deva": "ks_Deva",
|
22 |
+
"mai_Deva": "mai",
|
23 |
+
"mal_Mlym": "ml",
|
24 |
+
"mar_Deva": "mr",
|
25 |
+
"mni_Beng": "mni_Beng",
|
26 |
+
"mni_Mtei": "mni",
|
27 |
+
"npi_Deva": "ne",
|
28 |
+
"ory_Orya": "or",
|
29 |
+
"pan_Guru": "pa",
|
30 |
+
"san_Deva": "sa",
|
31 |
+
"sat_Olck": "sat",
|
32 |
+
"snd_Arab": "sd",
|
33 |
+
"snd_Deva": "sd_Deva",
|
34 |
+
"tam_Taml": "ta",
|
35 |
+
"tel_Telu": "te",
|
36 |
+
"urd_Arab": "ur",
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
class AzureTranslator:
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
subscription_key: str,
|
44 |
+
region: str,
|
45 |
+
endpoint: str = "https://api.cognitive.microsofttranslator.com",
|
46 |
+
) -> None:
|
47 |
+
self.http_headers = {
|
48 |
+
"Ocp-Apim-Subscription-Key": subscription_key,
|
49 |
+
"Ocp-Apim-Subscription-Region": region,
|
50 |
+
}
|
51 |
+
self.translate_endpoint = endpoint + "/translate?api-version=3.0&"
|
52 |
+
self.languages_endpoint = endpoint + "/languages?api-version=3.0"
|
53 |
+
|
54 |
+
self.supported_languages = self.get_supported_languages()
|
55 |
+
|
56 |
+
def get_supported_languages(self) -> dict:
|
57 |
+
return requests.get(self.languages_endpoint).json()["translation"]
|
58 |
+
|
59 |
+
def batch_translate(self, texts: list, src_lang: str, tgt_lang: str) -> list:
|
60 |
+
if not texts:
|
61 |
+
return texts
|
62 |
+
|
63 |
+
src_lang = flores_to_iso[src_lang]
|
64 |
+
tgt_lang = flores_to_iso[tgt_lang]
|
65 |
+
|
66 |
+
if src_lang not in self.supported_languages:
|
67 |
+
raise NotImplementedError(
|
68 |
+
f"Source language code: `{src_lang}` not supported!"
|
69 |
+
)
|
70 |
+
|
71 |
+
if tgt_lang not in self.supported_languages:
|
72 |
+
raise NotImplementedError(
|
73 |
+
f"Target language code: `{tgt_lang}` not supported!"
|
74 |
+
)
|
75 |
+
|
76 |
+
body = [{"text": text} for text in texts]
|
77 |
+
query_string = urlencode(
|
78 |
+
{
|
79 |
+
"from": src_lang,
|
80 |
+
"to": tgt_lang,
|
81 |
+
}
|
82 |
+
)
|
83 |
+
|
84 |
+
try:
|
85 |
+
response = requests.post(
|
86 |
+
self.translate_endpoint + query_string,
|
87 |
+
headers=self.http_headers,
|
88 |
+
json=body,
|
89 |
+
)
|
90 |
+
except:
|
91 |
+
traceback.print_exc()
|
92 |
+
return None
|
93 |
+
|
94 |
+
try:
|
95 |
+
response = response.json()
|
96 |
+
except:
|
97 |
+
traceback.print_exc()
|
98 |
+
print("Response:", response.text)
|
99 |
+
return None
|
100 |
+
|
101 |
+
return [payload["translations"][0]["text"] for payload in response]
|
102 |
+
|
103 |
+
def text_translate(self, text: str, src_lang: str, tgt_lang: str) -> str:
|
104 |
+
return self.batch_translate([text], src_lang, tgt_lang)[0]
|
105 |
+
|
106 |
+
|
107 |
+
if __name__ == "__main__":
|
108 |
+
root_dir = sys.argv[1]
|
109 |
+
|
110 |
+
# Expects a .env file containing the API credentials.
|
111 |
+
config = dotenv_values(os.path.join(os.path.dirname(__file__), ".env"))
|
112 |
+
|
113 |
+
t = AzureTranslator(
|
114 |
+
config["AZURE_TRANSLATOR_TEXT_SUBSCRIPTION_KEY"],
|
115 |
+
config["AZURE_TRANSLATOR_TEXT_REGION"],
|
116 |
+
config["AZURE_TRANSLATOR_TEXT_ENDPOINT"],
|
117 |
+
)
|
118 |
+
|
119 |
+
pairs = sorted(glob.glob(os.path.join(root_dir, "*")))
|
120 |
+
|
121 |
+
for i, pair in enumerate(pairs):
|
122 |
+
basename = os.path.basename(pair)
|
123 |
+
|
124 |
+
print(pair)
|
125 |
+
|
126 |
+
src_lang, tgt_lang = basename.split("-")
|
127 |
+
|
128 |
+
print(f"{src_lang} - {tgt_lang}")
|
129 |
+
|
130 |
+
# source to target translations
|
131 |
+
src_infname = os.path.join(pair, f"test.{src_lang}")
|
132 |
+
tgt_outfname = os.path.join(pair, f"test.{tgt_lang}.pred.azure")
|
133 |
+
if not os.path.exists(src_infname):
|
134 |
+
continue
|
135 |
+
|
136 |
+
src_sents = [
|
137 |
+
sent.replace("\n", "").strip()
|
138 |
+
for sent in open(src_infname, "r").read().split("\n")
|
139 |
+
if sent
|
140 |
+
]
|
141 |
+
|
142 |
+
if not os.path.exists(tgt_outfname):
|
143 |
+
try:
|
144 |
+
translations = []
|
145 |
+
for i in range(0, len(src_sents), 128):
|
146 |
+
start, end = i, int(min(i + 128, len(src_sents)))
|
147 |
+
translations.extend(
|
148 |
+
t.batch_translate(src_sents[start:end], src_lang, tgt_lang)
|
149 |
+
)
|
150 |
+
with open(tgt_outfname, "w") as f:
|
151 |
+
f.write("\n".join(translations))
|
152 |
+
|
153 |
+
time.sleep(10)
|
154 |
+
except Exception as e:
|
155 |
+
print(e)
|
156 |
+
continue
|
157 |
+
|
158 |
+
# target to source translations
|
159 |
+
tgt_infname = os.path.join(pair, f"test.{tgt_lang}")
|
160 |
+
src_outfname = os.path.join(pair, f"test.{src_lang}.pred.azure")
|
161 |
+
if not os.path.exists(tgt_infname):
|
162 |
+
continue
|
163 |
+
|
164 |
+
tgt_sents = [
|
165 |
+
sent.replace("\n", "").strip()
|
166 |
+
for sent in open(tgt_infname, "r").read().split("\n")
|
167 |
+
if sent
|
168 |
+
]
|
169 |
+
|
170 |
+
if not os.path.exists(src_outfname):
|
171 |
+
try:
|
172 |
+
translations = []
|
173 |
+
for i in range(0, len(tgt_sents), 128):
|
174 |
+
start, end = i, int(min(i + 128, len(tgt_sents)))
|
175 |
+
translations.extend(
|
176 |
+
t.batch_translate(tgt_sents[start:end], tgt_lang, src_lang)
|
177 |
+
)
|
178 |
+
with open(src_outfname, "w") as f:
|
179 |
+
f.write("\n".join(translations))
|
180 |
+
except Exception as e:
|
181 |
+
continue
|
182 |
+
|
183 |
+
time.sleep(10)
|
IndicTrans2/baseline_eval/google_translate.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import glob
|
4 |
+
from tqdm import tqdm
|
5 |
+
from google.cloud import translate
|
6 |
+
|
7 |
+
# Expects a json file containing the API credentials.
|
8 |
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.join(
|
9 |
+
os.path.dirname(__file__), r"api_key.json"
|
10 |
+
)
|
11 |
+
|
12 |
+
flores_to_iso = {
|
13 |
+
"asm_Beng": "as",
|
14 |
+
"ben_Beng": "bn",
|
15 |
+
"doi_Deva": "doi",
|
16 |
+
"eng_Latn": "en",
|
17 |
+
"gom_Deva": "gom",
|
18 |
+
"guj_Gujr": "gu",
|
19 |
+
"hin_Deva": "hi",
|
20 |
+
"kan_Knda": "kn",
|
21 |
+
"mai_Deva": "mai",
|
22 |
+
"mal_Mlym": "ml",
|
23 |
+
"mar_Deva": "mr",
|
24 |
+
"mni_Mtei": "mni_Mtei",
|
25 |
+
"npi_Deva": "ne",
|
26 |
+
"ory_Orya": "or",
|
27 |
+
"pan_Guru": "pa",
|
28 |
+
"san_Deva": "sa",
|
29 |
+
"sat_Olck": "sat",
|
30 |
+
"snd_Arab": "sd",
|
31 |
+
"tam_Taml": "ta",
|
32 |
+
"tel_Telu": "te",
|
33 |
+
"urd_Arab": "ur",
|
34 |
+
}
|
35 |
+
|
36 |
+
|
37 |
+
# Copy the project id from the json file containing API credentials
|
38 |
+
def translate_text(text, src_lang, tgt_lang, project_id="project_id"):
|
39 |
+
|
40 |
+
src_lang = flores_to_iso[src_lang]
|
41 |
+
tgt_lang = flores_to_iso[tgt_lang]
|
42 |
+
|
43 |
+
if src_lang == "mni_Mtei":
|
44 |
+
src_lang = "mni-Mtei"
|
45 |
+
|
46 |
+
if tgt_lang == "mni_Mtei":
|
47 |
+
tgt_lang = "mni-Mtei"
|
48 |
+
|
49 |
+
client = translate.TranslationServiceClient()
|
50 |
+
|
51 |
+
location = "global"
|
52 |
+
|
53 |
+
parent = f"projects/{project_id}/locations/{location}"
|
54 |
+
|
55 |
+
response = client.translate_text(
|
56 |
+
request={
|
57 |
+
"parent": parent,
|
58 |
+
"contents": [text],
|
59 |
+
"mime_type": "text/plain", # mime types: text/plain, text/html
|
60 |
+
"source_language_code": src_lang,
|
61 |
+
"target_language_code": tgt_lang,
|
62 |
+
}
|
63 |
+
)
|
64 |
+
|
65 |
+
translated_text = ""
|
66 |
+
for translation in response.translations:
|
67 |
+
translated_text += translation.translated_text
|
68 |
+
|
69 |
+
return translated_text
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
root_dir = sys.argv[1]
|
74 |
+
|
75 |
+
pairs = sorted(glob.glob(os.path.join(root_dir, "*")))
|
76 |
+
|
77 |
+
for pair in pairs:
|
78 |
+
|
79 |
+
print(pair)
|
80 |
+
|
81 |
+
basename = os.path.basename(pair)
|
82 |
+
|
83 |
+
src_lang, tgt_lang = basename.split("-")
|
84 |
+
if src_lang not in flores_to_iso.keys() or tgt_lang not in flores_to_iso.keys():
|
85 |
+
continue
|
86 |
+
|
87 |
+
if src_lang == "eng_Latn":
|
88 |
+
lang = tgt_lang
|
89 |
+
else:
|
90 |
+
lang = src_lang
|
91 |
+
|
92 |
+
lang = flores_to_iso[lang]
|
93 |
+
|
94 |
+
if lang not in "as bn doi gom gu hi kn mai ml mni_Mtei mr ne or pa sa sd ta te ur":
|
95 |
+
continue
|
96 |
+
|
97 |
+
print(f"{src_lang} - {tgt_lang}")
|
98 |
+
|
99 |
+
# source to target translations
|
100 |
+
|
101 |
+
src_infname = os.path.join(pair, f"test.{src_lang}")
|
102 |
+
tgt_outfname = os.path.join(pair, f"test.{tgt_lang}.pred.google")
|
103 |
+
if os.path.exists(src_infname) and not os.path.exists(tgt_outfname):
|
104 |
+
src_sents = [
|
105 |
+
sent.replace("\n", "").strip()
|
106 |
+
for sent in open(src_infname, "r").read().split("\n")
|
107 |
+
if sent
|
108 |
+
]
|
109 |
+
translations = [
|
110 |
+
translate_text(text, src_lang, tgt_lang).strip() for text in tqdm(src_sents)
|
111 |
+
]
|
112 |
+
with open(tgt_outfname, "w") as f:
|
113 |
+
f.write("\n".join(translations))
|
114 |
+
|
115 |
+
# # target to source translations
|
116 |
+
tgt_infname = os.path.join(pair, f"test.{tgt_lang}")
|
117 |
+
src_outfname = os.path.join(pair, f"test.{src_lang}.pred.google")
|
118 |
+
if os.path.exists(tgt_infname) and not os.path.exists(src_outfname):
|
119 |
+
tgt_sents = [
|
120 |
+
sent.replace("\n", "").strip()
|
121 |
+
for sent in open(tgt_infname, "r").read().split("\n")
|
122 |
+
if sent
|
123 |
+
]
|
124 |
+
translations = [
|
125 |
+
translate_text(text, tgt_lang, src_lang).strip() for text in tqdm(tgt_sents)
|
126 |
+
]
|
127 |
+
|
128 |
+
with open(src_outfname, "w") as f:
|
129 |
+
f.write("\n".join(translations))
|
IndicTrans2/baseline_eval/m2m100_inference.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import sys
|
4 |
+
from tqdm import tqdm
|
5 |
+
import torch
|
6 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
7 |
+
|
8 |
+
|
9 |
+
# dictionary mapping flores codes to M2M-100 supported codes
|
10 |
+
langs_supported = {
|
11 |
+
"eng_Latn": "en",
|
12 |
+
"ben_Beng": "bn",
|
13 |
+
"guj_Gujr": "gu",
|
14 |
+
"hin_Deva": "hi",
|
15 |
+
"kan_Knda": "kn",
|
16 |
+
"mal_Mlym": "ml",
|
17 |
+
"mar_Deva": "mr",
|
18 |
+
"npi_Deva": "ne",
|
19 |
+
"ory_Orya": "or",
|
20 |
+
"pan_Guru": "pa",
|
21 |
+
"snd_Arab": "sd",
|
22 |
+
"tam_Taml": "ta",
|
23 |
+
"urd_Arab": "ur",
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def predict(batch, tokenizer, model, bos_token_id):
|
28 |
+
encoded_batch = tokenizer(batch, padding=True, return_tensors="pt").to(model.device)
|
29 |
+
generated_tokens = model.generate(
|
30 |
+
**encoded_batch,
|
31 |
+
num_beams=5,
|
32 |
+
max_length=256,
|
33 |
+
min_length=0,
|
34 |
+
forced_bos_token_id=bos_token_id,
|
35 |
+
)
|
36 |
+
hypothesis = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
37 |
+
return hypothesis
|
38 |
+
|
39 |
+
|
40 |
+
def main(devtest_data_dir, batch_size):
|
41 |
+
# load the pre-trained M2M-100 tokenizer and model
|
42 |
+
model_name = "facebook/m2m100-12B-last-ckpt"
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
44 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
45 |
+
model.eval()
|
46 |
+
|
47 |
+
# iterate over a list of language pairs from `devtest_data_dir`
|
48 |
+
for pair in sorted(os.listdir(devtest_data_dir)):
|
49 |
+
if "-" not in pair:
|
50 |
+
continue
|
51 |
+
|
52 |
+
src_lang, tgt_lang = pair.split("-")
|
53 |
+
|
54 |
+
# check if the source and target languages are supported
|
55 |
+
if (
|
56 |
+
src_lang not in langs_supported.keys()
|
57 |
+
or tgt_lang not in langs_supported.keys()
|
58 |
+
):
|
59 |
+
print(f"Skipping {src_lang}-{tgt_lang} ...")
|
60 |
+
continue
|
61 |
+
|
62 |
+
# -------------------------------------------------------------------
|
63 |
+
# source to target evaluation
|
64 |
+
# -------------------------------------------------------------------
|
65 |
+
print(f"Evaluating {src_lang}-{tgt_lang} ...")
|
66 |
+
|
67 |
+
infname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}")
|
68 |
+
outfname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}.pred.m2m100")
|
69 |
+
|
70 |
+
with open(infname, "r") as f:
|
71 |
+
src_sents = f.read().split("\n")
|
72 |
+
|
73 |
+
add_new_line = False
|
74 |
+
if src_sents[-1] == "":
|
75 |
+
add_new_line = True
|
76 |
+
src_sents = src_sents[:-1]
|
77 |
+
|
78 |
+
# set the source language for tokenization
|
79 |
+
tokenizer.src_lang = langs_supported[src_lang]
|
80 |
+
|
81 |
+
# process sentences in batches and generate predictions
|
82 |
+
hypothesis = []
|
83 |
+
for i in tqdm(range(0, len(src_sents), batch_size)):
|
84 |
+
start, end = i, int(min(len(src_sents), i + batch_size))
|
85 |
+
batch = src_sents[start:end]
|
86 |
+
bos_token_id = tokenizer.lang_code_to_id[langs_supported[tgt_lang]]
|
87 |
+
hypothesis += predict(batch, tokenizer, model, bos_token_id)
|
88 |
+
|
89 |
+
assert len(hypothesis) == len(src_sents)
|
90 |
+
|
91 |
+
hypothesis = [
|
92 |
+
re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
|
93 |
+
for x in hypothesis
|
94 |
+
]
|
95 |
+
if add_new_line:
|
96 |
+
hypothesis = hypothesis
|
97 |
+
|
98 |
+
with open(outfname, "w") as f:
|
99 |
+
f.write("\n".join(hypothesis))
|
100 |
+
|
101 |
+
# -------------------------------------------------------------------
|
102 |
+
# target to source evaluation
|
103 |
+
# -------------------------------------------------------------------
|
104 |
+
infname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}")
|
105 |
+
outfname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}.pred.m2m100")
|
106 |
+
|
107 |
+
with open(infname, "r") as f:
|
108 |
+
src_sents = f.read().split("\n")
|
109 |
+
|
110 |
+
add_new_line = False
|
111 |
+
if src_sents[-1] == "":
|
112 |
+
add_new_line = True
|
113 |
+
src_sents = src_sents[:-1]
|
114 |
+
|
115 |
+
# set the source language for tokenization
|
116 |
+
tokenizer.src_lang = langs_supported[tgt_lang]
|
117 |
+
|
118 |
+
# process sentences in batches and generate predictions
|
119 |
+
hypothesis = []
|
120 |
+
for i in tqdm(range(0, len(src_sents), batch_size)):
|
121 |
+
start, end = i, int(min(len(src_sents), i + batch_size))
|
122 |
+
batch = src_sents[start:end]
|
123 |
+
bos_token_id = tokenizer.lang_code_to_id[langs_supported[src_lang]]
|
124 |
+
hypothesis += predict(batch, tokenizer, model, bos_token_id)
|
125 |
+
|
126 |
+
assert len(hypothesis) == len(src_sents)
|
127 |
+
|
128 |
+
hypothesis = [
|
129 |
+
re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
|
130 |
+
for x in hypothesis
|
131 |
+
]
|
132 |
+
if add_new_line:
|
133 |
+
hypothesis = hypothesis
|
134 |
+
|
135 |
+
with open(outfname, "w") as f:
|
136 |
+
f.write("\n".join(hypothesis))
|
137 |
+
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
# expects En-X subdirectories pairs within the devtest data directory
|
141 |
+
devtest_data_dir = sys.argv[1]
|
142 |
+
batch_size = int(sys.argv[2])
|
143 |
+
|
144 |
+
if not torch.cuda.is_available():
|
145 |
+
print("No GPU available")
|
146 |
+
sys.exit(1)
|
147 |
+
|
148 |
+
main(devtest_data_dir, batch_size)
|
IndicTrans2/baseline_eval/mbart_inference.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import sys
|
4 |
+
from tqdm import tqdm
|
5 |
+
import torch
|
6 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
7 |
+
|
8 |
+
|
9 |
+
# dictionary mapping flores codes to mBART supported codes
|
10 |
+
langs_supported = {
|
11 |
+
"eng_Latn": "en_XX",
|
12 |
+
"guj_Gujr": "gu_IN",
|
13 |
+
"hin_Deva": "hi_IN",
|
14 |
+
"npi_Deva": "ne_NP",
|
15 |
+
"ben_Beng": "bn_IN",
|
16 |
+
"mal_Mlym": "ml_IN",
|
17 |
+
"mar_Deva": "mr_IN",
|
18 |
+
"tam_Taml": "ta_IN",
|
19 |
+
"tel_Telu": "te_IN",
|
20 |
+
"urd_Arab": "ur_PK",
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
def predict(batch, tokenizer, model, bos_token_id):
|
25 |
+
encoded_batch = tokenizer(batch, padding=True, return_tensors="pt").to(model.device)
|
26 |
+
generated_tokens = model.generate(
|
27 |
+
**encoded_batch,
|
28 |
+
num_beams=5,
|
29 |
+
max_length=256,
|
30 |
+
min_length=0,
|
31 |
+
forced_bos_token_id=bos_token_id,
|
32 |
+
)
|
33 |
+
hypothesis = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
34 |
+
return hypothesis
|
35 |
+
|
36 |
+
|
37 |
+
def main(devtest_data_dir, batch_size):
|
38 |
+
# load the pre-trained mBART tokenizers and models for English-XX and XX-English translation
|
39 |
+
enxx_model_name = "facebook/mbart-large-50-one-to-many-mmt"
|
40 |
+
xxen_model_name = "facebook/mbart-large-50-many-to-one-mmt"
|
41 |
+
tokenizers = {
|
42 |
+
"enxx": AutoTokenizer.from_pretrained(enxx_model_name),
|
43 |
+
"xxen": AutoTokenizer.from_pretrained(xxen_model_name),
|
44 |
+
}
|
45 |
+
models = {
|
46 |
+
"enxx": AutoModelForSeq2SeqLM.from_pretrained(enxx_model_name).cuda(),
|
47 |
+
"xxen": AutoModelForSeq2SeqLM.from_pretrained(xxen_model_name).cuda(),
|
48 |
+
}
|
49 |
+
|
50 |
+
# set the models to evaluation mode
|
51 |
+
for model_name in models:
|
52 |
+
models[model_name].eval()
|
53 |
+
|
54 |
+
# iterate over a list of language pairs from `devtest_data_dir`
|
55 |
+
for pair in sorted(os.listdir(devtest_data_dir)):
|
56 |
+
if "-" not in pair:
|
57 |
+
continue
|
58 |
+
|
59 |
+
src_lang, tgt_lang = pair.split("-")
|
60 |
+
|
61 |
+
# check if the source and target languages are supported
|
62 |
+
if (
|
63 |
+
src_lang not in langs_supported.keys()
|
64 |
+
or tgt_lang not in langs_supported.keys()
|
65 |
+
):
|
66 |
+
print(f"Skipping {src_lang}-{tgt_lang} ...")
|
67 |
+
continue
|
68 |
+
|
69 |
+
# -------------------------------------------------------------------
|
70 |
+
# source to target evaluation
|
71 |
+
# -------------------------------------------------------------------
|
72 |
+
print(f"Evaluating {src_lang}-{tgt_lang} ...")
|
73 |
+
|
74 |
+
infname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}")
|
75 |
+
outfname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}.pred.mbart50")
|
76 |
+
|
77 |
+
with open(infname, "r") as f:
|
78 |
+
src_sents = f.read().split("\n")
|
79 |
+
|
80 |
+
add_new_line = False
|
81 |
+
if src_sents[-1] == "":
|
82 |
+
add_new_line = True
|
83 |
+
src_sents = src_sents[:-1]
|
84 |
+
|
85 |
+
# set the source language for tokenization
|
86 |
+
tokenizers["enxx"].src_lang = langs_supported[src_lang]
|
87 |
+
|
88 |
+
# process sentences in batches and generate predictions
|
89 |
+
hypothesis = []
|
90 |
+
for i in tqdm(range(0, len(src_sents), batch_size)):
|
91 |
+
start, end = i, int(min(len(src_sents), i + batch_size))
|
92 |
+
batch = src_sents[start:end]
|
93 |
+
bos_token_id = tokenizers["enxx"].lang_code_to_id[langs_supported[tgt_lang]]
|
94 |
+
hypothesis += predict(
|
95 |
+
batch, tokenizers["enxx"], models["enxx"], bos_token_id
|
96 |
+
)
|
97 |
+
|
98 |
+
assert len(hypothesis) == len(src_sents)
|
99 |
+
|
100 |
+
hypothesis = [
|
101 |
+
re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
|
102 |
+
for x in hypothesis
|
103 |
+
]
|
104 |
+
if add_new_line:
|
105 |
+
hypothesis = hypothesis
|
106 |
+
|
107 |
+
with open(outfname, "w") as f:
|
108 |
+
f.write("\n".join(hypothesis))
|
109 |
+
|
110 |
+
# -------------------------------------------------------------------
|
111 |
+
# target to source evaluation
|
112 |
+
# -------------------------------------------------------------------
|
113 |
+
infname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}")
|
114 |
+
outfname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}.pred.mbart50")
|
115 |
+
|
116 |
+
with open(infname, "r") as f:
|
117 |
+
src_sents = f.read().split("\n")
|
118 |
+
|
119 |
+
add_new_line = False
|
120 |
+
if src_sents[-1] == "":
|
121 |
+
add_new_line = True
|
122 |
+
src_sents = src_sents[:-1]
|
123 |
+
|
124 |
+
# set the source language for tokenization
|
125 |
+
tokenizers["xxen"].src_lang = langs_supported[tgt_lang]
|
126 |
+
|
127 |
+
# process sentences in batches and generate predictions
|
128 |
+
hypothesis = []
|
129 |
+
for i in tqdm(range(0, len(src_sents), batch_size)):
|
130 |
+
start, end = i, int(min(len(src_sents), i + batch_size))
|
131 |
+
batch = src_sents[start:end]
|
132 |
+
bos_token_id = tokenizers["xxen"].lang_code_to_id[langs_supported[src_lang]]
|
133 |
+
hypothesis += predict(
|
134 |
+
batch, tokenizers["xxen"], models["xxen"], bos_token_id
|
135 |
+
)
|
136 |
+
|
137 |
+
assert len(hypothesis) == len(src_sents)
|
138 |
+
|
139 |
+
hypothesis = [
|
140 |
+
re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
|
141 |
+
for x in hypothesis
|
142 |
+
]
|
143 |
+
if add_new_line:
|
144 |
+
hypothesis = hypothesis
|
145 |
+
|
146 |
+
with open(outfname, "w") as f:
|
147 |
+
f.write("\n".join(hypothesis))
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
# expects En-X subdirectories pairs within the devtest data directory
|
152 |
+
devtest_data_dir = sys.argv[1]
|
153 |
+
batch_size = int(sys.argv[2])
|
154 |
+
|
155 |
+
if not torch.cuda.is_available():
|
156 |
+
print("No GPU available")
|
157 |
+
sys.exit(1)
|
158 |
+
|
159 |
+
main(devtest_data_dir, batch_size)
|
IndicTrans2/baseline_eval/nllb_moe_cpu_inference.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import sys
|
4 |
+
from tqdm import tqdm
|
5 |
+
import torch
|
6 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
7 |
+
|
8 |
+
langs_supported = [
|
9 |
+
"asm_Beng",
|
10 |
+
"ben_Beng",
|
11 |
+
"guj_Gujr",
|
12 |
+
"eng_Latn",
|
13 |
+
"hin_Deva",
|
14 |
+
"kas_Deva",
|
15 |
+
"kas_Arab",
|
16 |
+
"kan_Knda",
|
17 |
+
"mal_Mlym",
|
18 |
+
"mai_Deva",
|
19 |
+
"mar_Deva",
|
20 |
+
"mni_Beng",
|
21 |
+
"npi_Deva",
|
22 |
+
"ory_Orya",
|
23 |
+
"pan_Guru",
|
24 |
+
"san_Deva",
|
25 |
+
"snd_Arab",
|
26 |
+
"sat_Olck",
|
27 |
+
"tam_Taml",
|
28 |
+
"tel_Telu",
|
29 |
+
"urd_Arab",
|
30 |
+
]
|
31 |
+
|
32 |
+
|
33 |
+
def predict(batch, tokenizer, model, bos_token_id):
|
34 |
+
encoded_batch = tokenizer(batch, padding=True, return_tensors="pt").to(model.device)
|
35 |
+
generated_tokens = model.generate(
|
36 |
+
**encoded_batch,
|
37 |
+
num_beams=5,
|
38 |
+
max_length=256,
|
39 |
+
min_length=0,
|
40 |
+
forced_bos_token_id=bos_token_id,
|
41 |
+
)
|
42 |
+
hypothesis = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
43 |
+
return hypothesis
|
44 |
+
|
45 |
+
|
46 |
+
def main(devtest_data_dir, batch_size):
|
47 |
+
# load the pre-trained NLLB tokenizer and model
|
48 |
+
model_name = "facebook/nllb-moe-54b"
|
49 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
50 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
51 |
+
model.eval()
|
52 |
+
|
53 |
+
# iterate over a list of language pairs from `devtest_data_dir`
|
54 |
+
for pair in sorted(os.listdir(devtest_data_dir)):
|
55 |
+
if "-" not in pair:
|
56 |
+
continue
|
57 |
+
|
58 |
+
src_lang, tgt_lang = pair.split("-")
|
59 |
+
|
60 |
+
# check if the source and target languages are supported
|
61 |
+
if (
|
62 |
+
src_lang not in langs_supported.keys()
|
63 |
+
or tgt_lang not in langs_supported.keys()
|
64 |
+
):
|
65 |
+
print(f"Skipping {src_lang}-{tgt_lang} ...")
|
66 |
+
continue
|
67 |
+
|
68 |
+
# -------------------------------------------------------------------
|
69 |
+
# source to target evaluation
|
70 |
+
# -------------------------------------------------------------------
|
71 |
+
print(f"Evaluating {src_lang}-{tgt_lang} ...")
|
72 |
+
|
73 |
+
infname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}")
|
74 |
+
outfname = os.path.join(
|
75 |
+
devtest_data_dir, pair, f"test.{tgt_lang}.pred.nllb_moe"
|
76 |
+
)
|
77 |
+
|
78 |
+
with open(infname, "r") as f:
|
79 |
+
src_sents = f.read().split("\n")
|
80 |
+
|
81 |
+
add_new_line = False
|
82 |
+
if src_sents[-1] == "":
|
83 |
+
add_new_line = True
|
84 |
+
src_sents = src_sents[:-1]
|
85 |
+
|
86 |
+
# set the source language for tokenization
|
87 |
+
tokenizer.src_lang = src_lang
|
88 |
+
|
89 |
+
# process sentences in batches and generate predictions
|
90 |
+
hypothesis = []
|
91 |
+
for i in tqdm(range(0, len(src_sents), batch_size)):
|
92 |
+
start, end = i, int(min(len(src_sents), i + batch_size))
|
93 |
+
batch = src_sents[start:end]
|
94 |
+
if tgt_lang == "sat_Olck":
|
95 |
+
bos_token_id = tokenizer.lang_code_to_id["sat_Beng"]
|
96 |
+
else:
|
97 |
+
bos_token_id = tokenizer.lang_code_to_id[tgt_lang]
|
98 |
+
hypothesis += predict(batch, tokenizer, model, bos_token_id)
|
99 |
+
|
100 |
+
assert len(hypothesis) == len(src_sents)
|
101 |
+
|
102 |
+
hypothesis = [
|
103 |
+
re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
|
104 |
+
for x in hypothesis
|
105 |
+
]
|
106 |
+
if add_new_line:
|
107 |
+
hypothesis = hypothesis
|
108 |
+
|
109 |
+
with open(outfname, "w") as f:
|
110 |
+
f.write("\n".join(hypothesis))
|
111 |
+
|
112 |
+
# -------------------------------------------------------------------
|
113 |
+
# target to source evaluation
|
114 |
+
# -------------------------------------------------------------------
|
115 |
+
infname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}")
|
116 |
+
outfname = os.path.join(
|
117 |
+
devtest_data_dir, pair, f"test.{src_lang}.pred.nllb_moe"
|
118 |
+
)
|
119 |
+
|
120 |
+
with open(infname, "r") as f:
|
121 |
+
src_sents = f.read().split("\n")
|
122 |
+
|
123 |
+
add_new_line = False
|
124 |
+
if src_sents[-1] == "":
|
125 |
+
add_new_line = True
|
126 |
+
src_sents = src_sents[:-1]
|
127 |
+
|
128 |
+
# set the source language for tokenization
|
129 |
+
tokenizer.src_lang = "sat_Beng" if tgt_lang == "sat_Olck" else tgt_lang
|
130 |
+
|
131 |
+
# process sentences in batches and generate predictions
|
132 |
+
hypothesis = []
|
133 |
+
for i in tqdm(range(0, len(src_sents), batch_size)):
|
134 |
+
start, end = i, int(min(len(src_sents), i + batch_size))
|
135 |
+
batch = src_sents[start:end]
|
136 |
+
bos_token_id = tokenizer.lang_code_to_id[langs_supported[src_lang]]
|
137 |
+
hypothesis += predict(batch, tokenizer, model, bos_token_id)
|
138 |
+
|
139 |
+
assert len(hypothesis) == len(src_sents)
|
140 |
+
|
141 |
+
hypothesis = [
|
142 |
+
re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
|
143 |
+
for x in hypothesis
|
144 |
+
]
|
145 |
+
if add_new_line:
|
146 |
+
hypothesis = hypothesis
|
147 |
+
|
148 |
+
with open(outfname, "w") as f:
|
149 |
+
f.write("\n".join(hypothesis))
|
150 |
+
|
151 |
+
|
152 |
+
if __name__ == "__main__":
|
153 |
+
# expects En-X subdirectories pairs within the devtest data directory
|
154 |
+
devtest_data_dir = sys.argv[1]
|
155 |
+
batch_size = int(sys.argv[2])
|
156 |
+
|
157 |
+
main(devtest_data_dir, batch_size)
|
IndicTrans2/compute_comet_score.sh
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# This script computes COMET metrics and also performs significance testing on the evaluation set
|
4 |
+
# where each subdirectory contains En-X pair
|
5 |
+
|
6 |
+
|
7 |
+
echo `date`
|
8 |
+
devtest_data_dir=$1 # path to the evaluation directory
|
9 |
+
model_name=${2-"Unbabel/wmt22-comet-da"} # name of the model checkpoint
|
10 |
+
|
11 |
+
# predefined list of languages supported by COMET
|
12 |
+
langs=(asm_Beng ben_Beng guj_Gujr hin_Deva kan_Knda mal_Mlym mar_Deva ory_Orya pan_Guru tam_Taml tel_Telu urd_Arab)
|
13 |
+
|
14 |
+
# we predefine a set of systems which we consider for evaluation
|
15 |
+
# feel free to change the below line in case you want to add or remove any system
|
16 |
+
system=(google azure nllb mbart50 m2m100 it1 it2)
|
17 |
+
|
18 |
+
|
19 |
+
# iterate over the list of predefined languages
|
20 |
+
for lang in "${langs[@]}"; do
|
21 |
+
|
22 |
+
mkdir -p "$devtest_data_dir/eng_Latn-$lang/comet"
|
23 |
+
|
24 |
+
# --------------------------------------------------------------
|
25 |
+
# COMET score computation
|
26 |
+
# --------------------------------------------------------------
|
27 |
+
|
28 |
+
# iterate over the list of predefined systems
|
29 |
+
for sys in "${system[@]}"; do
|
30 |
+
|
31 |
+
echo "${sys}"
|
32 |
+
|
33 |
+
# en - indic direction
|
34 |
+
if [ -f "$devtest_data_dir/eng_Latn-$lang/test.$lang.pred.$sys" ]; then
|
35 |
+
echo "eng_Latn-${lang}"
|
36 |
+
|
37 |
+
src_fname=$devtest_data_dir/eng_Latn-$lang/test.eng_Latn
|
38 |
+
pred_fname=$devtest_data_dir/eng_Latn-$lang/test.$lang.pred.$sys
|
39 |
+
ref_fname=$devtest_data_dir/eng_Latn-$lang/test.$lang
|
40 |
+
out_fname=$devtest_data_dir/eng_Latn-$lang/comet/eng_Latn_${lang}_${sys}_comet.txt
|
41 |
+
|
42 |
+
# Compute COMET scores using the `comet-score`
|
43 |
+
comet-score -s $src_fname -t $pred_fname -r $ref_fname --gpus 1 --model $model_name --quiet --only_system > $out_fname
|
44 |
+
fi
|
45 |
+
|
46 |
+
# indic - en direction
|
47 |
+
if [ -f "$devtest_data_dir/eng_Latn-$lang/test.eng_Latn.pred.$sys" ]; then
|
48 |
+
echo "${lang}-eng_Latn"
|
49 |
+
|
50 |
+
src_fname=$devtest_data_dir/eng_Latn-$lang/test.$lang
|
51 |
+
pred_fname=$devtest_data_dir/eng_Latn-$lang/test.eng_Latn.pred.$sys
|
52 |
+
ref_fname=$devtest_data_dir/eng_Latn-$lang/test.eng_Latn
|
53 |
+
out_fname=$devtest_data_dir/eng_Latn-$lang/comet/${lang}_eng_Latn_${sys}_comet.txt
|
54 |
+
|
55 |
+
# Compute COMET scores using the `comet-score`
|
56 |
+
comet-score -s $src_fname -t $pred_fname -r $ref_fname --gpus 1 --model $model_name --quiet --only_system > $out_fname
|
57 |
+
fi
|
58 |
+
|
59 |
+
done
|
60 |
+
|
61 |
+
# --------------------------------------------------------------
|
62 |
+
# COMET significance testing
|
63 |
+
# --------------------------------------------------------------
|
64 |
+
|
65 |
+
# en - indic direction
|
66 |
+
src_fname=$devtest_data_dir/eng_Latn-$lang/test.eng_Latn
|
67 |
+
pred_fname=$devtest_data_dir/eng_Latn-$lang/test.$lang.pred.*
|
68 |
+
ref_fname=$devtest_data_dir/eng_Latn-$lang/test.$lang
|
69 |
+
out_fname=$devtest_data_dir/eng_Latn-$lang/comet/eng_Latn_${lang}_comet_stat.txt
|
70 |
+
|
71 |
+
# Compute COMET significance scores using the `comet-compare`
|
72 |
+
comet-compare -s $src_fname -t $pred_fname -r $ref_fname > $out_fname
|
73 |
+
|
74 |
+
|
75 |
+
# indic-en direction
|
76 |
+
src_fname=$devtest_data_dir/eng_Latn-$lang/test.$lang
|
77 |
+
pred_fname=$devtest_data_dir/eng_Latn-$lang/test.eng_Latn.pred.*
|
78 |
+
ref_fname=$devtest_data_dir/eng_Latn-$lang/test.eng_Latn
|
79 |
+
out_fname=$devtest_data_dir/eng_Latn-$lang/comet/${lang}_eng_Latn_comet_stat.txt
|
80 |
+
|
81 |
+
# Compute COMET significance scores using the `comet-compare`
|
82 |
+
comet-compare -s $src_fname -t $pred_fname -r $ref_fname > $out_fname
|
83 |
+
|
84 |
+
done
|
IndicTrans2/compute_metrics.sh
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# This script compute the evaluation metrics such as BLEU, chrF, chrF++ using the
|
4 |
+
# detokenized predictions of the translation systems using sacrebleu (version 2.3.1).
|
5 |
+
# If the target language is:
|
6 |
+
# English: directly use Moses tokenizer that is internally supported (`mteval-v13a`)
|
7 |
+
# Indic: use IndicNLP tokenizers and skip tokenization step in sacrebleu.
|
8 |
+
|
9 |
+
|
10 |
+
echo `date`
|
11 |
+
pred_fname=$1 # path to the predction file
|
12 |
+
ref_fname=$2 # path to the reference file
|
13 |
+
tgt_lang=$3 # target language
|
14 |
+
|
15 |
+
|
16 |
+
if [ $tgt_lang == 'eng_Latn' ]; then
|
17 |
+
# directly tokenize the prediction and reference files using sacrebleu and compute the metric
|
18 |
+
sacrebleu $ref_fname < $pred_fname -m bleu chrf
|
19 |
+
sacrebleu $ref_fname < $pred_fname -m chrf --chrf-word-order 2
|
20 |
+
else
|
21 |
+
|
22 |
+
# indicnlp tokenize prediction and reference files before evaluation
|
23 |
+
input_size=`python scripts/preprocess_translate.py $ref_fname $ref_fname.tok $tgt_lang false false`
|
24 |
+
input_size=`python scripts/preprocess_translate.py $pred_fname $pred_fname.tok $tgt_lang false false`
|
25 |
+
|
26 |
+
# since we are tokenizing with indicnlp separately, we are setting tokenize to none here
|
27 |
+
sacrebleu --tokenize none $ref_fname.tok < $pred_fname.tok -m bleu chrf
|
28 |
+
sacrebleu --tokenize none $ref_fname.tok < $pred_fname.tok -m chrf --chrf-word-order 2
|
29 |
+
fi
|
IndicTrans2/compute_metrics_significance.sh
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# This script performs significance testing for metrics such as BLEU, chrF++ using sacrebleu on the evaluation set
|
4 |
+
# where each subdirectory contains En-X pair
|
5 |
+
|
6 |
+
|
7 |
+
echo `date`
|
8 |
+
devtest_data_dir=$1 # path to the evaluation directory
|
9 |
+
|
10 |
+
# we predefine a set of systems which we consider for evaluation
|
11 |
+
# feel free to change the below line in case you want to add or remove any system
|
12 |
+
system=(google azure nllb mbart50 m2m100 it1 it2)
|
13 |
+
|
14 |
+
|
15 |
+
# get a list of language pairs in the `devtest_data_dir`
|
16 |
+
pairs=$(ls -d $devtest_data_dir/eng_Latn-* | sort)
|
17 |
+
|
18 |
+
|
19 |
+
# iterate over each language pair
|
20 |
+
for pair in ${pairs[@]}; do
|
21 |
+
# extract the source and target languages from the pair name
|
22 |
+
pair=$(basename $pair)
|
23 |
+
src_lang=$(echo "$pair" | cut -d "-" -f 1)
|
24 |
+
tgt_lang=$(echo "$pair" | cut -d "-" -f 2)
|
25 |
+
|
26 |
+
if [[ $src_lang == "eng_Latn" ]]; then
|
27 |
+
|
28 |
+
# ----------------------------------------------------------------------
|
29 |
+
# en - indic direction
|
30 |
+
# ----------------------------------------------------------------------
|
31 |
+
echo "${src_lang} - ${tgt_lang}"
|
32 |
+
|
33 |
+
# find all the prediction files for different systems and tokenize it using IndicNLP
|
34 |
+
pred_fnames=$devtest_data_dir/$pair/test.${tgt_lang}.pred.*
|
35 |
+
ref_fname=$devtest_data_dir/$pair/test.${tgt_lang}
|
36 |
+
|
37 |
+
for pred_fname in $(find . -type f -name $pred_fnames); do
|
38 |
+
input_size=`python scripts/preprocess_translate.py $pred_fname $pred_fname.tok $tgt_lang false false`
|
39 |
+
done
|
40 |
+
|
41 |
+
input_size=`python scripts/preprocess_translate.py $ref_fname $ref_fname.tok $tgt_lang false false`
|
42 |
+
|
43 |
+
ref_fname=$devtest_data_dir/$pair/test.${tgt_lang}.tok
|
44 |
+
it2_fname=$devtest_data_dir/$pair/test.${tgt_lang}.pred.it2.tok
|
45 |
+
sys_fnames=$devtest_data_dir/$pair/test.${tgt_lang}.pred.*.tok
|
46 |
+
bleu_out_fname=$devtest_data_dir/$pair/${src_lang}_${tgt_lang}_bleu_significance.txt
|
47 |
+
chrF_out_fname=$devtest_data_dir/$pair/${src_lang}_${tgt_lang}_chrF++_significance.txt
|
48 |
+
|
49 |
+
sacrebleu --tokenize none $ref_fname -i $it2_fname $sys_fnames --paired-bs -m bleu --format text > $bleu_out_fname
|
50 |
+
sacrebleu --tokenize none $it2_fname $sys_fnames --paired-bs -m chrf --chrf-word-order 2 --format text > $chrF_out_fname
|
51 |
+
|
52 |
+
# ----------------------------------------------------------------------
|
53 |
+
# indic - en direction
|
54 |
+
# ----------------------------------------------------------------------
|
55 |
+
echo "${tgt_lang} - ${src_lang}"
|
56 |
+
|
57 |
+
ref_fname=$devtest_data_dir/$pair/test.${src_lang}
|
58 |
+
it2_fname=$devtest_data_dir/$pair/test.${src_lang}.pred.it2
|
59 |
+
sys_fnames=$devtest_data_dir/$pair/test.${src_lang}.pred.*
|
60 |
+
bleu_out_fname=$devtest_data_dir/$pair/${tgt_lang}_${src_lang}_bleu_significance.txt
|
61 |
+
chrF_out_fname=$devtest_data_dir/$pair/${tgt_lang}_${src_lang}_chrF++_significance.txt
|
62 |
+
|
63 |
+
sacrebleu --tokenize none $ref_fname -i $it2_fname $sys_fnames --paired-bs -m bleu --format text > $bleu_out_fname
|
64 |
+
sacrebleu --tokenize none $it2_fname $sys_fnames --paired-bs -m chrf --chrf-word-order 2 --format text > $chrF_out_fname
|
65 |
+
|
66 |
+
fi
|
IndicTrans2/eval.sh
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# This script evaluates the performance of a machine translation system
|
4 |
+
# on a evaluation set in forward direction. For example, if the evaluation set
|
5 |
+
# consists of language pairs, such as En-X, where En represents the English language
|
6 |
+
# and X represents the target Indic language then this script accesses the translation
|
7 |
+
# system from the English language (En) to the target Indic language (X) direction.
|
8 |
+
|
9 |
+
|
10 |
+
echo `date`
|
11 |
+
devtest_data_dir=$1 # path to the evaluation directory
|
12 |
+
ckpt_dir=$2 # path to the checkpoint directory
|
13 |
+
system=${3:-"it2"} # name of the machine translation system
|
14 |
+
|
15 |
+
|
16 |
+
# get a list of language pairs in the `devtest_data_dir`
|
17 |
+
pairs=$(ls -d $devtest_data_dir/* | sort)
|
18 |
+
|
19 |
+
|
20 |
+
# iterate over each language pair
|
21 |
+
for pair in ${pairs[@]}; do
|
22 |
+
# extract the source and target languages from the pair name
|
23 |
+
pair=$(basename $pair)
|
24 |
+
src_lang=$(echo "$pair" | cut -d "-" -f 1)
|
25 |
+
tgt_lang=$(echo "$pair" | cut -d "-" -f 2)
|
26 |
+
|
27 |
+
src_fname=$devtest_data_dir/$src_lang-$tgt_lang/test.$src_lang
|
28 |
+
tgt_fname=$devtest_data_dir/$src_lang-$tgt_lang/test.$tgt_lang
|
29 |
+
|
30 |
+
# check if the source and target files exists
|
31 |
+
if [ -f "$src_fname" ] && [ -f "$tgt_fname" ]; then
|
32 |
+
echo "Evaluating $src_lang-$tgt_lang ..."
|
33 |
+
else
|
34 |
+
echo "Skipping $src_lang-$tgt_lang ..."
|
35 |
+
continue
|
36 |
+
fi
|
37 |
+
|
38 |
+
# generate translations if the system name contains "it2"
|
39 |
+
if [[ $system == *"it2"* ]]; then
|
40 |
+
echo "Generating Translations"
|
41 |
+
bash joint_translate.sh $src_fname $tgt_fname.pred.$system $src_lang $tgt_lang $ckpt_dir
|
42 |
+
fi
|
43 |
+
|
44 |
+
# compute automatic string-based metrics if the prediction exists for the system
|
45 |
+
if [[ -f "${tgt_fname}.pred.${system}" ]]; then
|
46 |
+
echo "Computing Metrics"
|
47 |
+
bash compute_metrics.sh $tgt_fname.pred.$system $tgt_fname $tgt_lang > $devtest_data_dir/$src_lang-$tgt_lang/${src_lang}_${tgt_lang}_${system}_scores.txt
|
48 |
+
fi
|
49 |
+
|
50 |
+
# remove the intermediate files
|
51 |
+
rm -rf $tgt_fname.pred.$system.*
|
52 |
+
rm -rf $devtest_data_dir/$src_lang-$tgt_lang/*.tok
|
53 |
+
|
54 |
+
done
|
IndicTrans2/eval_rev.sh
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# This script evaluates the performance of a machine translation system
|
4 |
+
# on a evaluation set in forward direction. For example, if the evaluation set
|
5 |
+
# consists of language pairs, such as En-X, where En represents the English language
|
6 |
+
# and X represents the target Indic language then this script accesses the translation
|
7 |
+
# system from the target Indic language (X) to the English language (En) direction.
|
8 |
+
|
9 |
+
|
10 |
+
echo `date`
|
11 |
+
devtest_data_dir=$1 # path to the evaluation directory
|
12 |
+
ckpt_dir=$2 # path to the checkpoint directory
|
13 |
+
system=${3:-"it2"} # name of the machine translation system
|
14 |
+
|
15 |
+
|
16 |
+
# get a list of language pairs in the `devtest_data_dir`
|
17 |
+
pairs=$(ls -d $devtest_data_dir/* | sort)
|
18 |
+
|
19 |
+
|
20 |
+
# iterate over each language pair
|
21 |
+
for pair in ${pairs[@]}; do
|
22 |
+
# extract the source and target languages from the pair name
|
23 |
+
pair=$(basename $pair)
|
24 |
+
src_lang=$(echo "$pair" | cut -d "-" -f 1)
|
25 |
+
tgt_lang=$(echo "$pair" | cut -d "-" -f 2)
|
26 |
+
|
27 |
+
src_fname=$devtest_data_dir/$src_lang-$tgt_lang/test.$tgt_lang
|
28 |
+
tgt_fname=$devtest_data_dir/$src_lang-$tgt_lang/test.$src_lang
|
29 |
+
|
30 |
+
# check if the source and target files exists
|
31 |
+
# in this case, we flip the actual target file as source and vice-versa
|
32 |
+
if [ -f "$src_fname" ] && [ -f "$tgt_fname" ]; then
|
33 |
+
echo "Evaluating $src_lang-$tgt_lang ..."
|
34 |
+
else
|
35 |
+
echo "Skipping $src_lang-$tgt_lang ..."
|
36 |
+
continue
|
37 |
+
fi
|
38 |
+
|
39 |
+
# generate translations if the system name contains "it2"
|
40 |
+
if [[ $system == *"it2"* ]]; then
|
41 |
+
echo "Generating Translations"
|
42 |
+
bash joint_translate.sh $src_fname $tgt_fname.pred.$system $tgt_lang $src_lang $ckpt_dir
|
43 |
+
fi
|
44 |
+
|
45 |
+
# compute automatic string-based metrics if the prediction exists for the system
|
46 |
+
if [[ -f "${tgt_fname}.pred.${system}" ]]; then
|
47 |
+
echo "Computing Metrics"
|
48 |
+
bash compute_metrics.sh $tgt_fname.pred.$system $tgt_fname $src_lang > $devtest_data_dir/$src_lang-$tgt_lang/${tgt_lang}_${src_lang}_${system}_scores.txt
|
49 |
+
fi
|
50 |
+
|
51 |
+
# remove the intermediate files
|
52 |
+
rm -rf $tgt_fname.pred.$system.*
|
53 |
+
rm -rf $devtest_data_dir/$src_lang-$tgt_lang/*.tok
|
54 |
+
|
55 |
+
done
|
IndicTrans2/finetune.sh
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# This script finetunes the pretrained translation model on the binarized data using fairseq.
|
4 |
+
|
5 |
+
|
6 |
+
echo `date`
|
7 |
+
exp_dir=$1 # path of the experiment directory
|
8 |
+
model_arch=${2:-"transformer_18_18"} # model architecture (defaults to `transformer_18_18`)
|
9 |
+
pretrained_ckpt=$3 # path to the pretrained checkpoint `.pt` file
|
10 |
+
|
11 |
+
|
12 |
+
fairseq-train $exp_dir/final_bin \
|
13 |
+
--max-source-positions=256 \
|
14 |
+
--max-target-positions=256 \
|
15 |
+
--source-lang=SRC \
|
16 |
+
--target-lang=TGT \
|
17 |
+
--max-update=1000000 \
|
18 |
+
--save-interval-updates=1000 \
|
19 |
+
--arch=$model_arch \
|
20 |
+
--activation-fn gelu \
|
21 |
+
--criterion=label_smoothed_cross_entropy \
|
22 |
+
--label-smoothing=0.1 \
|
23 |
+
--optimizer adam \
|
24 |
+
--adam-betas "(0.9, 0.98)" \
|
25 |
+
--lr-scheduler=inverse_sqrt \
|
26 |
+
--clip-norm 1.0 \
|
27 |
+
--warmup-init-lr 1e-07 \
|
28 |
+
--lr 3e-5 \
|
29 |
+
--warmup-updates 2000 \
|
30 |
+
--dropout 0.2 \
|
31 |
+
--save-dir $exp_dir/model \
|
32 |
+
--keep-last-epochs 5 \
|
33 |
+
--keep-interval-updates 3 \
|
34 |
+
--patience 10 \
|
35 |
+
--skip-invalid-size-inputs-valid-test \
|
36 |
+
--fp16 \
|
37 |
+
--user-dir model_configs \
|
38 |
+
--update-freq=4 \
|
39 |
+
--distributed-world-size 8 \
|
40 |
+
--num-workers 24 \
|
41 |
+
--max-tokens 1024 \
|
42 |
+
--eval-bleu \
|
43 |
+
--eval-bleu-args "{\"beam\": 1, \"lenpen\": 1.0, \"max_len_a\": 1.2, \"max_len_b\": 10}" \
|
44 |
+
--eval-bleu-detok moses \
|
45 |
+
--eval-bleu-remove-bpe sentencepiece \
|
46 |
+
--eval-bleu-print-samples \
|
47 |
+
--best-checkpoint-metric bleu \
|
48 |
+
--maximize-best-checkpoint-metric \
|
49 |
+
--restore-file $pretrained_ckpt \
|
50 |
+
--reset-lr-scheduler \
|
51 |
+
--reset-meters \
|
52 |
+
--reset-dataloader \
|
53 |
+
--reset-optimizer \
|
54 |
+
--task translation
|
IndicTrans2/huggingface_interface/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
IndicTransTokenizer
|
IndicTrans2/huggingface_interface/README.md
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# IndicTrans2 HF Compatible Models
|
2 |
+
|
3 |
+
[![colab link](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI4Bharat/IndicTrans2/blob/main/huggingface_interface/colab_inference.ipynb)
|
4 |
+
|
5 |
+
In this section, we provide details on how to use our [IndicTrans2](https://github.com/AI4Bharat/IndicTrans2) models which were originally trained with the [fairseq](https://github.com/facebookresearch/fairseq) to [HuggingFace transformers](https://huggingface.co/docs/transformers/index) for inference purpose. Our scripts for HuggingFace compatible models are adapted from [M2M100 repository](https://github.com/huggingface/transformers/tree/main/src/transformers/models/m2m_100).
|
6 |
+
|
7 |
+
> Note: We have migrated IndicTrans2 tokenizer for HF compatible IndicTrans2 models to [IndicTransToolkit](https://github.com/VarunGumma/IndicTransToolkit) and will be maintained separately there from now onwards. This is automatically installed when you call `install.sh` script in `huggingface_interface`.
|
8 |
+
|
9 |
+
### Setup
|
10 |
+
|
11 |
+
To get started, follow these steps to set up the environment:
|
12 |
+
|
13 |
+
```
|
14 |
+
# Clone the github repository and navigate to the project directory.
|
15 |
+
git clone https://github.com/AI4Bharat/IndicTrans2
|
16 |
+
cd IndicTrans2/huggingface_interface
|
17 |
+
|
18 |
+
# Install all the dependencies and requirements associated with the project for running HF compatible models.
|
19 |
+
source install.sh
|
20 |
+
```
|
21 |
+
|
22 |
+
> Note: The `install.sh` script in this directory is specifically for running HF compatible models for inference.
|
23 |
+
|
24 |
+
### Converting
|
25 |
+
|
26 |
+
In order to convert the fairseq checkpoint to a PyTorch checkpoint that is compatible with HuggingFace Transformers, use the following command:
|
27 |
+
|
28 |
+
```bash
|
29 |
+
python3 convert_indictrans_checkpoint_to_pytorch.py --fairseq_path <fairseq_checkpoint_best.pt> --pytorch_dump_folder_path <hf_output_dir>
|
30 |
+
```
|
31 |
+
|
32 |
+
- `<fairseq_checkpoint_best.pt>`: path to the fairseq `checkpoint_best.pt` that needs to be converted to HF compatible models
|
33 |
+
- `<hf_output_dir>`: path to the output directory where the HF compatible models will be saved
|
34 |
+
|
35 |
+
### Models
|
36 |
+
|
37 |
+
| Model | 🤗 HuggingFace Checkpoints |
|
38 |
+
| -------------------------------- | ----------------------------------------------------------------------------------------------------------------- |
|
39 |
+
| En-Indic | [ai4bharat/indictrans2-en-indic-1B](https://huggingface.co/ai4bharat/indictrans2-en-indic-1B) |
|
40 |
+
| Indic-En | [ai4bharat/indictrans2-indic-en-1B](https://huggingface.co/ai4bharat/indictrans2-indic-en-1B) |
|
41 |
+
| Distilled En-Indic | [ai4bharat/indictrans2-en-indic-dist-200M](https://huggingface.co/ai4bharat/indictrans2-en-indic-dist-200M) |
|
42 |
+
| Distilled Indic-En | [ai4bharat/indictrans2-indic-en-dist-200M](https://huggingface.co/ai4bharat/indictrans2-indic-en-dist-200M) |
|
43 |
+
| Indic-Indic (Stitched) | [ai4bharat/indictrans2-indic-indic-1B](https://huggingface.co/ai4bharat/indictrans2-indic-indic-1B) |
|
44 |
+
| Distilled Indic-Indic (Stitched) | [ai4bharat/indictrans2-indic-indic-dist-320M](https://huggingface.co/ai4bharat/indictrans2-indic-indic-dist-320M) |
|
45 |
+
|
46 |
+
### Inference
|
47 |
+
|
48 |
+
With the conversion complete, you can now perform inference using the HuggingFace Transformers.
|
49 |
+
|
50 |
+
You can start with the provided `example.py` script and customize it for your specific translation use case:
|
51 |
+
|
52 |
+
```bash
|
53 |
+
python3 example.py
|
54 |
+
```
|
55 |
+
|
56 |
+
Feel free to modify the `example.py` script to suit your translation needs.
|
57 |
+
|
58 |
+
### Fine-tuning with LoRA
|
59 |
+
|
60 |
+
Before starting with fine-tuning IndicTrans2 models, you will need to restructure the training data in the following format.
|
61 |
+
|
62 |
+
```
|
63 |
+
en-indic-exp
|
64 |
+
├── train
|
65 |
+
│ ├── eng_Latn-asm_Beng
|
66 |
+
│ │ ├── train.eng_Latn
|
67 |
+
│ │ └── train.asm_Beng
|
68 |
+
│ ├── eng_Latn-ben_Beng
|
69 |
+
│ │ └── ...
|
70 |
+
│ └── {src_lang}-{tgt_lang}
|
71 |
+
│ ├── train.{src_lang}
|
72 |
+
│ └── train.{tgt_lang}
|
73 |
+
└── dev
|
74 |
+
├── eng_Latn-asm_Beng
|
75 |
+
│ ├── dev.eng_Latn
|
76 |
+
│ └── dev.asm_Beng
|
77 |
+
├── eng_Latn-ben_Beng
|
78 |
+
│ └── ...
|
79 |
+
└── {src_lang}-{tgt_lang}
|
80 |
+
├── dev.{src_lang}
|
81 |
+
└── dev.{tgt_lang}
|
82 |
+
```
|
83 |
+
|
84 |
+
Once you have data ready in above specified format, use the following command.
|
85 |
+
|
86 |
+
```bash
|
87 |
+
bash train_lora.sh <data_dir> <model_name> <output_dir> <direction> <src_lang_list> <tgt_lang_list>
|
88 |
+
```
|
89 |
+
|
90 |
+
We recommend you to refer to `train_lora.sh` for defaults arguments for fine-tuning. Please note that the specified hyperparameters may not be optimal and might require tuning for your use case.
|
91 |
+
|
92 |
+
### Inference with LoRA
|
93 |
+
|
94 |
+
You can load the LoRA adapters with the base model for inference by modifying the model initialization in `example.py` script.
|
95 |
+
|
96 |
+
```python
|
97 |
+
from transformers import AutoModelForSeq2SeqLM
|
98 |
+
from peft import PeftConfig, PeftModel
|
99 |
+
|
100 |
+
base_ckpt_dir = "ai4bharat/indictrans2-en-indic-1B" # you will need to change as per your use case
|
101 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_ckpt_dir, trust_remote_code=True)
|
102 |
+
lora_model = PeftModel.from_pretrained(base_model, lora_ckpt_dir)
|
103 |
+
```
|
104 |
+
|
105 |
+
> Note: Please feel free to open issues on the GitHub repo in case of any queries/issues.
|
106 |
+
|
107 |
+
### Citation
|
108 |
+
|
109 |
+
```bibtex
|
110 |
+
@article{gala2023indictrans,
|
111 |
+
title={IndicTrans2: Towards High-Quality and Accessible Machine Translation Models for all 22 Scheduled Indian Languages},
|
112 |
+
author={Jay Gala and Pranjal A Chitale and A K Raghavan and Varun Gumma and Sumanth Doddapaneni and Aswanth Kumar M and Janki Atul Nawale and Anupama Sujatha and Ratish Puduppully and Vivek Raghavan and Pratyush Kumar and Mitesh M Khapra and Raj Dabre and Anoop Kunchukuttan},
|
113 |
+
journal={Transactions on Machine Learning Research},
|
114 |
+
issn={2835-8856},
|
115 |
+
year={2023},
|
116 |
+
url={https://openreview.net/forum?id=vfT4YuzAYA},
|
117 |
+
note={}
|
118 |
+
}
|
119 |
+
```
|
IndicTrans2/huggingface_interface/colab_inference.ipynb
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "8Aa-nRCzPVdF"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"# IndicTrans2 HF Inference\n",
|
10 |
+
"\n",
|
11 |
+
"We provide an example notebook on how to use our IndicTrans2 models which were originally trained with the fairseq to HuggingFace transformers for inference purpose.\n"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "markdown",
|
16 |
+
"metadata": {
|
17 |
+
"id": "Cfsv02IeP2It"
|
18 |
+
},
|
19 |
+
"source": [
|
20 |
+
"## Setup\n",
|
21 |
+
"\n",
|
22 |
+
"Please run the cells below to install the necessary dependencies.\n"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": null,
|
28 |
+
"metadata": {
|
29 |
+
"id": "qKcYlUZYGLrt"
|
30 |
+
},
|
31 |
+
"outputs": [],
|
32 |
+
"source": [
|
33 |
+
"%%capture\n",
|
34 |
+
"!git clone https://github.com/AI4Bharat/IndicTrans2.git"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": null,
|
40 |
+
"metadata": {
|
41 |
+
"id": "U3vs7FkIGSxK"
|
42 |
+
},
|
43 |
+
"outputs": [],
|
44 |
+
"source": [
|
45 |
+
"%%capture\n",
|
46 |
+
"%cd /content/IndicTrans2/huggingface_interface"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": null,
|
52 |
+
"metadata": {
|
53 |
+
"id": "ddkRAXQ2Git0"
|
54 |
+
},
|
55 |
+
"outputs": [],
|
56 |
+
"source": [
|
57 |
+
"%%capture\n",
|
58 |
+
"!python3 -m pip install nltk sacremoses pandas regex mock transformers>=4.33.2 mosestokenizer\n",
|
59 |
+
"!python3 -c \"import nltk; nltk.download('punkt')\"\n",
|
60 |
+
"!python3 -m pip install bitsandbytes scipy accelerate datasets\n",
|
61 |
+
"!python3 -m pip install sentencepiece\n",
|
62 |
+
"\n",
|
63 |
+
"!git clone https://github.com/VarunGumma/IndicTransToolkit.git\n",
|
64 |
+
"%cd IndicTransToolkit\n",
|
65 |
+
"!python3 -m pip install --editable ./\n",
|
66 |
+
"%cd .."
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "markdown",
|
71 |
+
"metadata": {
|
72 |
+
"id": "hjN7ub1tO33H"
|
73 |
+
},
|
74 |
+
"source": [
|
75 |
+
"**IMPORTANT : Restart your run-time first and then run the cells below.**"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "markdown",
|
80 |
+
"metadata": {
|
81 |
+
"id": "_SLBIw6rQB-0"
|
82 |
+
},
|
83 |
+
"source": [
|
84 |
+
"## Inference\n"
|
85 |
+
]
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"cell_type": "code",
|
89 |
+
"execution_count": null,
|
90 |
+
"metadata": {
|
91 |
+
"id": "fYczM2U6G1Zv"
|
92 |
+
},
|
93 |
+
"outputs": [],
|
94 |
+
"source": [
|
95 |
+
"import torch\n",
|
96 |
+
"from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig, AutoTokenizer\n",
|
97 |
+
"from IndicTransToolkit import IndicProcessor\n",
|
98 |
+
"\n",
|
99 |
+
"BATCH_SIZE = 4\n",
|
100 |
+
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
101 |
+
"quantization = None"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"execution_count": null,
|
107 |
+
"metadata": {
|
108 |
+
"id": "xj1WCNjuHG-d"
|
109 |
+
},
|
110 |
+
"outputs": [],
|
111 |
+
"source": [
|
112 |
+
"def initialize_model_and_tokenizer(ckpt_dir, quantization):\n",
|
113 |
+
" if quantization == \"4-bit\":\n",
|
114 |
+
" qconfig = BitsAndBytesConfig(\n",
|
115 |
+
" load_in_4bit=True,\n",
|
116 |
+
" bnb_4bit_use_double_quant=True,\n",
|
117 |
+
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
|
118 |
+
" )\n",
|
119 |
+
" elif quantization == \"8-bit\":\n",
|
120 |
+
" qconfig = BitsAndBytesConfig(\n",
|
121 |
+
" load_in_8bit=True,\n",
|
122 |
+
" bnb_8bit_use_double_quant=True,\n",
|
123 |
+
" bnb_8bit_compute_dtype=torch.bfloat16,\n",
|
124 |
+
" )\n",
|
125 |
+
" else:\n",
|
126 |
+
" qconfig = None\n",
|
127 |
+
"\n",
|
128 |
+
" tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True)\n",
|
129 |
+
" model = AutoModelForSeq2SeqLM.from_pretrained(\n",
|
130 |
+
" ckpt_dir,\n",
|
131 |
+
" trust_remote_code=True,\n",
|
132 |
+
" low_cpu_mem_usage=True,\n",
|
133 |
+
" quantization_config=qconfig,\n",
|
134 |
+
" )\n",
|
135 |
+
"\n",
|
136 |
+
" if qconfig == None:\n",
|
137 |
+
" model = model.to(DEVICE)\n",
|
138 |
+
" if DEVICE == \"cuda\":\n",
|
139 |
+
" model.half()\n",
|
140 |
+
"\n",
|
141 |
+
" model.eval()\n",
|
142 |
+
"\n",
|
143 |
+
" return tokenizer, model\n",
|
144 |
+
"\n",
|
145 |
+
"\n",
|
146 |
+
"def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip):\n",
|
147 |
+
" translations = []\n",
|
148 |
+
" for i in range(0, len(input_sentences), BATCH_SIZE):\n",
|
149 |
+
" batch = input_sentences[i : i + BATCH_SIZE]\n",
|
150 |
+
"\n",
|
151 |
+
" # Preprocess the batch and extract entity mappings\n",
|
152 |
+
" batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)\n",
|
153 |
+
"\n",
|
154 |
+
" # Tokenize the batch and generate input encodings\n",
|
155 |
+
" inputs = tokenizer(\n",
|
156 |
+
" batch,\n",
|
157 |
+
" truncation=True,\n",
|
158 |
+
" padding=\"longest\",\n",
|
159 |
+
" return_tensors=\"pt\",\n",
|
160 |
+
" return_attention_mask=True,\n",
|
161 |
+
" ).to(DEVICE)\n",
|
162 |
+
"\n",
|
163 |
+
" # Generate translations using the model\n",
|
164 |
+
" with torch.no_grad():\n",
|
165 |
+
" generated_tokens = model.generate(\n",
|
166 |
+
" **inputs,\n",
|
167 |
+
" use_cache=True,\n",
|
168 |
+
" min_length=0,\n",
|
169 |
+
" max_length=256,\n",
|
170 |
+
" num_beams=5,\n",
|
171 |
+
" num_return_sequences=1,\n",
|
172 |
+
" )\n",
|
173 |
+
"\n",
|
174 |
+
" # Decode the generated tokens into text\n",
|
175 |
+
"\n",
|
176 |
+
" with tokenizer.as_target_tokenizer():\n",
|
177 |
+
" generated_tokens = tokenizer.batch_decode(\n",
|
178 |
+
" generated_tokens.detach().cpu().tolist(),\n",
|
179 |
+
" skip_special_tokens=True,\n",
|
180 |
+
" clean_up_tokenization_spaces=True,\n",
|
181 |
+
" )\n",
|
182 |
+
"\n",
|
183 |
+
" # Postprocess the translations, including entity replacement\n",
|
184 |
+
" translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang)\n",
|
185 |
+
"\n",
|
186 |
+
" del inputs\n",
|
187 |
+
" torch.cuda.empty_cache()\n",
|
188 |
+
"\n",
|
189 |
+
" return translations"
|
190 |
+
]
|
191 |
+
},
|
192 |
+
{
|
193 |
+
"cell_type": "markdown",
|
194 |
+
"metadata": {
|
195 |
+
"id": "erNCuZTEMt49"
|
196 |
+
},
|
197 |
+
"source": [
|
198 |
+
"### English to Indic Example\n"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"cell_type": "code",
|
203 |
+
"execution_count": null,
|
204 |
+
"metadata": {
|
205 |
+
"colab": {
|
206 |
+
"base_uri": "https://localhost:8080/"
|
207 |
+
},
|
208 |
+
"id": "6OG3Bw-sHnf3",
|
209 |
+
"outputId": "a204f50e-9456-4fb1-900a-e60680b97b99"
|
210 |
+
},
|
211 |
+
"outputs": [
|
212 |
+
{
|
213 |
+
"name": "stdout",
|
214 |
+
"output_type": "stream",
|
215 |
+
"text": [
|
216 |
+
"\n",
|
217 |
+
"eng_Latn - hin_Deva\n",
|
218 |
+
"eng_Latn: When I was young, I used to go to the park every day.\n",
|
219 |
+
"hin_Deva: जब मैं छोटा था, मैं हर दिन पार्क जाता था।\n",
|
220 |
+
"eng_Latn: He has many old books, which he inherited from his ancestors.\n",
|
221 |
+
"hin_Deva: उनके पास कई पुरानी किताबें हैं, जो उन्हें अपने पूर्वजों से विरासत में मिली हैं।\n",
|
222 |
+
"eng_Latn: I can't figure out how to solve my problem.\n",
|
223 |
+
"hin_Deva: मुझे समझ नहीं आ रहा है कि मैं अपनी समस्या का समाधान कैसे करूं।\n",
|
224 |
+
"eng_Latn: She is very hardworking and intelligent, which is why she got all the good marks.\n",
|
225 |
+
"hin_Deva: वह बहुत मेहनती और बुद्धिमान है, यही कारण है कि उसे सभी अच्छे अंक मिले।\n",
|
226 |
+
"eng_Latn: We watched a new movie last week, which was very inspiring.\n",
|
227 |
+
"hin_Deva: हमने पिछले हफ्ते एक नई फिल्म देखी, जो बहुत प्रेरणादायक थी।\n",
|
228 |
+
"eng_Latn: If you had met me at that time, we would have gone out to eat.\n",
|
229 |
+
"hin_Deva: अगर आप उस समय मुझसे मिलते तो हम बाहर खाना खाने जाते।\n",
|
230 |
+
"eng_Latn: She went to the market with her sister to buy a new sari.\n",
|
231 |
+
"hin_Deva: वह अपनी बहन के साथ नई साड़ी खरीदने के लिए बाजार गई थी।\n",
|
232 |
+
"eng_Latn: Raj told me that he is going to his grandmother's house next month.\n",
|
233 |
+
"hin_Deva: राज ने मुझे बताया कि वह अगले महीने अपनी दादी के घर जा रहा है।\n",
|
234 |
+
"eng_Latn: All the kids were having fun at the party and were eating lots of sweets.\n",
|
235 |
+
"hin_Deva: पार्टी में सभी बच्चे खूब मस्ती कर रहे थे और खूब मिठाइयां खा रहे थे।\n",
|
236 |
+
"eng_Latn: My friend has invited me to his birthday party, and I will give him a gift.\n",
|
237 |
+
"hin_Deva: मेरे दोस्त ने मुझे अपने जन्मदिन की पार्टी में आमंत्रित किया है, और मैं उसे एक उपहार दूंगा।\n"
|
238 |
+
]
|
239 |
+
}
|
240 |
+
],
|
241 |
+
"source": [
|
242 |
+
"en_indic_ckpt_dir = \"ai4bharat/indictrans2-en-indic-1B\" # ai4bharat/indictrans2-en-indic-dist-200M\n",
|
243 |
+
"en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, quantization)\n",
|
244 |
+
"\n",
|
245 |
+
"ip = IndicProcessor(inference=True)\n",
|
246 |
+
"\n",
|
247 |
+
"en_sents = [\n",
|
248 |
+
" \"When I was young, I used to go to the park every day.\",\n",
|
249 |
+
" \"He has many old books, which he inherited from his ancestors.\",\n",
|
250 |
+
" \"I can't figure out how to solve my problem.\",\n",
|
251 |
+
" \"She is very hardworking and intelligent, which is why she got all the good marks.\",\n",
|
252 |
+
" \"We watched a new movie last week, which was very inspiring.\",\n",
|
253 |
+
" \"If you had met me at that time, we would have gone out to eat.\",\n",
|
254 |
+
" \"She went to the market with her sister to buy a new sari.\",\n",
|
255 |
+
" \"Raj told me that he is going to his grandmother's house next month.\",\n",
|
256 |
+
" \"All the kids were having fun at the party and were eating lots of sweets.\",\n",
|
257 |
+
" \"My friend has invited me to his birthday party, and I will give him a gift.\",\n",
|
258 |
+
"]\n",
|
259 |
+
"\n",
|
260 |
+
"src_lang, tgt_lang = \"eng_Latn\", \"hin_Deva\"\n",
|
261 |
+
"hi_translations = batch_translate(en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip)\n",
|
262 |
+
"\n",
|
263 |
+
"print(f\"\\n{src_lang} - {tgt_lang}\")\n",
|
264 |
+
"for input_sentence, translation in zip(en_sents, hi_translations):\n",
|
265 |
+
" print(f\"{src_lang}: {input_sentence}\")\n",
|
266 |
+
" print(f\"{tgt_lang}: {translation}\")\n",
|
267 |
+
"\n",
|
268 |
+
"# flush the models to free the GPU memory\n",
|
269 |
+
"del en_indic_tokenizer, en_indic_model"
|
270 |
+
]
|
271 |
+
},
|
272 |
+
{
|
273 |
+
"cell_type": "markdown",
|
274 |
+
"metadata": {
|
275 |
+
"id": "OM_1pbPtMpV9"
|
276 |
+
},
|
277 |
+
"source": [
|
278 |
+
"### Indic to English Example"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"execution_count": null,
|
284 |
+
"metadata": {
|
285 |
+
"colab": {
|
286 |
+
"base_uri": "https://localhost:8080/"
|
287 |
+
},
|
288 |
+
"id": "PLCEWJKvGG9I",
|
289 |
+
"outputId": "ab9d8726-67c7-490b-ecb3-208df1c0f741"
|
290 |
+
},
|
291 |
+
"outputs": [
|
292 |
+
{
|
293 |
+
"name": "stdout",
|
294 |
+
"output_type": "stream",
|
295 |
+
"text": [
|
296 |
+
"\n",
|
297 |
+
"hin_Deva - eng_Latn\n",
|
298 |
+
"hin_Deva: जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।\n",
|
299 |
+
"eng_Latn: When I was young, I used to go to the park every day.\n",
|
300 |
+
"hin_Deva: उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।\n",
|
301 |
+
"eng_Latn: She has a lot of old books, which she inherited from her grandparents.\n",
|
302 |
+
"hin_Deva: मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।\n",
|
303 |
+
"eng_Latn: I don't know how to find a solution to my problem.\n",
|
304 |
+
"hin_Deva: वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।\n",
|
305 |
+
"eng_Latn: He is very hardworking and understanding, so he got all the good marks.\n",
|
306 |
+
"hin_Deva: हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।\n",
|
307 |
+
"eng_Latn: We saw a new movie last week that was very inspiring.\n",
|
308 |
+
"hin_Deva: अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।\n",
|
309 |
+
"eng_Latn: If you'd given me a pass at that time, we'd have gone out to eat.\n",
|
310 |
+
"hin_Deva: वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।\n",
|
311 |
+
"eng_Latn: She had gone to the market with her sister so that she could buy a new sari.\n",
|
312 |
+
"hin_Deva: राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।\n",
|
313 |
+
"eng_Latn: Raj told me that he was going to his grandmother's house next month.\n",
|
314 |
+
"hin_Deva: सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।\n",
|
315 |
+
"eng_Latn: All the children were having fun at the party and eating a lot of sweets.\n",
|
316 |
+
"hin_Deva: मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।\n",
|
317 |
+
"eng_Latn: My friend has invited me to her birthday party, and I'll give her a present.\n"
|
318 |
+
]
|
319 |
+
}
|
320 |
+
],
|
321 |
+
"source": [
|
322 |
+
"indic_en_ckpt_dir = \"ai4bharat/indictrans2-indic-en-1B\" # ai4bharat/indictrans2-indic-en-dist-200M\n",
|
323 |
+
"indic_en_tokenizer, indic_en_model = initialize_model_and_tokenizer(indic_en_ckpt_dir, quantization)\n",
|
324 |
+
"\n",
|
325 |
+
"ip = IndicProcessor(inference=True)\n",
|
326 |
+
"\n",
|
327 |
+
"hi_sents = [\n",
|
328 |
+
" \"जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।\",\n",
|
329 |
+
" \"उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।\",\n",
|
330 |
+
" \"मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।\",\n",
|
331 |
+
" \"वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।\",\n",
|
332 |
+
" \"हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।\",\n",
|
333 |
+
" \"अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।\",\n",
|
334 |
+
" \"वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।\",\n",
|
335 |
+
" \"राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।\",\n",
|
336 |
+
" \"सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।\",\n",
|
337 |
+
" \"मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।\",\n",
|
338 |
+
"]\n",
|
339 |
+
"src_lang, tgt_lang = \"hin_Deva\", \"eng_Latn\"\n",
|
340 |
+
"en_translations = batch_translate(hi_sents, src_lang, tgt_lang, indic_en_model, indic_en_tokenizer, ip)\n",
|
341 |
+
"\n",
|
342 |
+
"\n",
|
343 |
+
"print(f\"\\n{src_lang} - {tgt_lang}\")\n",
|
344 |
+
"for input_sentence, translation in zip(hi_sents, en_translations):\n",
|
345 |
+
" print(f\"{src_lang}: {input_sentence}\")\n",
|
346 |
+
" print(f\"{tgt_lang}: {translation}\")\n",
|
347 |
+
"\n",
|
348 |
+
"# flush the models to free the GPU memory\n",
|
349 |
+
"del indic_en_tokenizer, indic_en_model"
|
350 |
+
]
|
351 |
+
},
|
352 |
+
{
|
353 |
+
"cell_type": "markdown",
|
354 |
+
"metadata": {
|
355 |
+
"id": "7VCAkyKBGtnV"
|
356 |
+
},
|
357 |
+
"source": [
|
358 |
+
"### Indic to Indic Example\n"
|
359 |
+
]
|
360 |
+
},
|
361 |
+
{
|
362 |
+
"cell_type": "code",
|
363 |
+
"execution_count": null,
|
364 |
+
"metadata": {
|
365 |
+
"colab": {
|
366 |
+
"base_uri": "https://localhost:8080/"
|
367 |
+
},
|
368 |
+
"id": "_7TxTTCoKjti",
|
369 |
+
"outputId": "df1a750b-0f32-478d-cfc9-e445f669f3ee"
|
370 |
+
},
|
371 |
+
"outputs": [
|
372 |
+
{
|
373 |
+
"name": "stdout",
|
374 |
+
"output_type": "stream",
|
375 |
+
"text": [
|
376 |
+
"\n",
|
377 |
+
"hin_Deva - mar_Deva\n",
|
378 |
+
"hin_Deva: जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।\n",
|
379 |
+
"mar_Deva: मी लहान होतो तेव्हा मी दररोज उद्यानाला जायचे.\n",
|
380 |
+
"hin_Deva: उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।\n",
|
381 |
+
"mar_Deva: तिच्याकडे बरीच जुनी पुस्तके आहेत, जी तिला तिच्या आजोबांकडून वारशाने मिळाली आहेत.\n",
|
382 |
+
"hin_Deva: मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।\n",
|
383 |
+
"mar_Deva: माझ्या समस्येवर तोडगा कसा काढायचा हे मला समजत नाही.\n",
|
384 |
+
"hin_Deva: वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।\n",
|
385 |
+
"mar_Deva: तो खूप मेहनती आणि बुद्धिमान आहे, त्यामुळे त्याला सर्व चांगले गुण मिळाले.\n",
|
386 |
+
"hin_Deva: हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।\n",
|
387 |
+
"mar_Deva: आम्ही गेल्या आठवड्यात एक नवीन चित्रपट पाहिला जो खूप प्रेरणादायी होता.\n",
|
388 |
+
"hin_Deva: अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।\n",
|
389 |
+
"mar_Deva: जर तुम्हाला त्या वेळी मला पास मिळाला तर आम्ही बाहेर जेवायला जाऊ.\n",
|
390 |
+
"hin_Deva: वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।\n",
|
391 |
+
"mar_Deva: ती तिच्या बहिणीसोबत बाजारात गेली होती जेणेकरून ती नवीन साडी खरेदी करू शकेल.\n",
|
392 |
+
"hin_Deva: राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।\n",
|
393 |
+
"mar_Deva: राजने मला सांगितले की तो पुढच्या महिन्यात त्याच्या आजीच्या घरी जात आहे.\n",
|
394 |
+
"hin_Deva: सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।\n",
|
395 |
+
"mar_Deva: सर्व मुले पार्टीचा आनंद घेत होती आणि भरपूर मिठाई खात होती.\n",
|
396 |
+
"hin_Deva: मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।\n",
|
397 |
+
"mar_Deva: माझ्या मित्राने मला त्याच्या वाढदिवसाच्या मेजवानीसाठी आमंत्रित केले आहे आणि मी त्याला भेटवस्तू देईन.\n"
|
398 |
+
]
|
399 |
+
}
|
400 |
+
],
|
401 |
+
"source": [
|
402 |
+
"indic_indic_ckpt_dir = \"ai4bharat/indictrans2-indic-indic-1B\" # ai4bharat/indictrans2-indic-indic-dist-320M\n",
|
403 |
+
"indic_indic_tokenizer, indic_indic_model = initialize_model_and_tokenizer(indic_indic_ckpt_dir, quantization)\n",
|
404 |
+
"\n",
|
405 |
+
"ip = IndicProcessor(inference=True)\n",
|
406 |
+
"\n",
|
407 |
+
"hi_sents = [\n",
|
408 |
+
" \"जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।\",\n",
|
409 |
+
" \"उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।\",\n",
|
410 |
+
" \"मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।\",\n",
|
411 |
+
" \"वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।\",\n",
|
412 |
+
" \"हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।\",\n",
|
413 |
+
" \"अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।\",\n",
|
414 |
+
" \"वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।\",\n",
|
415 |
+
" \"राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।\",\n",
|
416 |
+
" \"सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।\",\n",
|
417 |
+
" \"मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।\",\n",
|
418 |
+
"]\n",
|
419 |
+
"src_lang, tgt_lang = \"hin_Deva\", \"mar_Deva\"\n",
|
420 |
+
"mr_translations = batch_translate(hi_sents, src_lang, tgt_lang, indic_indic_model, indic_indic_tokenizer, ip)\n",
|
421 |
+
"\n",
|
422 |
+
"print(f\"\\n{src_lang} - {tgt_lang}\")\n",
|
423 |
+
"for input_sentence, translation in zip(hi_sents, mr_translations):\n",
|
424 |
+
" print(f\"{src_lang}: {input_sentence}\")\n",
|
425 |
+
" print(f\"{tgt_lang}: {translation}\")\n",
|
426 |
+
"\n",
|
427 |
+
"# flush the models to free the GPU memory\n",
|
428 |
+
"del indic_indic_tokenizer, indic_indic_model"
|
429 |
+
]
|
430 |
+
},
|
431 |
+
{
|
432 |
+
"cell_type": "code",
|
433 |
+
"execution_count": null,
|
434 |
+
"metadata": {
|
435 |
+
"id": "uyxXpt--Ma6n"
|
436 |
+
},
|
437 |
+
"outputs": [],
|
438 |
+
"source": []
|
439 |
+
}
|
440 |
+
],
|
441 |
+
"metadata": {
|
442 |
+
"accelerator": "GPU",
|
443 |
+
"colab": {
|
444 |
+
"gpuType": "T4",
|
445 |
+
"provenance": [],
|
446 |
+
"toc_visible": true
|
447 |
+
},
|
448 |
+
"kernelspec": {
|
449 |
+
"display_name": "Python 3",
|
450 |
+
"name": "python3"
|
451 |
+
},
|
452 |
+
"language_info": {
|
453 |
+
"name": "python"
|
454 |
+
}
|
455 |
+
},
|
456 |
+
"nbformat": 4,
|
457 |
+
"nbformat_minor": 0
|
458 |
+
}
|
IndicTrans2/huggingface_interface/configuration_indictrans.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch IndicTrans config."""
|
16 |
+
|
17 |
+
|
18 |
+
from collections import OrderedDict
|
19 |
+
from typing import Any, Mapping, Optional
|
20 |
+
|
21 |
+
from transformers import PreTrainedTokenizer
|
22 |
+
from transformers.configuration_utils import PretrainedConfig
|
23 |
+
from transformers.onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
|
24 |
+
from transformers.onnx.utils import compute_effective_axis_dimension
|
25 |
+
from transformers.utils import TensorType, is_torch_available
|
26 |
+
|
27 |
+
|
28 |
+
# Copied from transformers.models.m2m_100.configuration_m2m_100.M2M100Config->IndicTrans
|
29 |
+
class IndicTransConfig(PretrainedConfig):
|
30 |
+
r"""
|
31 |
+
This is the configuration class to store the configuration of a [`IT2Model`]. It is used to instantiate an
|
32 |
+
IT2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
33 |
+
with the defaults will yield a similar configuration to that of the IT2
|
34 |
+
|
35 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
36 |
+
documentation from [`PretrainedConfig`] for more information.
|
37 |
+
|
38 |
+
|
39 |
+
Args:
|
40 |
+
vocab_size (`int`, *optional*, defaults to 50265):
|
41 |
+
Vocabulary size of the IT2 model. Defines the number of different tokens that can be represented by the
|
42 |
+
`inputs_ids` passed when calling [`IT2Model`] or
|
43 |
+
d_model (`int`, *optional*, defaults to 1024):
|
44 |
+
Dimensionality of the layers and the pooler layer.
|
45 |
+
encoder_layers (`int`, *optional*, defaults to 12):
|
46 |
+
Number of encoder layers.
|
47 |
+
decoder_layers (`int`, *optional*, defaults to 12):
|
48 |
+
Number of decoder layers.
|
49 |
+
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
50 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
51 |
+
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
52 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
53 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
54 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
55 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
56 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
57 |
+
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
58 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
59 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
60 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
61 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
62 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
63 |
+
The dropout ratio for the attention probabilities.
|
64 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
65 |
+
The dropout ratio for activations inside the fully connected layer.
|
66 |
+
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
67 |
+
The dropout ratio for classifier.
|
68 |
+
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
69 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
70 |
+
just in case (e.g., 512 or 1024 or 2048).
|
71 |
+
init_std (`float`, *optional*, defaults to 0.02):
|
72 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
73 |
+
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
74 |
+
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
75 |
+
for more details.
|
76 |
+
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
77 |
+
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
78 |
+
for more details.
|
79 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
80 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
81 |
+
```"""
|
82 |
+
model_type = "IndicTrans"
|
83 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
84 |
+
attribute_map = {
|
85 |
+
"num_attention_heads": "encoder_attention_heads",
|
86 |
+
"hidden_size": "d_model",
|
87 |
+
}
|
88 |
+
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
encoder_vocab_size=None,
|
92 |
+
decoder_vocab_size=None,
|
93 |
+
encoder_embed_dim=512,
|
94 |
+
decoder_embed_dim=512,
|
95 |
+
max_source_positions=210,
|
96 |
+
max_target_positions=210,
|
97 |
+
encoder_layers=6,
|
98 |
+
encoder_ffn_dim=2048,
|
99 |
+
encoder_attention_heads=8,
|
100 |
+
decoder_layers=6,
|
101 |
+
decoder_ffn_dim=2048,
|
102 |
+
decoder_attention_heads=8,
|
103 |
+
encoder_layerdrop=0.00,
|
104 |
+
decoder_layerdrop=0.00,
|
105 |
+
use_cache=True,
|
106 |
+
is_encoder_decoder=True,
|
107 |
+
activation_function="relu",
|
108 |
+
encoder_normalize_before=False,
|
109 |
+
decoder_normalize_before=False,
|
110 |
+
layernorm_embedding=False,
|
111 |
+
share_decoder_input_output_embed=False,
|
112 |
+
dropout=0.1,
|
113 |
+
attention_dropout=0.0,
|
114 |
+
activation_dropout=0.0,
|
115 |
+
init_std=0.02,
|
116 |
+
scale_embedding=True,
|
117 |
+
decoder_start_token_id=2,
|
118 |
+
pad_token_id=1,
|
119 |
+
bos_token_id=0,
|
120 |
+
eos_token_id=2,
|
121 |
+
attn_implementation="eager",
|
122 |
+
**kwargs,
|
123 |
+
):
|
124 |
+
self.encoder_vocab_size = encoder_vocab_size
|
125 |
+
self.decoder_vocab_size = decoder_vocab_size
|
126 |
+
self.encoder_normalize_before = encoder_normalize_before
|
127 |
+
self.decoder_normalize_before = decoder_normalize_before
|
128 |
+
self.layernorm_embedding = layernorm_embedding
|
129 |
+
self.max_source_positions = max_source_positions
|
130 |
+
self.max_target_positions = max_target_positions
|
131 |
+
self.encoder_embed_dim = encoder_embed_dim
|
132 |
+
self.decoder_embed_dim = decoder_embed_dim
|
133 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
134 |
+
self.encoder_layers = encoder_layers
|
135 |
+
self.encoder_attention_heads = encoder_attention_heads
|
136 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
137 |
+
self.decoder_layers = decoder_layers
|
138 |
+
self.decoder_attention_heads = decoder_attention_heads
|
139 |
+
self.dropout = dropout
|
140 |
+
self.attention_dropout = attention_dropout
|
141 |
+
self.activation_dropout = activation_dropout
|
142 |
+
self.activation_function = activation_function
|
143 |
+
self.init_std = init_std
|
144 |
+
self.encoder_layerdrop = encoder_layerdrop
|
145 |
+
self.decoder_layerdrop = decoder_layerdrop
|
146 |
+
self.use_cache = use_cache
|
147 |
+
self.num_hidden_layers = encoder_layers
|
148 |
+
self.scale_embedding = scale_embedding
|
149 |
+
self.share_decoder_input_output_embed = share_decoder_input_output_embed
|
150 |
+
self.attn_implementation = attn_implementation
|
151 |
+
|
152 |
+
super().__init__(
|
153 |
+
pad_token_id=pad_token_id,
|
154 |
+
bos_token_id=bos_token_id,
|
155 |
+
eos_token_id=eos_token_id,
|
156 |
+
is_encoder_decoder=is_encoder_decoder,
|
157 |
+
decoder_start_token_id=decoder_start_token_id,
|
158 |
+
**kwargs,
|
159 |
+
)
|
160 |
+
|
161 |
+
|
162 |
+
class IndicTransOnnxConfig(OnnxSeq2SeqConfigWithPast):
|
163 |
+
@property
|
164 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
165 |
+
common_inputs = OrderedDict(
|
166 |
+
[
|
167 |
+
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
168 |
+
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
169 |
+
]
|
170 |
+
)
|
171 |
+
|
172 |
+
if self.use_past:
|
173 |
+
common_inputs["decoder_input_ids"] = {0: "batch"}
|
174 |
+
common_inputs["decoder_attention_mask"] = {
|
175 |
+
0: "batch",
|
176 |
+
1: "past_decoder_sequence + sequence",
|
177 |
+
}
|
178 |
+
else:
|
179 |
+
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
|
180 |
+
common_inputs["decoder_attention_mask"] = {
|
181 |
+
0: "batch",
|
182 |
+
1: "decoder_sequence",
|
183 |
+
}
|
184 |
+
|
185 |
+
if self.use_past:
|
186 |
+
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
187 |
+
return common_inputs
|
188 |
+
|
189 |
+
# Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
|
190 |
+
# A better name would be _generate_dummy_inputs_for_encoder_and_decoder because sequence classification and question
|
191 |
+
# answering are not supported for IT2, but this name is preserved to be able to check that the copy matches what
|
192 |
+
# was done for BART so that it can be updated if need be.
|
193 |
+
def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
194 |
+
self,
|
195 |
+
tokenizer: PreTrainedTokenizer,
|
196 |
+
batch_size: int = -1,
|
197 |
+
seq_length: int = -1,
|
198 |
+
is_pair: bool = False,
|
199 |
+
framework: Optional[TensorType] = None,
|
200 |
+
) -> Mapping[str, Any]:
|
201 |
+
# Copied from OnnxConfig.generate_dummy_inputs
|
202 |
+
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
|
203 |
+
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
204 |
+
batch_size = compute_effective_axis_dimension(
|
205 |
+
batch_size,
|
206 |
+
fixed_dimension=OnnxConfig.default_fixed_batch,
|
207 |
+
num_token_to_add=0,
|
208 |
+
)
|
209 |
+
|
210 |
+
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
|
211 |
+
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
|
212 |
+
seq_length = compute_effective_axis_dimension(
|
213 |
+
seq_length,
|
214 |
+
fixed_dimension=OnnxConfig.default_fixed_sequence,
|
215 |
+
num_token_to_add=token_to_add,
|
216 |
+
)
|
217 |
+
|
218 |
+
# Generate dummy inputs according to compute batch and sequence
|
219 |
+
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
220 |
+
common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
|
221 |
+
return common_inputs
|
222 |
+
|
223 |
+
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_default_and_seq2seq_lm
|
224 |
+
def _generate_dummy_inputs_for_default_and_seq2seq_lm(
|
225 |
+
self,
|
226 |
+
tokenizer: PreTrainedTokenizer,
|
227 |
+
batch_size: int = -1,
|
228 |
+
seq_length: int = -1,
|
229 |
+
is_pair: bool = False,
|
230 |
+
framework: Optional[TensorType] = None,
|
231 |
+
) -> Mapping[str, Any]:
|
232 |
+
encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
233 |
+
tokenizer, batch_size, seq_length, is_pair, framework
|
234 |
+
)
|
235 |
+
|
236 |
+
# Generate decoder inputs
|
237 |
+
decoder_seq_length = seq_length if not self.use_past else 1
|
238 |
+
decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
239 |
+
tokenizer, batch_size, decoder_seq_length, is_pair, framework
|
240 |
+
)
|
241 |
+
decoder_inputs = {
|
242 |
+
f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()
|
243 |
+
}
|
244 |
+
common_inputs = dict(**encoder_inputs, **decoder_inputs)
|
245 |
+
|
246 |
+
if self.use_past:
|
247 |
+
if not is_torch_available():
|
248 |
+
raise ValueError(
|
249 |
+
"Cannot generate dummy past_keys inputs without PyTorch installed."
|
250 |
+
)
|
251 |
+
else:
|
252 |
+
import torch
|
253 |
+
batch, encoder_seq_length = common_inputs["input_ids"].shape
|
254 |
+
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
|
255 |
+
(
|
256 |
+
num_encoder_attention_heads,
|
257 |
+
num_decoder_attention_heads,
|
258 |
+
) = self.num_attention_heads
|
259 |
+
encoder_shape = (
|
260 |
+
batch,
|
261 |
+
num_encoder_attention_heads,
|
262 |
+
encoder_seq_length,
|
263 |
+
self._config.hidden_size // num_encoder_attention_heads,
|
264 |
+
)
|
265 |
+
decoder_past_length = decoder_seq_length + 3
|
266 |
+
decoder_shape = (
|
267 |
+
batch,
|
268 |
+
num_decoder_attention_heads,
|
269 |
+
decoder_past_length,
|
270 |
+
self._config.hidden_size // num_decoder_attention_heads,
|
271 |
+
)
|
272 |
+
|
273 |
+
common_inputs["decoder_attention_mask"] = torch.cat(
|
274 |
+
[
|
275 |
+
common_inputs["decoder_attention_mask"],
|
276 |
+
torch.ones(batch, decoder_past_length),
|
277 |
+
],
|
278 |
+
dim=1,
|
279 |
+
)
|
280 |
+
|
281 |
+
common_inputs["past_key_values"] = []
|
282 |
+
# If the number of encoder and decoder layers are present in the model configuration, both are considered
|
283 |
+
num_encoder_layers, num_decoder_layers = self.num_layers
|
284 |
+
min_num_layers = min(num_encoder_layers, num_decoder_layers)
|
285 |
+
max_num_layers = (
|
286 |
+
max(num_encoder_layers, num_decoder_layers) - min_num_layers
|
287 |
+
)
|
288 |
+
remaining_side_name = (
|
289 |
+
"encoder" if num_encoder_layers > num_decoder_layers else "decoder"
|
290 |
+
)
|
291 |
+
|
292 |
+
for _ in range(min_num_layers):
|
293 |
+
common_inputs["past_key_values"].append(
|
294 |
+
(
|
295 |
+
torch.zeros(decoder_shape),
|
296 |
+
torch.zeros(decoder_shape),
|
297 |
+
torch.zeros(encoder_shape),
|
298 |
+
torch.zeros(encoder_shape),
|
299 |
+
)
|
300 |
+
)
|
301 |
+
# TODO: test this.
|
302 |
+
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
|
303 |
+
for _ in range(min_num_layers, max_num_layers):
|
304 |
+
common_inputs["past_key_values"].append(
|
305 |
+
(torch.zeros(shape), torch.zeros(shape))
|
306 |
+
)
|
307 |
+
return common_inputs
|
308 |
+
|
309 |
+
generate_dummy_inputs = _generate_dummy_inputs_for_default_and_seq2seq_lm
|
IndicTrans2/huggingface_interface/convert_indictrans_checkpoint_to_pytorch.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from configuration_indictrans import IndicTransConfig
|
21 |
+
from modeling_indictrans import IndicTransForConditionalGeneration
|
22 |
+
|
23 |
+
|
24 |
+
def remove_ignore_keys_(state_dict):
|
25 |
+
ignore_keys = [
|
26 |
+
"encoder.version",
|
27 |
+
"decoder.version",
|
28 |
+
"model.encoder.version",
|
29 |
+
"model.decoder.version",
|
30 |
+
"_float_tensor",
|
31 |
+
"encoder.embed_positions._float_tensor",
|
32 |
+
"decoder.embed_positions._float_tensor",
|
33 |
+
]
|
34 |
+
for k in ignore_keys:
|
35 |
+
state_dict.pop(k, None)
|
36 |
+
|
37 |
+
|
38 |
+
def make_linear_from_emb(emb):
|
39 |
+
vocab_size, emb_size = emb.shape
|
40 |
+
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
41 |
+
lin_layer.weight.data = emb.data
|
42 |
+
return lin_layer
|
43 |
+
|
44 |
+
|
45 |
+
def convert_fairseq_IT2_checkpoint_from_disk(checkpoint_path):
|
46 |
+
model = torch.load(checkpoint_path, map_location="cpu")
|
47 |
+
args = model["args"] or model["cfg"]["model"]
|
48 |
+
state_dict = model["model"]
|
49 |
+
remove_ignore_keys_(state_dict)
|
50 |
+
encoder_vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
|
51 |
+
decoder_vocab_size = state_dict["decoder.embed_tokens.weight"].shape[0]
|
52 |
+
|
53 |
+
config = IndicTransConfig(
|
54 |
+
encoder_vocab_size=encoder_vocab_size,
|
55 |
+
decoder_vocab_size=decoder_vocab_size,
|
56 |
+
max_source_positions=args.max_source_positions,
|
57 |
+
max_target_positions=args.max_target_positions,
|
58 |
+
encoder_layers=args.encoder_layers,
|
59 |
+
decoder_layers=args.decoder_layers,
|
60 |
+
layernorm_embedding=args.layernorm_embedding,
|
61 |
+
encoder_normalize_before=args.encoder_normalize_before,
|
62 |
+
decoder_normalize_before=args.decoder_normalize_before,
|
63 |
+
encoder_attention_heads=args.encoder_attention_heads,
|
64 |
+
decoder_attention_heads=args.decoder_attention_heads,
|
65 |
+
encoder_ffn_dim=args.encoder_ffn_embed_dim,
|
66 |
+
decoder_ffn_dim=args.decoder_ffn_embed_dim,
|
67 |
+
encoder_embed_dim=args.encoder_embed_dim,
|
68 |
+
decoder_embed_dim=args.decoder_embed_dim,
|
69 |
+
encoder_layerdrop=args.encoder_layerdrop,
|
70 |
+
decoder_layerdrop=args.decoder_layerdrop,
|
71 |
+
dropout=args.dropout,
|
72 |
+
attention_dropout=args.attention_dropout,
|
73 |
+
activation_dropout=args.activation_dropout,
|
74 |
+
activation_function=args.activation_fn,
|
75 |
+
share_decoder_input_output_embed=args.share_decoder_input_output_embed,
|
76 |
+
scale_embedding=not args.no_scale_embedding,
|
77 |
+
)
|
78 |
+
|
79 |
+
model = IndicTransForConditionalGeneration(config)
|
80 |
+
model.model.load_state_dict(state_dict, strict=False)
|
81 |
+
if not args.share_decoder_input_output_embed:
|
82 |
+
model.lm_head = make_linear_from_emb(
|
83 |
+
state_dict["decoder.output_projection.weight"]
|
84 |
+
)
|
85 |
+
print(model)
|
86 |
+
return model
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
parser = argparse.ArgumentParser()
|
91 |
+
# Required parameters
|
92 |
+
parser.add_argument(
|
93 |
+
"--fairseq_path",
|
94 |
+
default="indic-en/model/checkpoint_best.pt",
|
95 |
+
type=str,
|
96 |
+
help="path to a model.pt on local filesystem.",
|
97 |
+
)
|
98 |
+
parser.add_argument(
|
99 |
+
"--pytorch_dump_folder_path",
|
100 |
+
default="indic-en/hf_model",
|
101 |
+
type=str,
|
102 |
+
help="Path to the output PyTorch model.",
|
103 |
+
)
|
104 |
+
|
105 |
+
args = parser.parse_args()
|
106 |
+
model = convert_fairseq_IT2_checkpoint_from_disk(args.fairseq_path)
|
107 |
+
model.save_pretrained(args.pytorch_dump_folder_path)
|
IndicTrans2/huggingface_interface/example.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig
|
4 |
+
from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
|
5 |
+
from IndicTransToolkit import IndicProcessor
|
6 |
+
from mosestokenizer import MosesSentenceSplitter
|
7 |
+
from nltk import sent_tokenize
|
8 |
+
from indicnlp.tokenize.sentence_tokenize import sentence_split, DELIM_PAT_NO_DANDA
|
9 |
+
|
10 |
+
|
11 |
+
en_indic_ckpt_dir = "ai4bharat/indictrans2-en-indic-1B" # ai4bharat/indictrans2-en-indic-dist-200M
|
12 |
+
indic_en_ckpt_dir = "ai4bharat/indictrans2-indic-en-1B" # ai4bharat/indictrans2-indic-en-dist-200M
|
13 |
+
indic_indic_ckpt_dir = (
|
14 |
+
"ai4bharat/indictrans2-indic-indic-dist-320M" # ai4bharat/indictrans2-indic-indic-dist-320M
|
15 |
+
)
|
16 |
+
BATCH_SIZE = 4
|
17 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
|
19 |
+
if len(sys.argv) > 1:
|
20 |
+
quantization = sys.argv[1]
|
21 |
+
attn_implementation = sys.argv[2]
|
22 |
+
else:
|
23 |
+
quantization = ""
|
24 |
+
attn_implementation = "eager"
|
25 |
+
|
26 |
+
|
27 |
+
# FLORES language code mapping to 2 letter ISO language code for compatibility
|
28 |
+
# with Indic NLP Library (https://github.com/anoopkunchukuttan/indic_nlp_library)
|
29 |
+
flores_codes = {
|
30 |
+
"asm_Beng": "as",
|
31 |
+
"awa_Deva": "hi",
|
32 |
+
"ben_Beng": "bn",
|
33 |
+
"bho_Deva": "hi",
|
34 |
+
"brx_Deva": "hi",
|
35 |
+
"doi_Deva": "hi",
|
36 |
+
"eng_Latn": "en",
|
37 |
+
"gom_Deva": "kK",
|
38 |
+
"guj_Gujr": "gu",
|
39 |
+
"hin_Deva": "hi",
|
40 |
+
"hne_Deva": "hi",
|
41 |
+
"kan_Knda": "kn",
|
42 |
+
"kas_Arab": "ur",
|
43 |
+
"kas_Deva": "hi",
|
44 |
+
"kha_Latn": "en",
|
45 |
+
"lus_Latn": "en",
|
46 |
+
"mag_Deva": "hi",
|
47 |
+
"mai_Deva": "hi",
|
48 |
+
"mal_Mlym": "ml",
|
49 |
+
"mar_Deva": "mr",
|
50 |
+
"mni_Beng": "bn",
|
51 |
+
"mni_Mtei": "hi",
|
52 |
+
"npi_Deva": "ne",
|
53 |
+
"ory_Orya": "or",
|
54 |
+
"pan_Guru": "pa",
|
55 |
+
"san_Deva": "hi",
|
56 |
+
"sat_Olck": "or",
|
57 |
+
"snd_Arab": "ur",
|
58 |
+
"snd_Deva": "hi",
|
59 |
+
"tam_Taml": "ta",
|
60 |
+
"tel_Telu": "te",
|
61 |
+
"urd_Arab": "ur",
|
62 |
+
}
|
63 |
+
|
64 |
+
|
65 |
+
def split_sentences(input_text, lang):
|
66 |
+
if lang == "eng_Latn":
|
67 |
+
input_sentences = sent_tokenize(input_text)
|
68 |
+
with MosesSentenceSplitter(flores_codes[lang]) as splitter:
|
69 |
+
sents_moses = splitter([input_text])
|
70 |
+
sents_nltk = sent_tokenize(input_text)
|
71 |
+
if len(sents_nltk) < len(sents_moses):
|
72 |
+
input_sentences = sents_nltk
|
73 |
+
else:
|
74 |
+
input_sentences = sents_moses
|
75 |
+
input_sentences = [sent.replace("\xad", "") for sent in input_sentences]
|
76 |
+
else:
|
77 |
+
input_sentences = sentence_split(
|
78 |
+
input_text, lang=flores_codes[lang], delim_pat=DELIM_PAT_NO_DANDA
|
79 |
+
)
|
80 |
+
return input_sentences
|
81 |
+
|
82 |
+
|
83 |
+
def initialize_model_and_tokenizer(ckpt_dir, quantization, attn_implementation):
|
84 |
+
if quantization == "4-bit":
|
85 |
+
qconfig = BitsAndBytesConfig(
|
86 |
+
load_in_4bit=True,
|
87 |
+
bnb_4bit_use_double_quant=True,
|
88 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
89 |
+
)
|
90 |
+
elif quantization == "8-bit":
|
91 |
+
qconfig = BitsAndBytesConfig(
|
92 |
+
load_in_8bit=True,
|
93 |
+
bnb_8bit_use_double_quant=True,
|
94 |
+
bnb_8bit_compute_dtype=torch.bfloat16,
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
qconfig = None
|
98 |
+
|
99 |
+
if attn_implementation == "flash_attention_2":
|
100 |
+
if is_flash_attn_2_available() and is_flash_attn_greater_or_equal_2_10():
|
101 |
+
attn_implementation = "flash_attention_2"
|
102 |
+
else:
|
103 |
+
attn_implementation = "eager"
|
104 |
+
|
105 |
+
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True)
|
106 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
107 |
+
ckpt_dir,
|
108 |
+
trust_remote_code=True,
|
109 |
+
attn_implementation=attn_implementation,
|
110 |
+
low_cpu_mem_usage=True,
|
111 |
+
quantization_config=qconfig,
|
112 |
+
)
|
113 |
+
|
114 |
+
if qconfig == None:
|
115 |
+
model = model.to(DEVICE)
|
116 |
+
model.half()
|
117 |
+
|
118 |
+
model.eval()
|
119 |
+
|
120 |
+
return tokenizer, model
|
121 |
+
|
122 |
+
|
123 |
+
def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip):
|
124 |
+
translations = []
|
125 |
+
for i in range(0, len(input_sentences), BATCH_SIZE):
|
126 |
+
batch = input_sentences[i : i + BATCH_SIZE]
|
127 |
+
|
128 |
+
# Preprocess the batch and extract entity mappings
|
129 |
+
batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)
|
130 |
+
|
131 |
+
# Tokenize the batch and generate input encodings
|
132 |
+
inputs = tokenizer(
|
133 |
+
batch,
|
134 |
+
truncation=True,
|
135 |
+
padding="longest",
|
136 |
+
return_tensors="pt",
|
137 |
+
return_attention_mask=True,
|
138 |
+
).to(DEVICE)
|
139 |
+
|
140 |
+
# Generate translations using the model
|
141 |
+
with torch.no_grad():
|
142 |
+
generated_tokens = model.generate(
|
143 |
+
**inputs,
|
144 |
+
use_cache=True,
|
145 |
+
min_length=0,
|
146 |
+
max_length=256,
|
147 |
+
num_beams=5,
|
148 |
+
num_return_sequences=1,
|
149 |
+
)
|
150 |
+
|
151 |
+
# Decode the generated tokens into text
|
152 |
+
with tokenizer.as_target_tokenizer():
|
153 |
+
generated_tokens = tokenizer.batch_decode(
|
154 |
+
generated_tokens.detach().cpu().tolist(),
|
155 |
+
skip_special_tokens=True,
|
156 |
+
clean_up_tokenization_spaces=True,
|
157 |
+
)
|
158 |
+
|
159 |
+
# Postprocess the translations, including entity replacement
|
160 |
+
translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang)
|
161 |
+
|
162 |
+
del inputs
|
163 |
+
torch.cuda.empty_cache()
|
164 |
+
|
165 |
+
return translations
|
166 |
+
|
167 |
+
|
168 |
+
def translate_paragraph(input_text, src_lang, tgt_lang, model, tokenizer, ip):
|
169 |
+
input_sentences = split_sentences(input_text, src_lang)
|
170 |
+
translated_text = batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip)
|
171 |
+
return " ".join(translated_text)
|
172 |
+
|
173 |
+
|
174 |
+
ip = IndicProcessor(inference=True)
|
175 |
+
|
176 |
+
en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(
|
177 |
+
en_indic_ckpt_dir, quantization, attn_implementation
|
178 |
+
)
|
179 |
+
|
180 |
+
indic_en_tokenizer, indic_en_model = initialize_model_and_tokenizer(
|
181 |
+
indic_en_ckpt_dir, quantization, attn_implementation
|
182 |
+
)
|
183 |
+
|
184 |
+
indic_indic_tokenizer, indic_indic_model = initialize_model_and_tokenizer(
|
185 |
+
indic_indic_ckpt_dir, quantization, attn_implementation
|
186 |
+
)
|
187 |
+
|
188 |
+
# ---------------------------------------------------------------------------
|
189 |
+
# Hindi to English
|
190 |
+
# ---------------------------------------------------------------------------
|
191 |
+
hi_sents = [
|
192 |
+
"जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।",
|
193 |
+
"उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।",
|
194 |
+
"मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।",
|
195 |
+
"वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।",
|
196 |
+
"हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।",
|
197 |
+
"अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।",
|
198 |
+
"वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।",
|
199 |
+
"राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।",
|
200 |
+
"सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।",
|
201 |
+
"मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।",
|
202 |
+
]
|
203 |
+
src_lang, tgt_lang = "hin_Deva", "eng_Latn"
|
204 |
+
en_translations = batch_translate(
|
205 |
+
hi_sents, src_lang, tgt_lang, indic_en_model, indic_en_tokenizer, ip
|
206 |
+
)
|
207 |
+
|
208 |
+
print(f"\n{src_lang} - {tgt_lang}")
|
209 |
+
for input_sentence, translation in zip(hi_sents, en_translations):
|
210 |
+
print(f"{src_lang}: {input_sentence}")
|
211 |
+
print(f"{tgt_lang}: {translation}")
|
212 |
+
|
213 |
+
|
214 |
+
# ---------------------------------------------------------------------------
|
215 |
+
# English to Hindi
|
216 |
+
# ---------------------------------------------------------------------------
|
217 |
+
en_sents = [
|
218 |
+
"When I was young, I used to go to the park every day.",
|
219 |
+
"He has many old books, which he inherited from his ancestors.",
|
220 |
+
"I can't figure out how to solve my problem.",
|
221 |
+
"She is very hardworking and intelligent, which is why she got all the good marks.",
|
222 |
+
"We watched a new movie last week, which was very inspiring.",
|
223 |
+
"If you had met me at that time, we would have gone out to eat.",
|
224 |
+
"She went to the market with her sister to buy a new sari.",
|
225 |
+
"Raj told me that he is going to his grandmother's house next month.",
|
226 |
+
"All the kids were having fun at the party and were eating lots of sweets.",
|
227 |
+
"My friend has invited me to his birthday party, and I will give him a gift.",
|
228 |
+
]
|
229 |
+
src_lang, tgt_lang = "eng_Latn", "hin_Deva"
|
230 |
+
hi_translations = batch_translate(
|
231 |
+
en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip
|
232 |
+
)
|
233 |
+
|
234 |
+
print(f"\n{src_lang} - {tgt_lang}")
|
235 |
+
for input_sentence, translation in zip(en_sents, hi_translations):
|
236 |
+
print(f"{src_lang}: {input_sentence}")
|
237 |
+
print(f"{tgt_lang}: {translation}")
|
238 |
+
|
239 |
+
|
240 |
+
# ---------------------------------------------------------------------------
|
241 |
+
# Hindi to Marathi
|
242 |
+
# ---------------------------------------------------------------------------
|
243 |
+
hi_sents = [
|
244 |
+
"��ब मैं छोटा था, मैं हर रोज़ पार्क जाता था।",
|
245 |
+
"उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।",
|
246 |
+
"मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।",
|
247 |
+
"वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।",
|
248 |
+
"हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।",
|
249 |
+
"अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।",
|
250 |
+
"वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।",
|
251 |
+
"राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।",
|
252 |
+
"सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।",
|
253 |
+
"मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।",
|
254 |
+
]
|
255 |
+
src_lang, tgt_lang = "hin_Deva", "mar_Deva"
|
256 |
+
mr_translations = batch_translate(
|
257 |
+
hi_sents, src_lang, tgt_lang, indic_indic_model, indic_indic_tokenizer, ip
|
258 |
+
)
|
259 |
+
|
260 |
+
print(f"\n{src_lang} - {tgt_lang}")
|
261 |
+
for input_sentence, translation in zip(hi_sents, mr_translations):
|
262 |
+
print(f"{src_lang}: {input_sentence}")
|
263 |
+
print(f"{tgt_lang}: {translation}")
|
264 |
+
|
265 |
+
|
266 |
+
# ---------------------------------------------------------------------------
|
267 |
+
# Paragraph translation
|
268 |
+
# ---------------------------------------------------------------------------
|
269 |
+
src_lang, tgt_lang = "hin_Deva", "eng_Latn"
|
270 |
+
hi_text = "यहाँ एक पाराग्राफ है जो हिंदी में लिखा गया है। हिंदी एक सुंदर भाषा है और भारत की राष्ट्रीय भाषा है। इसका विकास विभिन्न कालों में हुआ है और यह विशेषतः भारतीय उपमहाद्वीप में बोली जाती है। हिंदी भाषा का साहित्य, संस्कृति और इतिहास भी बहुत गर्वनीय है।"
|
271 |
+
en_translated_text = translate_paragraph(
|
272 |
+
hi_text, src_lang, tgt_lang, indic_en_model, indic_en_tokenizer, ip
|
273 |
+
)
|
274 |
+
print(f"{src_lang}: {hi_text}")
|
275 |
+
print(f"{tgt_lang}: {en_translated_text}")
|
IndicTrans2/huggingface_interface/install.sh
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#/bin/bash
|
2 |
+
|
3 |
+
root_dir=$(pwd)
|
4 |
+
echo "Setting up the environment in the $root_dir"
|
5 |
+
|
6 |
+
# --------------------------------------------------------------
|
7 |
+
# create and activate the virtual environment
|
8 |
+
# --------------------------------------------------------------
|
9 |
+
echo "Creating a virtual environment with python3"
|
10 |
+
conda create -n itv2_hf python=3.9 -y
|
11 |
+
conda activate itv2_hf
|
12 |
+
|
13 |
+
echo "Installing all the dependencies"
|
14 |
+
conda install pip
|
15 |
+
python3 -m pip install --upgrade pip
|
16 |
+
|
17 |
+
|
18 |
+
# --------------------------------------------------------------
|
19 |
+
# PyTorch Installation
|
20 |
+
# --------------------------------------------------------------
|
21 |
+
python3 -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu118
|
22 |
+
|
23 |
+
|
24 |
+
# --------------------------------------------------------------
|
25 |
+
# Install additional utility packages
|
26 |
+
# --------------------------------------------------------------
|
27 |
+
python3 -m pip install nltk sacremoses pandas regex mock transformers>=4.33.2 mosestokenizer
|
28 |
+
python3 -c "import nltk; nltk.download('punkt')"
|
29 |
+
python3 -m pip install bitsandbytes scipy accelerate datasets flash-attn>=2.1
|
30 |
+
|
31 |
+
|
32 |
+
# --------------------------------------------------------------
|
33 |
+
# Sentencepiece for tokenization
|
34 |
+
# --------------------------------------------------------------
|
35 |
+
# build the cpp binaries from the source repo in order to use the command line utility
|
36 |
+
# source repo: https://github.com/google/sentencepiece
|
37 |
+
python3 -m pip install sentencepiece
|
38 |
+
|
39 |
+
|
40 |
+
# -----------------------------------------------------------------
|
41 |
+
# Install IndicTrans2 tokenizer and its dependencies
|
42 |
+
# -----------------------------------------------------------------
|
43 |
+
git clone https://github.com/VarunGumma/IndicTransToolkit
|
44 |
+
cd IndicTransToolkit
|
45 |
+
python3 -m pip install --editable ./
|
46 |
+
cd $root_dir
|
47 |
+
|
48 |
+
|
49 |
+
echo "Setup completed!"
|
IndicTrans2/huggingface_interface/modeling_indictrans.py
ADDED
@@ -0,0 +1,1801 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch IndicTrans model."""
|
16 |
+
|
17 |
+
|
18 |
+
import math
|
19 |
+
from typing import List, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
from torch.nn import functional as F
|
24 |
+
|
25 |
+
from transformers.activations import ACT2FN
|
26 |
+
|
27 |
+
from transformers.modeling_attn_mask_utils import (
|
28 |
+
_prepare_4d_attention_mask,
|
29 |
+
_prepare_4d_attention_mask_for_sdpa,
|
30 |
+
_prepare_4d_causal_attention_mask,
|
31 |
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
32 |
+
)
|
33 |
+
|
34 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
35 |
+
from transformers.modeling_outputs import (
|
36 |
+
BaseModelOutput,
|
37 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
38 |
+
Seq2SeqLMOutput,
|
39 |
+
Seq2SeqModelOutput
|
40 |
+
)
|
41 |
+
|
42 |
+
from transformers.utils import (
|
43 |
+
logging,
|
44 |
+
is_flash_attn_2_available,
|
45 |
+
is_flash_attn_greater_or_equal_2_10,
|
46 |
+
)
|
47 |
+
|
48 |
+
from transformers.modeling_utils import PreTrainedModel
|
49 |
+
|
50 |
+
from .configuration_indictrans import IndicTransConfig
|
51 |
+
|
52 |
+
|
53 |
+
logger = logging.get_logger(__name__)
|
54 |
+
|
55 |
+
INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
|
56 |
+
|
57 |
+
try:
|
58 |
+
if is_flash_attn_2_available():
|
59 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
60 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
61 |
+
except:
|
62 |
+
pass
|
63 |
+
|
64 |
+
|
65 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
66 |
+
def _get_unpad_data(attention_mask):
|
67 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
68 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
69 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
70 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
71 |
+
return (
|
72 |
+
indices,
|
73 |
+
cu_seqlens,
|
74 |
+
max_seqlen_in_batch,
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
|
79 |
+
def shift_tokens_right(
|
80 |
+
input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
Shift input ids one token to the right.
|
84 |
+
"""
|
85 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
86 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
87 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
88 |
+
|
89 |
+
if pad_token_id is None:
|
90 |
+
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
91 |
+
# replace possible -100 values in labels by `pad_token_id`
|
92 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
93 |
+
|
94 |
+
return shifted_input_ids
|
95 |
+
|
96 |
+
|
97 |
+
def create_position_ids_from_input_ids(
|
98 |
+
input_ids, padding_idx, past_key_values_length=0
|
99 |
+
):
|
100 |
+
"""
|
101 |
+
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
102 |
+
are ignored. This is modified from fairseq's `utils.make_positions`.
|
103 |
+
"""
|
104 |
+
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
105 |
+
mask = input_ids.ne(padding_idx).int()
|
106 |
+
incremental_indices = (
|
107 |
+
torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
|
108 |
+
) * mask
|
109 |
+
return incremental_indices.long() + padding_idx
|
110 |
+
|
111 |
+
|
112 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding->IndicTrans
|
113 |
+
class IndicTransSinusoidalPositionalEmbedding(nn.Module):
|
114 |
+
"""This module produces sinusoidal positional embeddings of any length."""
|
115 |
+
|
116 |
+
def __init__(
|
117 |
+
self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None
|
118 |
+
):
|
119 |
+
super().__init__()
|
120 |
+
self.offset = 2
|
121 |
+
self.embedding_dim = embedding_dim
|
122 |
+
self.padding_idx = padding_idx
|
123 |
+
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
|
124 |
+
|
125 |
+
def make_weights(
|
126 |
+
self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
|
127 |
+
):
|
128 |
+
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
|
129 |
+
if hasattr(self, "weights"):
|
130 |
+
# in forward put the weights on the correct dtype and device of the param
|
131 |
+
emb_weights = emb_weights.to(
|
132 |
+
dtype=self.weights.dtype, device=self.weights.device
|
133 |
+
)
|
134 |
+
|
135 |
+
self.register_buffer("weights", emb_weights, persistent=False)
|
136 |
+
|
137 |
+
@staticmethod
|
138 |
+
def get_embedding(
|
139 |
+
num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
|
140 |
+
):
|
141 |
+
"""
|
142 |
+
Build sinusoidal embeddings.
|
143 |
+
|
144 |
+
This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
|
145 |
+
"Attention Is All You Need".
|
146 |
+
"""
|
147 |
+
half_dim = embedding_dim // 2
|
148 |
+
emb = math.log(10000) / (half_dim - 1)
|
149 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
150 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
|
151 |
+
1
|
152 |
+
) * emb.unsqueeze(0)
|
153 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
|
154 |
+
num_embeddings, -1
|
155 |
+
)
|
156 |
+
if embedding_dim % 2 == 1:
|
157 |
+
# zero pad
|
158 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
159 |
+
if padding_idx is not None:
|
160 |
+
emb[padding_idx, :] = 0
|
161 |
+
|
162 |
+
return emb.to(torch.get_default_dtype())
|
163 |
+
|
164 |
+
@torch.no_grad()
|
165 |
+
def forward(
|
166 |
+
self,
|
167 |
+
input_ids: torch.Tensor = None,
|
168 |
+
inputs_embeds: torch.Tensor = None,
|
169 |
+
past_key_values_length: int = 0,
|
170 |
+
):
|
171 |
+
if input_ids is not None:
|
172 |
+
bsz, seq_len = input_ids.size()
|
173 |
+
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
174 |
+
position_ids = create_position_ids_from_input_ids(
|
175 |
+
input_ids, self.padding_idx, past_key_values_length
|
176 |
+
).to(input_ids.device)
|
177 |
+
else:
|
178 |
+
bsz, seq_len = inputs_embeds.size()[:-1]
|
179 |
+
position_ids = self.create_position_ids_from_inputs_embeds(
|
180 |
+
inputs_embeds, past_key_values_length
|
181 |
+
)
|
182 |
+
|
183 |
+
# expand embeddings if needed
|
184 |
+
max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
|
185 |
+
if max_pos > self.weights.size(0):
|
186 |
+
self.make_weights(
|
187 |
+
max_pos + self.offset, self.embedding_dim, self.padding_idx
|
188 |
+
)
|
189 |
+
|
190 |
+
return (
|
191 |
+
self.weights.index_select(0, position_ids.view(-1))
|
192 |
+
.view(bsz, seq_len, self.weights.shape[-1])
|
193 |
+
.detach()
|
194 |
+
)
|
195 |
+
|
196 |
+
def create_position_ids_from_inputs_embeds(
|
197 |
+
self, inputs_embeds, past_key_values_length
|
198 |
+
):
|
199 |
+
"""
|
200 |
+
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
inputs_embeds: torch.Tensor
|
204 |
+
|
205 |
+
Returns: torch.Tensor
|
206 |
+
"""
|
207 |
+
input_shape = inputs_embeds.size()[:-1]
|
208 |
+
sequence_length = input_shape[1]
|
209 |
+
|
210 |
+
position_ids = torch.arange(
|
211 |
+
self.padding_idx + 1,
|
212 |
+
sequence_length + self.padding_idx + 1,
|
213 |
+
dtype=torch.long,
|
214 |
+
device=inputs_embeds.device,
|
215 |
+
)
|
216 |
+
return (
|
217 |
+
position_ids.unsqueeze(0).expand(input_shape).contiguous()
|
218 |
+
+ past_key_values_length
|
219 |
+
)
|
220 |
+
|
221 |
+
|
222 |
+
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->IndicTrans
|
223 |
+
class IndicTransAttention(nn.Module):
|
224 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
225 |
+
|
226 |
+
def __init__(
|
227 |
+
self,
|
228 |
+
embed_dim: int,
|
229 |
+
num_heads: int,
|
230 |
+
dropout: float = 0.0,
|
231 |
+
is_decoder: bool = False,
|
232 |
+
bias: bool = True,
|
233 |
+
is_causal: bool = False,
|
234 |
+
config: Optional[IndicTransConfig] = None,
|
235 |
+
):
|
236 |
+
super().__init__()
|
237 |
+
self.embed_dim = embed_dim
|
238 |
+
self.num_heads = num_heads
|
239 |
+
self.dropout = dropout
|
240 |
+
self.head_dim = embed_dim // num_heads
|
241 |
+
self.config = config
|
242 |
+
|
243 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
244 |
+
raise ValueError(
|
245 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
246 |
+
f" and `num_heads`: {num_heads})."
|
247 |
+
)
|
248 |
+
self.scaling = self.head_dim**-0.5
|
249 |
+
self.is_decoder = is_decoder
|
250 |
+
self.is_causal = is_causal
|
251 |
+
|
252 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
253 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
254 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
255 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
256 |
+
|
257 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
258 |
+
return (
|
259 |
+
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
260 |
+
.transpose(1, 2)
|
261 |
+
.contiguous()
|
262 |
+
)
|
263 |
+
|
264 |
+
def forward(
|
265 |
+
self,
|
266 |
+
hidden_states: torch.Tensor,
|
267 |
+
key_value_states: Optional[torch.Tensor] = None,
|
268 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
269 |
+
attention_mask: Optional[torch.Tensor] = None,
|
270 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
271 |
+
output_attentions: bool = False,
|
272 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
273 |
+
"""Input shape: Batch x Time x Channel"""
|
274 |
+
|
275 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
276 |
+
# for the decoder
|
277 |
+
is_cross_attention = key_value_states is not None
|
278 |
+
|
279 |
+
bsz, tgt_len, _ = hidden_states.size()
|
280 |
+
|
281 |
+
# get query proj
|
282 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
283 |
+
# get key, value proj
|
284 |
+
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
285 |
+
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
286 |
+
# the provided `key_value_states` to support prefix tuning
|
287 |
+
if (
|
288 |
+
is_cross_attention
|
289 |
+
and past_key_value is not None
|
290 |
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
291 |
+
):
|
292 |
+
# reuse k,v, cross_attentions
|
293 |
+
key_states = past_key_value[0]
|
294 |
+
value_states = past_key_value[1]
|
295 |
+
elif is_cross_attention:
|
296 |
+
# cross_attentions
|
297 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
298 |
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
299 |
+
elif past_key_value is not None:
|
300 |
+
# reuse k, v, self_attention
|
301 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
302 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
303 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
304 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
305 |
+
else:
|
306 |
+
# self_attention
|
307 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
308 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
309 |
+
|
310 |
+
if self.is_decoder:
|
311 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
312 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
313 |
+
# key/value_states (first "if" case)
|
314 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
315 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
316 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
317 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
318 |
+
past_key_value = (key_states, value_states)
|
319 |
+
|
320 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
321 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
322 |
+
key_states = key_states.reshape(*proj_shape)
|
323 |
+
value_states = value_states.reshape(*proj_shape)
|
324 |
+
|
325 |
+
src_len = key_states.size(1)
|
326 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
327 |
+
|
328 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
329 |
+
raise ValueError(
|
330 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
331 |
+
f" {attn_weights.size()}"
|
332 |
+
)
|
333 |
+
|
334 |
+
if attention_mask is not None:
|
335 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
336 |
+
raise ValueError(
|
337 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
338 |
+
)
|
339 |
+
attn_weights = (
|
340 |
+
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
341 |
+
+ attention_mask
|
342 |
+
)
|
343 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
344 |
+
|
345 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
346 |
+
|
347 |
+
if layer_head_mask is not None:
|
348 |
+
if layer_head_mask.size() != (self.num_heads,):
|
349 |
+
raise ValueError(
|
350 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
351 |
+
f" {layer_head_mask.size()}"
|
352 |
+
)
|
353 |
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
|
354 |
+
bsz, self.num_heads, tgt_len, src_len
|
355 |
+
)
|
356 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
357 |
+
|
358 |
+
if output_attentions:
|
359 |
+
# this operation is a bit awkward, but it's required to
|
360 |
+
# make sure that attn_weights keeps its gradient.
|
361 |
+
# In order to do so, attn_weights have to be reshaped
|
362 |
+
# twice and have to be reused in the following
|
363 |
+
attn_weights_reshaped = attn_weights.view(
|
364 |
+
bsz, self.num_heads, tgt_len, src_len
|
365 |
+
)
|
366 |
+
attn_weights = attn_weights_reshaped.view(
|
367 |
+
bsz * self.num_heads, tgt_len, src_len
|
368 |
+
)
|
369 |
+
else:
|
370 |
+
attn_weights_reshaped = None
|
371 |
+
|
372 |
+
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
|
373 |
+
|
374 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
375 |
+
|
376 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
377 |
+
raise ValueError(
|
378 |
+
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
|
379 |
+
f" {attn_output.size()}"
|
380 |
+
)
|
381 |
+
|
382 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
383 |
+
attn_output = attn_output.transpose(1, 2)
|
384 |
+
|
385 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
386 |
+
# partitioned across GPUs when using tensor-parallelism.
|
387 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
388 |
+
|
389 |
+
attn_output = self.out_proj(attn_output)
|
390 |
+
|
391 |
+
return attn_output, attn_weights_reshaped, past_key_value
|
392 |
+
|
393 |
+
|
394 |
+
class IndicTransFlashAttention2(IndicTransAttention):
|
395 |
+
"""
|
396 |
+
IndicTrans flash attention module. This module inherits from `IndicTransAttention` as the weights of the module stays
|
397 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
398 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
399 |
+
"""
|
400 |
+
|
401 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
402 |
+
def __init__(self, *args, **kwargs):
|
403 |
+
super().__init__(*args, **kwargs)
|
404 |
+
|
405 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
406 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
407 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
408 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
409 |
+
|
410 |
+
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
411 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
412 |
+
|
413 |
+
def forward(
|
414 |
+
self,
|
415 |
+
hidden_states: torch.Tensor,
|
416 |
+
key_value_states: Optional[torch.Tensor] = None,
|
417 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
418 |
+
attention_mask: Optional[torch.Tensor] = None,
|
419 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
420 |
+
output_attentions: bool = False,
|
421 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
422 |
+
# IndicTransFlashAttention2 attention does not support output_attentions
|
423 |
+
if output_attentions:
|
424 |
+
raise ValueError("IndicTransFlashAttention2 attention does not support output_attentions")
|
425 |
+
|
426 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
427 |
+
# for the decoder
|
428 |
+
is_cross_attention = key_value_states is not None
|
429 |
+
|
430 |
+
bsz, q_len, _ = hidden_states.size()
|
431 |
+
|
432 |
+
# get query proj
|
433 |
+
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
|
434 |
+
# get key, value proj
|
435 |
+
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
436 |
+
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
437 |
+
# the provided `key_value_states` to support prefix tuning
|
438 |
+
if (
|
439 |
+
is_cross_attention
|
440 |
+
and past_key_value is not None
|
441 |
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
442 |
+
):
|
443 |
+
# reuse k,v, cross_attentions
|
444 |
+
key_states = past_key_value[0].transpose(1, 2)
|
445 |
+
value_states = past_key_value[1].transpose(1, 2)
|
446 |
+
elif is_cross_attention:
|
447 |
+
# cross_attentions
|
448 |
+
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
|
449 |
+
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
|
450 |
+
elif past_key_value is not None:
|
451 |
+
# reuse k, v, self_attention
|
452 |
+
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
453 |
+
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
454 |
+
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
|
455 |
+
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
|
456 |
+
else:
|
457 |
+
# self_attention
|
458 |
+
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
459 |
+
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
460 |
+
|
461 |
+
if self.is_decoder:
|
462 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
463 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
464 |
+
# key/value_states (first "if" case)
|
465 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
466 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
467 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
468 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
469 |
+
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
|
470 |
+
|
471 |
+
kv_seq_len = key_states.shape[-2]
|
472 |
+
if past_key_value is not None:
|
473 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
474 |
+
|
475 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
476 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
477 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
478 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
479 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
480 |
+
|
481 |
+
input_dtype = query_states.dtype
|
482 |
+
if input_dtype == torch.float32:
|
483 |
+
if torch.is_autocast_enabled():
|
484 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
485 |
+
# Handle the case where the model is quantized
|
486 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
487 |
+
target_dtype = self.config._pre_quantization_dtype
|
488 |
+
else:
|
489 |
+
target_dtype = self.q_proj.weight.dtype
|
490 |
+
|
491 |
+
logger.warning_once(
|
492 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
493 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
494 |
+
f" {target_dtype}."
|
495 |
+
)
|
496 |
+
|
497 |
+
query_states = query_states.to(target_dtype)
|
498 |
+
key_states = key_states.to(target_dtype)
|
499 |
+
value_states = value_states.to(target_dtype)
|
500 |
+
|
501 |
+
attn_output = self._flash_attention_forward(
|
502 |
+
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
|
503 |
+
)
|
504 |
+
|
505 |
+
attn_output = attn_output.reshape(bsz, q_len, -1)
|
506 |
+
attn_output = self.out_proj(attn_output)
|
507 |
+
|
508 |
+
if not output_attentions:
|
509 |
+
attn_weights = None
|
510 |
+
|
511 |
+
return attn_output, attn_weights, past_key_value
|
512 |
+
|
513 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
514 |
+
def _flash_attention_forward(
|
515 |
+
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
516 |
+
):
|
517 |
+
"""
|
518 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
519 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
520 |
+
|
521 |
+
Args:
|
522 |
+
query_states (`torch.Tensor`):
|
523 |
+
Input query states to be passed to Flash Attention API
|
524 |
+
key_states (`torch.Tensor`):
|
525 |
+
Input key states to be passed to Flash Attention API
|
526 |
+
value_states (`torch.Tensor`):
|
527 |
+
Input value states to be passed to Flash Attention API
|
528 |
+
attention_mask (`torch.Tensor`):
|
529 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
530 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
531 |
+
dropout (`float`):
|
532 |
+
Attention dropout
|
533 |
+
softmax_scale (`float`, *optional*):
|
534 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
535 |
+
"""
|
536 |
+
if not self._flash_attn_uses_top_left_mask:
|
537 |
+
causal = self.is_causal
|
538 |
+
else:
|
539 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
540 |
+
causal = self.is_causal and query_length != 1
|
541 |
+
|
542 |
+
# Contains at least one padding token in the sequence
|
543 |
+
if attention_mask is not None:
|
544 |
+
batch_size = query_states.shape[0]
|
545 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
546 |
+
query_states, key_states, value_states, attention_mask, query_length
|
547 |
+
)
|
548 |
+
|
549 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
550 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
551 |
+
|
552 |
+
attn_output_unpad = flash_attn_varlen_func(
|
553 |
+
query_states,
|
554 |
+
key_states,
|
555 |
+
value_states,
|
556 |
+
cu_seqlens_q=cu_seqlens_q,
|
557 |
+
cu_seqlens_k=cu_seqlens_k,
|
558 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
559 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
560 |
+
dropout_p=dropout,
|
561 |
+
softmax_scale=softmax_scale,
|
562 |
+
causal=causal,
|
563 |
+
)
|
564 |
+
|
565 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
566 |
+
else:
|
567 |
+
attn_output = flash_attn_func(
|
568 |
+
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
569 |
+
)
|
570 |
+
|
571 |
+
return attn_output
|
572 |
+
|
573 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
574 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
575 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
576 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
577 |
+
|
578 |
+
key_layer = index_first_axis(
|
579 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
580 |
+
)
|
581 |
+
value_layer = index_first_axis(
|
582 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
583 |
+
)
|
584 |
+
if query_length == kv_seq_len:
|
585 |
+
query_layer = index_first_axis(
|
586 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
587 |
+
)
|
588 |
+
cu_seqlens_q = cu_seqlens_k
|
589 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
590 |
+
indices_q = indices_k
|
591 |
+
elif query_length == 1:
|
592 |
+
max_seqlen_in_batch_q = 1
|
593 |
+
cu_seqlens_q = torch.arange(
|
594 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
595 |
+
) # There is a memcpy here, that is very bad.
|
596 |
+
indices_q = cu_seqlens_q[:-1]
|
597 |
+
query_layer = query_layer.squeeze(1)
|
598 |
+
else:
|
599 |
+
# The -q_len: slice assumes left padding.
|
600 |
+
attention_mask = attention_mask[:, -query_length:]
|
601 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
602 |
+
|
603 |
+
return (
|
604 |
+
query_layer,
|
605 |
+
key_layer,
|
606 |
+
value_layer,
|
607 |
+
indices_q,
|
608 |
+
(cu_seqlens_q, cu_seqlens_k),
|
609 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
610 |
+
)
|
611 |
+
|
612 |
+
|
613 |
+
class IndicTransSdpaAttention(IndicTransAttention):
|
614 |
+
def forward(
|
615 |
+
self,
|
616 |
+
hidden_states: torch.Tensor,
|
617 |
+
key_value_states: Optional[torch.Tensor] = None,
|
618 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
619 |
+
attention_mask: Optional[torch.Tensor] = None,
|
620 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
621 |
+
output_attentions: bool = False,
|
622 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
623 |
+
"""Input shape: Batch x Time x Channel"""
|
624 |
+
if output_attentions or layer_head_mask is not None:
|
625 |
+
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
626 |
+
logger.warning_once(
|
627 |
+
"IndicTransModel is using IndicTransSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
628 |
+
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
629 |
+
)
|
630 |
+
return super().forward(
|
631 |
+
hidden_states,
|
632 |
+
key_value_states=key_value_states,
|
633 |
+
past_key_value=past_key_value,
|
634 |
+
attention_mask=attention_mask,
|
635 |
+
layer_head_mask=layer_head_mask,
|
636 |
+
output_attentions=output_attentions,
|
637 |
+
)
|
638 |
+
|
639 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
640 |
+
# for the decoder
|
641 |
+
is_cross_attention = key_value_states is not None
|
642 |
+
|
643 |
+
bsz, tgt_len, _ = hidden_states.size()
|
644 |
+
|
645 |
+
# get query proj
|
646 |
+
query_states = self.q_proj(hidden_states)
|
647 |
+
# get key, value proj
|
648 |
+
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
649 |
+
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
650 |
+
# the provided `key_value_states` to support prefix tuning
|
651 |
+
if (
|
652 |
+
is_cross_attention
|
653 |
+
and past_key_value is not None
|
654 |
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
655 |
+
):
|
656 |
+
# reuse k,v, cross_attentions
|
657 |
+
key_states = past_key_value[0]
|
658 |
+
value_states = past_key_value[1]
|
659 |
+
elif is_cross_attention:
|
660 |
+
# cross_attentions
|
661 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
662 |
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
663 |
+
elif past_key_value is not None:
|
664 |
+
# reuse k, v, self_attention
|
665 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
666 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
667 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
668 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
669 |
+
else:
|
670 |
+
# self_attention
|
671 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
672 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
673 |
+
|
674 |
+
if self.is_decoder:
|
675 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
676 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
677 |
+
# key/value_states (first "if" case)
|
678 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
679 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
680 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
681 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
682 |
+
past_key_value = (key_states, value_states)
|
683 |
+
|
684 |
+
query_states = self._shape(query_states, tgt_len, bsz)
|
685 |
+
|
686 |
+
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
687 |
+
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
688 |
+
attn_output = F.scaled_dot_product_attention(
|
689 |
+
query_states,
|
690 |
+
key_states,
|
691 |
+
value_states,
|
692 |
+
attn_mask=attention_mask,
|
693 |
+
dropout_p=self.dropout if self.training else 0.0,
|
694 |
+
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
695 |
+
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
|
696 |
+
)
|
697 |
+
|
698 |
+
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
699 |
+
raise ValueError(
|
700 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
701 |
+
f" {attn_output.size()}"
|
702 |
+
)
|
703 |
+
|
704 |
+
attn_output = attn_output.transpose(1, 2)
|
705 |
+
|
706 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
707 |
+
# partitioned across GPUs when using tensor-parallelism.
|
708 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
709 |
+
|
710 |
+
attn_output = self.out_proj(attn_output)
|
711 |
+
|
712 |
+
return attn_output, None, past_key_value
|
713 |
+
|
714 |
+
|
715 |
+
INDICTRANS_ATTENTION_CLASSES = {
|
716 |
+
"eager": IndicTransAttention,
|
717 |
+
"sdpa": IndicTransSdpaAttention,
|
718 |
+
"flash_attention_2": IndicTransFlashAttention2,
|
719 |
+
}
|
720 |
+
|
721 |
+
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTrans
|
722 |
+
class IndicTransEncoderLayer(nn.Module):
|
723 |
+
def __init__(self, config: IndicTransConfig):
|
724 |
+
super().__init__()
|
725 |
+
self.embed_dim = config.encoder_embed_dim
|
726 |
+
self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
|
727 |
+
embed_dim=self.embed_dim,
|
728 |
+
num_heads=config.encoder_attention_heads,
|
729 |
+
dropout=config.attention_dropout,
|
730 |
+
config=config,
|
731 |
+
)
|
732 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
733 |
+
self.dropout = config.dropout
|
734 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
735 |
+
self.activation_dropout = config.activation_dropout
|
736 |
+
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
737 |
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
738 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
739 |
+
self.normalize_before = config.encoder_normalize_before
|
740 |
+
|
741 |
+
def forward(
|
742 |
+
self,
|
743 |
+
hidden_states: torch.Tensor,
|
744 |
+
attention_mask: torch.Tensor,
|
745 |
+
layer_head_mask: torch.Tensor,
|
746 |
+
output_attentions: bool = False,
|
747 |
+
) -> torch.Tensor:
|
748 |
+
"""
|
749 |
+
Args:
|
750 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
751 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
752 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
753 |
+
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
754 |
+
`(encoder_attention_heads,)`.
|
755 |
+
output_attentions (`bool`, *optional*):
|
756 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
757 |
+
returned tensors for more detail.
|
758 |
+
"""
|
759 |
+
residual = hidden_states
|
760 |
+
if self.normalize_before:
|
761 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
762 |
+
hidden_states, attn_weights, _ = self.self_attn(
|
763 |
+
hidden_states=hidden_states,
|
764 |
+
attention_mask=attention_mask,
|
765 |
+
layer_head_mask=layer_head_mask,
|
766 |
+
output_attentions=output_attentions,
|
767 |
+
)
|
768 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
769 |
+
hidden_states = residual + hidden_states
|
770 |
+
if not self.normalize_before:
|
771 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
772 |
+
|
773 |
+
residual = hidden_states
|
774 |
+
if self.normalize_before:
|
775 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
776 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
777 |
+
hidden_states = F.dropout(
|
778 |
+
hidden_states, p=self.activation_dropout, training=self.training
|
779 |
+
)
|
780 |
+
hidden_states = self.fc2(hidden_states)
|
781 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
782 |
+
hidden_states = residual + hidden_states
|
783 |
+
if not self.normalize_before:
|
784 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
785 |
+
|
786 |
+
if hidden_states.dtype == torch.float16 and (
|
787 |
+
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
788 |
+
):
|
789 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
790 |
+
hidden_states = torch.clamp(
|
791 |
+
hidden_states, min=-clamp_value, max=clamp_value
|
792 |
+
)
|
793 |
+
|
794 |
+
outputs = (hidden_states,)
|
795 |
+
|
796 |
+
if output_attentions:
|
797 |
+
outputs += (attn_weights,)
|
798 |
+
|
799 |
+
return outputs
|
800 |
+
|
801 |
+
|
802 |
+
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->IndicTrans
|
803 |
+
class IndicTransDecoderLayer(nn.Module):
|
804 |
+
def __init__(self, config: IndicTransConfig):
|
805 |
+
super().__init__()
|
806 |
+
self.embed_dim = config.decoder_embed_dim
|
807 |
+
|
808 |
+
self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
|
809 |
+
embed_dim=self.embed_dim,
|
810 |
+
num_heads=config.decoder_attention_heads,
|
811 |
+
dropout=config.attention_dropout,
|
812 |
+
is_decoder=True,
|
813 |
+
is_causal=True,
|
814 |
+
config=config,
|
815 |
+
)
|
816 |
+
self.dropout = config.dropout
|
817 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
818 |
+
self.activation_dropout = config.activation_dropout
|
819 |
+
|
820 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
821 |
+
self.encoder_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
|
822 |
+
self.embed_dim,
|
823 |
+
config.decoder_attention_heads,
|
824 |
+
dropout=config.attention_dropout,
|
825 |
+
is_decoder=True,
|
826 |
+
config=config,
|
827 |
+
)
|
828 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
829 |
+
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
830 |
+
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
831 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
832 |
+
self.normalize_before = config.decoder_normalize_before
|
833 |
+
|
834 |
+
def forward(
|
835 |
+
self,
|
836 |
+
hidden_states: torch.Tensor,
|
837 |
+
attention_mask: Optional[torch.Tensor] = None,
|
838 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
839 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
840 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
841 |
+
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
842 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
843 |
+
output_attentions: Optional[bool] = False,
|
844 |
+
use_cache: Optional[bool] = True,
|
845 |
+
) -> torch.Tensor:
|
846 |
+
"""
|
847 |
+
Args:
|
848 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
849 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
850 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
851 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
852 |
+
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
|
853 |
+
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
854 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
855 |
+
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
856 |
+
`(encoder_attention_heads,)`.
|
857 |
+
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
858 |
+
size `(decoder_attention_heads,)`.
|
859 |
+
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
860 |
+
output_attentions (`bool`, *optional*):
|
861 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
862 |
+
returned tensors for more detail.
|
863 |
+
"""
|
864 |
+
residual = hidden_states
|
865 |
+
if self.normalize_before:
|
866 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
867 |
+
|
868 |
+
# Self Attention
|
869 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
870 |
+
self_attn_past_key_value = (
|
871 |
+
past_key_value[:2] if past_key_value is not None else None
|
872 |
+
)
|
873 |
+
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
874 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
875 |
+
hidden_states=hidden_states,
|
876 |
+
past_key_value=self_attn_past_key_value,
|
877 |
+
attention_mask=attention_mask,
|
878 |
+
layer_head_mask=layer_head_mask,
|
879 |
+
output_attentions=output_attentions,
|
880 |
+
)
|
881 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
882 |
+
hidden_states = residual + hidden_states
|
883 |
+
if not self.normalize_before:
|
884 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
885 |
+
|
886 |
+
# Cross-Attention Block
|
887 |
+
cross_attn_present_key_value = None
|
888 |
+
cross_attn_weights = None
|
889 |
+
if encoder_hidden_states is not None:
|
890 |
+
residual = hidden_states
|
891 |
+
if self.normalize_before:
|
892 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
893 |
+
|
894 |
+
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
895 |
+
cross_attn_past_key_value = (
|
896 |
+
past_key_value[-2:] if past_key_value is not None else None
|
897 |
+
)
|
898 |
+
(
|
899 |
+
hidden_states,
|
900 |
+
cross_attn_weights,
|
901 |
+
cross_attn_present_key_value,
|
902 |
+
) = self.encoder_attn(
|
903 |
+
hidden_states=hidden_states,
|
904 |
+
key_value_states=encoder_hidden_states,
|
905 |
+
attention_mask=encoder_attention_mask,
|
906 |
+
layer_head_mask=cross_attn_layer_head_mask,
|
907 |
+
past_key_value=cross_attn_past_key_value,
|
908 |
+
output_attentions=output_attentions,
|
909 |
+
)
|
910 |
+
hidden_states = F.dropout(
|
911 |
+
hidden_states, p=self.dropout, training=self.training
|
912 |
+
)
|
913 |
+
hidden_states = residual + hidden_states
|
914 |
+
if not self.normalize_before:
|
915 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
916 |
+
|
917 |
+
# add cross-attn to positions 3,4 of present_key_value tuple
|
918 |
+
present_key_value = present_key_value + cross_attn_present_key_value
|
919 |
+
|
920 |
+
# Fully Connected
|
921 |
+
residual = hidden_states
|
922 |
+
if self.normalize_before:
|
923 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
924 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
925 |
+
hidden_states = F.dropout(
|
926 |
+
hidden_states, p=self.activation_dropout, training=self.training
|
927 |
+
)
|
928 |
+
hidden_states = self.fc2(hidden_states)
|
929 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
930 |
+
hidden_states = residual + hidden_states
|
931 |
+
if not self.normalize_before:
|
932 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
933 |
+
|
934 |
+
outputs = (hidden_states,)
|
935 |
+
|
936 |
+
if output_attentions:
|
937 |
+
outputs += (self_attn_weights, cross_attn_weights)
|
938 |
+
|
939 |
+
if use_cache:
|
940 |
+
outputs += (present_key_value,)
|
941 |
+
|
942 |
+
return outputs
|
943 |
+
|
944 |
+
|
945 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100PretrainedModel->IndicTrans
|
946 |
+
class IndicTransPreTrainedModel(PreTrainedModel):
|
947 |
+
config_class = IndicTransConfig
|
948 |
+
base_model_prefix = "model"
|
949 |
+
supports_gradient_checkpointing = True
|
950 |
+
_no_split_modules = ["IndicTransAttention"]
|
951 |
+
|
952 |
+
def _init_weights(self, module):
|
953 |
+
std = self.config.init_std
|
954 |
+
if isinstance(module, nn.Linear):
|
955 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
956 |
+
if module.bias is not None:
|
957 |
+
module.bias.data.zero_()
|
958 |
+
elif isinstance(module, nn.Embedding):
|
959 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
960 |
+
if module.padding_idx is not None:
|
961 |
+
module.weight.data[module.padding_idx].zero_()
|
962 |
+
|
963 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
964 |
+
if isinstance(module, (IndicTransDecoder, IndicTransEncoder)):
|
965 |
+
module.gradient_checkpointing = value
|
966 |
+
|
967 |
+
|
968 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100EncoderLayer->IndicTrans
|
969 |
+
class IndicTransEncoder(IndicTransPreTrainedModel):
|
970 |
+
"""
|
971 |
+
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
972 |
+
[`IndicTransEncoderLayer`].
|
973 |
+
|
974 |
+
Args:
|
975 |
+
config: IndicTransConfig
|
976 |
+
embed_tokens (nn.Embedding): output embedding
|
977 |
+
"""
|
978 |
+
|
979 |
+
def __init__(
|
980 |
+
self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
|
981 |
+
):
|
982 |
+
super().__init__(config)
|
983 |
+
|
984 |
+
self.dropout = config.dropout
|
985 |
+
self.layerdrop = config.encoder_layerdrop
|
986 |
+
|
987 |
+
embed_dim = config.encoder_embed_dim
|
988 |
+
self.padding_idx = config.pad_token_id
|
989 |
+
self.max_source_positions = config.max_source_positions
|
990 |
+
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
991 |
+
|
992 |
+
self.embed_tokens = nn.Embedding(
|
993 |
+
config.encoder_vocab_size, embed_dim, self.padding_idx
|
994 |
+
)
|
995 |
+
|
996 |
+
if embed_tokens is not None:
|
997 |
+
self.embed_tokens.weight = embed_tokens.weight
|
998 |
+
|
999 |
+
self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
|
1000 |
+
config.max_source_positions,
|
1001 |
+
embed_dim,
|
1002 |
+
self.padding_idx,
|
1003 |
+
)
|
1004 |
+
self.layers = nn.ModuleList(
|
1005 |
+
[IndicTransEncoderLayer(config) for _ in range(config.encoder_layers)]
|
1006 |
+
)
|
1007 |
+
self.layer_norm = (
|
1008 |
+
nn.LayerNorm(embed_dim) if config.encoder_normalize_before else None
|
1009 |
+
)
|
1010 |
+
self.layernorm_embedding = (
|
1011 |
+
nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
|
1012 |
+
)
|
1013 |
+
|
1014 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
1015 |
+
self._use_sdpa = config._attn_implementation == "sdpa"
|
1016 |
+
|
1017 |
+
self.gradient_checkpointing = False
|
1018 |
+
# Initialize weights and apply final processing
|
1019 |
+
self.post_init()
|
1020 |
+
|
1021 |
+
def forward(
|
1022 |
+
self,
|
1023 |
+
input_ids: Optional[torch.Tensor] = None,
|
1024 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1025 |
+
head_mask: Optional[torch.Tensor] = None,
|
1026 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1027 |
+
output_attentions: Optional[bool] = None,
|
1028 |
+
output_hidden_states: Optional[bool] = None,
|
1029 |
+
return_dict: Optional[bool] = None,
|
1030 |
+
):
|
1031 |
+
r"""
|
1032 |
+
Args:
|
1033 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
1034 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
1035 |
+
provide it.
|
1036 |
+
|
1037 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
1038 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
1039 |
+
|
1040 |
+
[What are input IDs?](../glossary#input-ids)
|
1041 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1042 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
1043 |
+
|
1044 |
+
- 1 for tokens that are **not masked**,
|
1045 |
+
- 0 for tokens that are **masked**.
|
1046 |
+
|
1047 |
+
[What are attention masks?](../glossary#attention-mask)
|
1048 |
+
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
1049 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
1050 |
+
|
1051 |
+
- 1 indicates the head is **not masked**,
|
1052 |
+
- 0 indicates the head is **masked**.
|
1053 |
+
|
1054 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
1055 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
1056 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
1057 |
+
than the model's internal embedding lookup matrix.
|
1058 |
+
output_attentions (`bool`, *optional*):
|
1059 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
1060 |
+
returned tensors for more detail.
|
1061 |
+
output_hidden_states (`bool`, *optional*):
|
1062 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
1063 |
+
for more detail.
|
1064 |
+
return_dict (`bool`, *optional*):
|
1065 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
1066 |
+
"""
|
1067 |
+
output_attentions = (
|
1068 |
+
output_attentions
|
1069 |
+
if output_attentions is not None
|
1070 |
+
else self.config.output_attentions
|
1071 |
+
)
|
1072 |
+
output_hidden_states = (
|
1073 |
+
output_hidden_states
|
1074 |
+
if output_hidden_states is not None
|
1075 |
+
else self.config.output_hidden_states
|
1076 |
+
)
|
1077 |
+
return_dict = (
|
1078 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1079 |
+
)
|
1080 |
+
|
1081 |
+
# retrieve input_ids and inputs_embeds
|
1082 |
+
if input_ids is not None and inputs_embeds is not None:
|
1083 |
+
raise ValueError(
|
1084 |
+
"You cannot specify both input_ids and inputs_embeds at the same time"
|
1085 |
+
)
|
1086 |
+
elif input_ids is not None:
|
1087 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
1088 |
+
input_shape = input_ids.size()
|
1089 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
1090 |
+
elif inputs_embeds is not None:
|
1091 |
+
input_shape = inputs_embeds.size()[:-1]
|
1092 |
+
else:
|
1093 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
1094 |
+
|
1095 |
+
if inputs_embeds is None:
|
1096 |
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
1097 |
+
|
1098 |
+
embed_pos = self.embed_positions(input_ids, inputs_embeds)
|
1099 |
+
embed_pos = embed_pos.to(inputs_embeds.device)
|
1100 |
+
|
1101 |
+
hidden_states = inputs_embeds + embed_pos
|
1102 |
+
if self.layernorm_embedding is not None:
|
1103 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
1104 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
1105 |
+
|
1106 |
+
if attention_mask is not None:
|
1107 |
+
if self._use_flash_attention_2:
|
1108 |
+
attention_mask = attention_mask if 0 in attention_mask else None
|
1109 |
+
elif self._use_sdpa and head_mask is None and not output_attentions:
|
1110 |
+
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
|
1111 |
+
# the manual implementation that requires a 4D causal mask in all cases.
|
1112 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
1113 |
+
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
|
1114 |
+
else:
|
1115 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
1116 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
1117 |
+
|
1118 |
+
|
1119 |
+
encoder_states = () if output_hidden_states else None
|
1120 |
+
all_attentions = () if output_attentions else None
|
1121 |
+
|
1122 |
+
# check if head_mask has a correct number of layers specified if desired
|
1123 |
+
if head_mask is not None:
|
1124 |
+
if head_mask.size()[0] != len(self.layers):
|
1125 |
+
raise ValueError(
|
1126 |
+
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
|
1127 |
+
f" {head_mask.size()[0]}."
|
1128 |
+
)
|
1129 |
+
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
1130 |
+
|
1131 |
+
for idx, encoder_layer in enumerate(self.layers):
|
1132 |
+
if output_hidden_states:
|
1133 |
+
encoder_states = encoder_states + (hidden_states,)
|
1134 |
+
|
1135 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
1136 |
+
dropout_probability = torch.rand([])
|
1137 |
+
|
1138 |
+
skip_the_layer = (
|
1139 |
+
True
|
1140 |
+
if self.training and (dropout_probability < self.layerdrop)
|
1141 |
+
else False
|
1142 |
+
)
|
1143 |
+
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
1144 |
+
# under deepspeed zero3 all gpus must run in sync
|
1145 |
+
|
1146 |
+
if self.gradient_checkpointing and self.training:
|
1147 |
+
# create gradient checkpointing function
|
1148 |
+
def create_custom_forward(module):
|
1149 |
+
def custom_forward(*inputs):
|
1150 |
+
return module(*inputs, output_attentions)
|
1151 |
+
|
1152 |
+
return custom_forward
|
1153 |
+
|
1154 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
1155 |
+
create_custom_forward(encoder_layer),
|
1156 |
+
hidden_states,
|
1157 |
+
attention_mask,
|
1158 |
+
(head_mask[idx] if head_mask is not None else None),
|
1159 |
+
)
|
1160 |
+
else:
|
1161 |
+
layer_outputs = encoder_layer(
|
1162 |
+
hidden_states,
|
1163 |
+
attention_mask,
|
1164 |
+
layer_head_mask=(
|
1165 |
+
head_mask[idx] if head_mask is not None else None
|
1166 |
+
),
|
1167 |
+
output_attentions=output_attentions,
|
1168 |
+
)
|
1169 |
+
|
1170 |
+
hidden_states = layer_outputs[0]
|
1171 |
+
|
1172 |
+
if skip_the_layer:
|
1173 |
+
layer_outputs = (None, None)
|
1174 |
+
|
1175 |
+
if output_attentions:
|
1176 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
1177 |
+
|
1178 |
+
if self.layer_norm is not None:
|
1179 |
+
hidden_states = self.layer_norm(hidden_states)
|
1180 |
+
|
1181 |
+
if output_hidden_states:
|
1182 |
+
encoder_states = encoder_states + (hidden_states,)
|
1183 |
+
|
1184 |
+
if not return_dict:
|
1185 |
+
return tuple(
|
1186 |
+
v
|
1187 |
+
for v in [hidden_states, encoder_states, all_attentions]
|
1188 |
+
if v is not None
|
1189 |
+
)
|
1190 |
+
return BaseModelOutput(
|
1191 |
+
last_hidden_state=hidden_states,
|
1192 |
+
hidden_states=encoder_states,
|
1193 |
+
attentions=all_attentions,
|
1194 |
+
)
|
1195 |
+
|
1196 |
+
|
1197 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100DecoderLayer->IndicTrans
|
1198 |
+
class IndicTransDecoder(IndicTransPreTrainedModel):
|
1199 |
+
"""
|
1200 |
+
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`IndicTransDecoderLayer`]
|
1201 |
+
|
1202 |
+
Args:
|
1203 |
+
config: IndicTransConfig
|
1204 |
+
embed_tokens (nn.Embedding): output embedding
|
1205 |
+
"""
|
1206 |
+
|
1207 |
+
def __init__(
|
1208 |
+
self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
|
1209 |
+
):
|
1210 |
+
super().__init__(config)
|
1211 |
+
self.dropout = config.dropout
|
1212 |
+
self.layerdrop = config.decoder_layerdrop
|
1213 |
+
|
1214 |
+
embed_dim = config.encoder_embed_dim
|
1215 |
+
self.padding_idx = config.pad_token_id
|
1216 |
+
self.max_target_positions = config.max_target_positions
|
1217 |
+
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
1218 |
+
|
1219 |
+
self.embed_tokens = nn.Embedding(
|
1220 |
+
config.decoder_vocab_size, embed_dim, self.padding_idx
|
1221 |
+
)
|
1222 |
+
|
1223 |
+
if embed_tokens is not None:
|
1224 |
+
self.embed_tokens.weight = embed_tokens.weight
|
1225 |
+
|
1226 |
+
self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
|
1227 |
+
config.max_target_positions,
|
1228 |
+
embed_dim,
|
1229 |
+
self.padding_idx,
|
1230 |
+
)
|
1231 |
+
self.layers = nn.ModuleList(
|
1232 |
+
[IndicTransDecoderLayer(config) for _ in range(config.decoder_layers)]
|
1233 |
+
)
|
1234 |
+
self.layer_norm = (
|
1235 |
+
nn.LayerNorm(embed_dim) if config.decoder_normalize_before else None
|
1236 |
+
)
|
1237 |
+
self.layernorm_embedding = (
|
1238 |
+
nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
|
1239 |
+
)
|
1240 |
+
|
1241 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
1242 |
+
self._use_sdpa = config._attn_implementation == "sdpa"
|
1243 |
+
|
1244 |
+
self.gradient_checkpointing = False
|
1245 |
+
# Initialize weights and apply final processing
|
1246 |
+
self.post_init()
|
1247 |
+
|
1248 |
+
def forward(
|
1249 |
+
self,
|
1250 |
+
input_ids: Optional[torch.Tensor] = None,
|
1251 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1252 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1253 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1254 |
+
head_mask: Optional[torch.Tensor] = None,
|
1255 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
1256 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1257 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1258 |
+
use_cache: Optional[bool] = None,
|
1259 |
+
output_attentions: Optional[bool] = None,
|
1260 |
+
output_hidden_states: Optional[bool] = None,
|
1261 |
+
return_dict: Optional[bool] = None,
|
1262 |
+
):
|
1263 |
+
r"""
|
1264 |
+
Args:
|
1265 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
1266 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
1267 |
+
provide it.
|
1268 |
+
|
1269 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
1270 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
1271 |
+
|
1272 |
+
[What are input IDs?](../glossary#input-ids)
|
1273 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1274 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
1275 |
+
|
1276 |
+
- 1 for tokens that are **not masked**,
|
1277 |
+
- 0 for tokens that are **masked**.
|
1278 |
+
|
1279 |
+
[What are attention masks?](../glossary#attention-mask)
|
1280 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
|
1281 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
1282 |
+
of the decoder.
|
1283 |
+
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
|
1284 |
+
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
|
1285 |
+
selected in `[0, 1]`:
|
1286 |
+
|
1287 |
+
- 1 for tokens that are **not masked**,
|
1288 |
+
- 0 for tokens that are **masked**.
|
1289 |
+
|
1290 |
+
[What are attention masks?](../glossary#attention-mask)
|
1291 |
+
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
1292 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
1293 |
+
|
1294 |
+
- 1 indicates the head is **not masked**,
|
1295 |
+
- 0 indicates the head is **masked**.
|
1296 |
+
|
1297 |
+
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
1298 |
+
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
|
1299 |
+
cross-attention on hidden heads. Mask values selected in `[0, 1]`:
|
1300 |
+
|
1301 |
+
- 1 indicates the head is **not masked**,
|
1302 |
+
- 0 indicates the head is **masked**.
|
1303 |
+
|
1304 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
1305 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
1306 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
1307 |
+
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
1308 |
+
|
1309 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
1310 |
+
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
1311 |
+
|
1312 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
1313 |
+
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
1314 |
+
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
|
1315 |
+
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
|
1316 |
+
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
|
1317 |
+
control over how to convert `input_ids` indices into associated vectors than the model's internal
|
1318 |
+
embedding lookup matrix.
|
1319 |
+
output_attentions (`bool`, *optional*):
|
1320 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
1321 |
+
returned tensors for more detail.
|
1322 |
+
output_hidden_states (`bool`, *optional*):
|
1323 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
1324 |
+
for more detail.
|
1325 |
+
return_dict (`bool`, *optional*):
|
1326 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
1327 |
+
"""
|
1328 |
+
output_attentions = (
|
1329 |
+
output_attentions
|
1330 |
+
if output_attentions is not None
|
1331 |
+
else self.config.output_attentions
|
1332 |
+
)
|
1333 |
+
output_hidden_states = (
|
1334 |
+
output_hidden_states
|
1335 |
+
if output_hidden_states is not None
|
1336 |
+
else self.config.output_hidden_states
|
1337 |
+
)
|
1338 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1339 |
+
return_dict = (
|
1340 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1341 |
+
)
|
1342 |
+
|
1343 |
+
# retrieve input_ids and inputs_embeds
|
1344 |
+
if input_ids is not None and inputs_embeds is not None:
|
1345 |
+
raise ValueError(
|
1346 |
+
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
1347 |
+
)
|
1348 |
+
elif input_ids is not None:
|
1349 |
+
input_shape = input_ids.size()
|
1350 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
1351 |
+
elif inputs_embeds is not None:
|
1352 |
+
input_shape = inputs_embeds.size()[:-1]
|
1353 |
+
else:
|
1354 |
+
raise ValueError(
|
1355 |
+
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
1356 |
+
)
|
1357 |
+
|
1358 |
+
# past_key_values_length
|
1359 |
+
past_key_values_length = (
|
1360 |
+
past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
1361 |
+
)
|
1362 |
+
|
1363 |
+
if inputs_embeds is None:
|
1364 |
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
1365 |
+
|
1366 |
+
|
1367 |
+
if self._use_flash_attention_2:
|
1368 |
+
# 2d mask is passed through the layers
|
1369 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
1370 |
+
elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
|
1371 |
+
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
1372 |
+
# the manual implementation that requires a 4D causal mask in all cases.
|
1373 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
1374 |
+
attention_mask,
|
1375 |
+
input_shape,
|
1376 |
+
inputs_embeds,
|
1377 |
+
past_key_values_length,
|
1378 |
+
)
|
1379 |
+
else:
|
1380 |
+
# 4d mask is passed through the layers
|
1381 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
1382 |
+
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
1383 |
+
)
|
1384 |
+
|
1385 |
+
# expand encoder attention mask
|
1386 |
+
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
1387 |
+
if self._use_flash_attention_2:
|
1388 |
+
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
|
1389 |
+
elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
|
1390 |
+
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
1391 |
+
# the manual implementation that requires a 4D causal mask in all cases.
|
1392 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
1393 |
+
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
1394 |
+
encoder_attention_mask,
|
1395 |
+
inputs_embeds.dtype,
|
1396 |
+
tgt_len=input_shape[-1],
|
1397 |
+
)
|
1398 |
+
else:
|
1399 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
1400 |
+
encoder_attention_mask = _prepare_4d_attention_mask(
|
1401 |
+
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
1402 |
+
)
|
1403 |
+
|
1404 |
+
# embed positions
|
1405 |
+
positions = self.embed_positions(
|
1406 |
+
input_ids, inputs_embeds, past_key_values_length
|
1407 |
+
)
|
1408 |
+
positions = positions.to(inputs_embeds.device)
|
1409 |
+
|
1410 |
+
hidden_states = inputs_embeds + positions
|
1411 |
+
if self.layernorm_embedding is not None:
|
1412 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
1413 |
+
|
1414 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
1415 |
+
|
1416 |
+
if self.gradient_checkpointing and self.training:
|
1417 |
+
if use_cache:
|
1418 |
+
logger.warning_once(
|
1419 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting"
|
1420 |
+
" `use_cache=False`..."
|
1421 |
+
)
|
1422 |
+
use_cache = False
|
1423 |
+
|
1424 |
+
# decoder layers
|
1425 |
+
all_hidden_states = () if output_hidden_states else None
|
1426 |
+
all_self_attns = () if output_attentions else None
|
1427 |
+
all_cross_attentions = () if output_attentions else None
|
1428 |
+
next_decoder_cache = () if use_cache else None
|
1429 |
+
|
1430 |
+
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
1431 |
+
for attn_mask, mask_name in zip(
|
1432 |
+
[head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
|
1433 |
+
):
|
1434 |
+
if attn_mask is not None:
|
1435 |
+
if attn_mask.size()[0] != len(self.layers):
|
1436 |
+
raise ValueError(
|
1437 |
+
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
1438 |
+
f" {head_mask.size()[0]}."
|
1439 |
+
)
|
1440 |
+
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
1441 |
+
|
1442 |
+
for idx, decoder_layer in enumerate(self.layers):
|
1443 |
+
if output_hidden_states:
|
1444 |
+
all_hidden_states += (hidden_states,)
|
1445 |
+
|
1446 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
1447 |
+
dropout_probability = torch.rand([])
|
1448 |
+
|
1449 |
+
skip_the_layer = (
|
1450 |
+
True
|
1451 |
+
if self.training and (dropout_probability < self.layerdrop)
|
1452 |
+
else False
|
1453 |
+
)
|
1454 |
+
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
1455 |
+
# under deepspeed zero3 all gpus must run in sync
|
1456 |
+
|
1457 |
+
past_key_value = (
|
1458 |
+
past_key_values[idx] if past_key_values is not None else None
|
1459 |
+
)
|
1460 |
+
|
1461 |
+
if self.gradient_checkpointing and self.training:
|
1462 |
+
|
1463 |
+
def create_custom_forward(module):
|
1464 |
+
def custom_forward(*inputs):
|
1465 |
+
# None for past_key_value
|
1466 |
+
return module(*inputs, output_attentions, use_cache)
|
1467 |
+
|
1468 |
+
return custom_forward
|
1469 |
+
|
1470 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
1471 |
+
create_custom_forward(decoder_layer),
|
1472 |
+
hidden_states,
|
1473 |
+
attention_mask,
|
1474 |
+
encoder_hidden_states,
|
1475 |
+
encoder_attention_mask,
|
1476 |
+
head_mask[idx] if head_mask is not None else None,
|
1477 |
+
cross_attn_head_mask[idx]
|
1478 |
+
if cross_attn_head_mask is not None
|
1479 |
+
else None,
|
1480 |
+
None,
|
1481 |
+
)
|
1482 |
+
else:
|
1483 |
+
layer_outputs = decoder_layer(
|
1484 |
+
hidden_states,
|
1485 |
+
attention_mask=attention_mask,
|
1486 |
+
encoder_hidden_states=encoder_hidden_states,
|
1487 |
+
encoder_attention_mask=encoder_attention_mask,
|
1488 |
+
layer_head_mask=(
|
1489 |
+
head_mask[idx] if head_mask is not None else None
|
1490 |
+
),
|
1491 |
+
cross_attn_layer_head_mask=(
|
1492 |
+
cross_attn_head_mask[idx]
|
1493 |
+
if cross_attn_head_mask is not None
|
1494 |
+
else None
|
1495 |
+
),
|
1496 |
+
past_key_value=past_key_value,
|
1497 |
+
output_attentions=output_attentions,
|
1498 |
+
use_cache=use_cache,
|
1499 |
+
)
|
1500 |
+
|
1501 |
+
hidden_states = layer_outputs[0]
|
1502 |
+
|
1503 |
+
if skip_the_layer:
|
1504 |
+
continue
|
1505 |
+
|
1506 |
+
if use_cache:
|
1507 |
+
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
1508 |
+
|
1509 |
+
if output_attentions:
|
1510 |
+
all_self_attns += (layer_outputs[1],)
|
1511 |
+
all_cross_attentions += (layer_outputs[2],)
|
1512 |
+
|
1513 |
+
if self.layer_norm is not None:
|
1514 |
+
hidden_states = self.layer_norm(hidden_states)
|
1515 |
+
|
1516 |
+
# add hidden states from the last decoder layer
|
1517 |
+
if output_hidden_states:
|
1518 |
+
all_hidden_states += (hidden_states,)
|
1519 |
+
|
1520 |
+
next_cache = next_decoder_cache if use_cache else None
|
1521 |
+
if not return_dict:
|
1522 |
+
return tuple(
|
1523 |
+
v
|
1524 |
+
for v in [
|
1525 |
+
hidden_states,
|
1526 |
+
next_cache,
|
1527 |
+
all_hidden_states,
|
1528 |
+
all_self_attns,
|
1529 |
+
all_cross_attentions,
|
1530 |
+
]
|
1531 |
+
if v is not None
|
1532 |
+
)
|
1533 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
1534 |
+
last_hidden_state=hidden_states,
|
1535 |
+
past_key_values=next_cache,
|
1536 |
+
hidden_states=all_hidden_states,
|
1537 |
+
attentions=all_self_attns,
|
1538 |
+
cross_attentions=all_cross_attentions,
|
1539 |
+
)
|
1540 |
+
|
1541 |
+
|
1542 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->IndicTrans
|
1543 |
+
class IndicTransModel(IndicTransPreTrainedModel):
|
1544 |
+
_tied_weights_keys = None
|
1545 |
+
|
1546 |
+
def __init__(self, config: IndicTransConfig):
|
1547 |
+
super().__init__(config)
|
1548 |
+
|
1549 |
+
self.encoder = IndicTransEncoder(config)
|
1550 |
+
self.decoder = IndicTransDecoder(config)
|
1551 |
+
|
1552 |
+
# Initialize weights and apply final processing
|
1553 |
+
self.post_init()
|
1554 |
+
|
1555 |
+
def get_encoder(self):
|
1556 |
+
return self.encoder
|
1557 |
+
|
1558 |
+
def get_decoder(self):
|
1559 |
+
return self.decoder
|
1560 |
+
|
1561 |
+
def forward(
|
1562 |
+
self,
|
1563 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1564 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1565 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1566 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
1567 |
+
head_mask: Optional[torch.Tensor] = None,
|
1568 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
1569 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
1570 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1571 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1572 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1573 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
1574 |
+
use_cache: Optional[bool] = None,
|
1575 |
+
output_attentions: Optional[bool] = None,
|
1576 |
+
output_hidden_states: Optional[bool] = None,
|
1577 |
+
return_dict: Optional[bool] = None,
|
1578 |
+
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
|
1579 |
+
output_attentions = (
|
1580 |
+
output_attentions
|
1581 |
+
if output_attentions is not None
|
1582 |
+
else self.config.output_attentions
|
1583 |
+
)
|
1584 |
+
output_hidden_states = (
|
1585 |
+
output_hidden_states
|
1586 |
+
if output_hidden_states is not None
|
1587 |
+
else self.config.output_hidden_states
|
1588 |
+
)
|
1589 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1590 |
+
return_dict = (
|
1591 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1592 |
+
)
|
1593 |
+
|
1594 |
+
if encoder_outputs is None:
|
1595 |
+
encoder_outputs = self.encoder(
|
1596 |
+
input_ids=input_ids,
|
1597 |
+
attention_mask=attention_mask,
|
1598 |
+
head_mask=head_mask,
|
1599 |
+
inputs_embeds=inputs_embeds,
|
1600 |
+
output_attentions=output_attentions,
|
1601 |
+
output_hidden_states=output_hidden_states,
|
1602 |
+
return_dict=return_dict,
|
1603 |
+
)
|
1604 |
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
1605 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
1606 |
+
encoder_outputs = BaseModelOutput(
|
1607 |
+
last_hidden_state=encoder_outputs[0],
|
1608 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
1609 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
1610 |
+
)
|
1611 |
+
|
1612 |
+
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
1613 |
+
decoder_outputs = self.decoder(
|
1614 |
+
input_ids=decoder_input_ids,
|
1615 |
+
attention_mask=decoder_attention_mask,
|
1616 |
+
encoder_hidden_states=encoder_outputs[0],
|
1617 |
+
encoder_attention_mask=attention_mask,
|
1618 |
+
head_mask=decoder_head_mask,
|
1619 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
1620 |
+
past_key_values=past_key_values,
|
1621 |
+
inputs_embeds=decoder_inputs_embeds,
|
1622 |
+
use_cache=use_cache,
|
1623 |
+
output_attentions=output_attentions,
|
1624 |
+
output_hidden_states=output_hidden_states,
|
1625 |
+
return_dict=return_dict,
|
1626 |
+
)
|
1627 |
+
|
1628 |
+
if not return_dict:
|
1629 |
+
return decoder_outputs + encoder_outputs
|
1630 |
+
|
1631 |
+
return Seq2SeqModelOutput(
|
1632 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
1633 |
+
past_key_values=decoder_outputs.past_key_values,
|
1634 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
1635 |
+
decoder_attentions=decoder_outputs.attentions,
|
1636 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
1637 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
1638 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
1639 |
+
encoder_attentions=encoder_outputs.attentions,
|
1640 |
+
)
|
1641 |
+
|
1642 |
+
|
1643 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
|
1644 |
+
class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
1645 |
+
base_model_prefix = "model"
|
1646 |
+
_tied_weights_keys = None
|
1647 |
+
_label_smoothing = 0.0
|
1648 |
+
|
1649 |
+
def __init__(self, config: IndicTransConfig):
|
1650 |
+
super().__init__(config)
|
1651 |
+
self.model = IndicTransModel(config)
|
1652 |
+
self.lm_head = nn.Linear(
|
1653 |
+
config.decoder_embed_dim, config.decoder_vocab_size, bias=False
|
1654 |
+
)
|
1655 |
+
|
1656 |
+
if config.share_decoder_input_output_embed:
|
1657 |
+
self.lm_head.weight = self.model.decoder.embed_tokens.weight
|
1658 |
+
|
1659 |
+
self.post_init()
|
1660 |
+
|
1661 |
+
def tie_weights(self):
|
1662 |
+
pass
|
1663 |
+
|
1664 |
+
def get_encoder(self):
|
1665 |
+
return self.model.get_encoder()
|
1666 |
+
|
1667 |
+
def get_decoder(self):
|
1668 |
+
return self.model.get_decoder()
|
1669 |
+
|
1670 |
+
def get_output_embeddings(self):
|
1671 |
+
return self.lm_head
|
1672 |
+
|
1673 |
+
def set_output_embeddings(self, new_embeddings):
|
1674 |
+
self.lm_head = new_embeddings
|
1675 |
+
|
1676 |
+
def set_label_smoothing(self, label_smoothing):
|
1677 |
+
self._label_smoothing = label_smoothing
|
1678 |
+
|
1679 |
+
def forward(
|
1680 |
+
self,
|
1681 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1682 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1683 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1684 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
1685 |
+
head_mask: Optional[torch.Tensor] = None,
|
1686 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
1687 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
1688 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1689 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1690 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1691 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
1692 |
+
labels: Optional[torch.LongTensor] = None,
|
1693 |
+
use_cache: Optional[bool] = None,
|
1694 |
+
output_attentions: Optional[bool] = None,
|
1695 |
+
output_hidden_states: Optional[bool] = None,
|
1696 |
+
return_dict: Optional[bool] = None,
|
1697 |
+
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
1698 |
+
r"""
|
1699 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1700 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1701 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1702 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1703 |
+
|
1704 |
+
Returns:
|
1705 |
+
"""
|
1706 |
+
return_dict = (
|
1707 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1708 |
+
)
|
1709 |
+
|
1710 |
+
if labels is not None:
|
1711 |
+
if decoder_input_ids is None:
|
1712 |
+
decoder_input_ids = shift_tokens_right(
|
1713 |
+
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
1714 |
+
)
|
1715 |
+
|
1716 |
+
outputs = self.model(
|
1717 |
+
input_ids,
|
1718 |
+
attention_mask=attention_mask,
|
1719 |
+
decoder_input_ids=decoder_input_ids,
|
1720 |
+
encoder_outputs=encoder_outputs,
|
1721 |
+
decoder_attention_mask=decoder_attention_mask,
|
1722 |
+
head_mask=head_mask,
|
1723 |
+
decoder_head_mask=decoder_head_mask,
|
1724 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
1725 |
+
past_key_values=past_key_values,
|
1726 |
+
inputs_embeds=inputs_embeds,
|
1727 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
1728 |
+
use_cache=use_cache,
|
1729 |
+
output_attentions=output_attentions,
|
1730 |
+
output_hidden_states=output_hidden_states,
|
1731 |
+
return_dict=return_dict,
|
1732 |
+
)
|
1733 |
+
lm_logits = self.lm_head(outputs[0])
|
1734 |
+
|
1735 |
+
masked_lm_loss = None
|
1736 |
+
if labels is not None:
|
1737 |
+
# move labels to the correct device to enable PP
|
1738 |
+
labels = labels.to(lm_logits.device)
|
1739 |
+
masked_lm_loss = F.cross_entropy(
|
1740 |
+
input=lm_logits.view(-1, self.config.decoder_vocab_size),
|
1741 |
+
target=labels.view(-1),
|
1742 |
+
ignore_index=-100,
|
1743 |
+
label_smoothing=self._label_smoothing,
|
1744 |
+
)
|
1745 |
+
|
1746 |
+
if not return_dict:
|
1747 |
+
output = (lm_logits,) + outputs[1:]
|
1748 |
+
return (
|
1749 |
+
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
1750 |
+
)
|
1751 |
+
|
1752 |
+
return Seq2SeqLMOutput(
|
1753 |
+
loss=masked_lm_loss,
|
1754 |
+
logits=lm_logits,
|
1755 |
+
past_key_values=outputs.past_key_values,
|
1756 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
1757 |
+
decoder_attentions=outputs.decoder_attentions,
|
1758 |
+
cross_attentions=outputs.cross_attentions,
|
1759 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
1760 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
1761 |
+
encoder_attentions=outputs.encoder_attentions,
|
1762 |
+
)
|
1763 |
+
|
1764 |
+
def prepare_inputs_for_generation(
|
1765 |
+
self,
|
1766 |
+
decoder_input_ids,
|
1767 |
+
past_key_values=None,
|
1768 |
+
attention_mask=None,
|
1769 |
+
head_mask=None,
|
1770 |
+
decoder_head_mask=None,
|
1771 |
+
cross_attn_head_mask=None,
|
1772 |
+
use_cache=None,
|
1773 |
+
encoder_outputs=None,
|
1774 |
+
**kwargs,
|
1775 |
+
):
|
1776 |
+
# cut decoder_input_ids if past is used
|
1777 |
+
if past_key_values is not None:
|
1778 |
+
decoder_input_ids = decoder_input_ids[:, -1:]
|
1779 |
+
|
1780 |
+
return {
|
1781 |
+
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
1782 |
+
"encoder_outputs": encoder_outputs,
|
1783 |
+
"past_key_values": past_key_values,
|
1784 |
+
"decoder_input_ids": decoder_input_ids,
|
1785 |
+
"attention_mask": attention_mask,
|
1786 |
+
"head_mask": head_mask,
|
1787 |
+
"decoder_head_mask": decoder_head_mask,
|
1788 |
+
"cross_attn_head_mask": cross_attn_head_mask,
|
1789 |
+
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
1790 |
+
}
|
1791 |
+
|
1792 |
+
@staticmethod
|
1793 |
+
def _reorder_cache(past_key_values, beam_idx):
|
1794 |
+
reordered_past = ()
|
1795 |
+
for layer_past in past_key_values:
|
1796 |
+
reordered_past += (
|
1797 |
+
tuple(
|
1798 |
+
past_state.index_select(0, beam_idx) for past_state in layer_past
|
1799 |
+
),
|
1800 |
+
)
|
1801 |
+
return reordered_past
|
IndicTrans2/huggingface_interface/train_lora.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import pandas as pd
|
4 |
+
from datasets import Dataset
|
5 |
+
from sacrebleu.metrics import BLEU, CHRF
|
6 |
+
from peft import LoraConfig, get_peft_model
|
7 |
+
from IndicTransToolkit import IndicProcessor, IndicDataCollator
|
8 |
+
|
9 |
+
from transformers import (
|
10 |
+
Seq2SeqTrainer,
|
11 |
+
Seq2SeqTrainingArguments,
|
12 |
+
AutoModelForSeq2SeqLM,
|
13 |
+
AutoTokenizer,
|
14 |
+
EarlyStoppingCallback,
|
15 |
+
)
|
16 |
+
|
17 |
+
bleu_metric = BLEU()
|
18 |
+
chrf_metric = CHRF()
|
19 |
+
|
20 |
+
|
21 |
+
def get_arg_parse():
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument(
|
24 |
+
"--model",
|
25 |
+
type=str,
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--src_lang_list",
|
29 |
+
type=str,
|
30 |
+
help="comma separated list of source languages",
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--tgt_lang_list",
|
34 |
+
type=str,
|
35 |
+
help="comma separated list of target languages",
|
36 |
+
)
|
37 |
+
parser.add_argument("--data_dir", type=str)
|
38 |
+
parser.add_argument("--output_dir", type=str)
|
39 |
+
parser.add_argument("--save_steps", type=int, default=1000)
|
40 |
+
parser.add_argument("--eval_steps", type=int, default=1000)
|
41 |
+
parser.add_argument("--batch_size", type=int, default=32)
|
42 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
43 |
+
parser.add_argument("--max_steps", type=int, default=1000000)
|
44 |
+
parser.add_argument("--grad_accum_steps", type=int, default=4)
|
45 |
+
parser.add_argument("--warmup_steps", type=int, default=4000)
|
46 |
+
parser.add_argument("--warmup_ratio", type=int, default=0.0)
|
47 |
+
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
48 |
+
parser.add_argument("--learning_rate", type=float, default=5e-4)
|
49 |
+
parser.add_argument("--weight_decay", type=float, default=0.0)
|
50 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9)
|
51 |
+
parser.add_argument("--adam_beta2", type=float, default=0.98)
|
52 |
+
parser.add_argument("--dropout", type=float, default=0.0)
|
53 |
+
parser.add_argument("--print_samples", action="store_true")
|
54 |
+
parser.add_argument(
|
55 |
+
"--optimizer",
|
56 |
+
type=str,
|
57 |
+
default="adamw_torch",
|
58 |
+
choices=[
|
59 |
+
"adam_hf",
|
60 |
+
"adamw_torch",
|
61 |
+
"adamw_torch_fused",
|
62 |
+
"adamw_apex_fused",
|
63 |
+
"adafactor",
|
64 |
+
],
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--lr_scheduler",
|
68 |
+
type=str,
|
69 |
+
default="inverse_sqrt",
|
70 |
+
choices=[
|
71 |
+
"inverse_sqrt",
|
72 |
+
"linear",
|
73 |
+
"polynomial",
|
74 |
+
"cosine",
|
75 |
+
"constant",
|
76 |
+
"constant_with_warmup",
|
77 |
+
],
|
78 |
+
)
|
79 |
+
parser.add_argument("--label_smoothing", type=float, default=0.0)
|
80 |
+
parser.add_argument("--num_workers", type=int, default=8)
|
81 |
+
parser.add_argument("--metric_for_best_model", type=str, default="eval_loss")
|
82 |
+
parser.add_argument("--greater_is_better", action="store_true")
|
83 |
+
parser.add_argument("--lora_target_modules", type=str, default="q_proj,k_proj")
|
84 |
+
parser.add_argument("--lora_dropout", type=float, default=0.1)
|
85 |
+
parser.add_argument("--lora_r", type=int, default=16)
|
86 |
+
parser.add_argument("--lora_alpha", type=int, default=32)
|
87 |
+
parser.add_argument(
|
88 |
+
"--report_to",
|
89 |
+
type=str,
|
90 |
+
default="none",
|
91 |
+
choices=["wandb", "tensorboard", "azure_ml", "none"],
|
92 |
+
)
|
93 |
+
parser.add_argument("--patience", type=int, default=5),
|
94 |
+
parser.add_argument("--threshold", type=float, default=1e-3)
|
95 |
+
return parser
|
96 |
+
|
97 |
+
|
98 |
+
def load_and_process_translation_dataset(
|
99 |
+
data_dir,
|
100 |
+
split="train",
|
101 |
+
tokenizer=None,
|
102 |
+
processor=None,
|
103 |
+
src_lang_list=None,
|
104 |
+
tgt_lang_list=None,
|
105 |
+
num_proc=8,
|
106 |
+
seed=42
|
107 |
+
):
|
108 |
+
complete_dataset = {
|
109 |
+
"sentence_SRC": [],
|
110 |
+
"sentence_TGT": [],
|
111 |
+
}
|
112 |
+
|
113 |
+
for src_lang in src_lang_list:
|
114 |
+
for tgt_lang in tgt_lang_list:
|
115 |
+
if src_lang == tgt_lang:
|
116 |
+
continue
|
117 |
+
src_path = os.path.join(
|
118 |
+
data_dir, split, f"{src_lang}-{tgt_lang}", f"{split}.{src_lang}"
|
119 |
+
)
|
120 |
+
tgt_path = os.path.join(
|
121 |
+
data_dir, split, f"{src_lang}-{tgt_lang}", f"{split}.{tgt_lang}"
|
122 |
+
)
|
123 |
+
if not os.path.exists(src_path) or not os.path.exists(tgt_path):
|
124 |
+
raise FileNotFoundError(
|
125 |
+
f"Source ({split}.{src_lang}) or Target ({split}.{tgt_lang}) file not found in {data_dir}"
|
126 |
+
)
|
127 |
+
with open(src_path, encoding="utf-8") as src_file, open(
|
128 |
+
tgt_path, encoding="utf-8"
|
129 |
+
) as tgt_file:
|
130 |
+
src_lines = src_file.readlines()
|
131 |
+
tgt_lines = tgt_file.readlines()
|
132 |
+
|
133 |
+
# Ensure both files have the same number of lines
|
134 |
+
assert len(src_lines) == len(
|
135 |
+
tgt_lines
|
136 |
+
), f"Source and Target files have different number of lines for {split}.{src_lang} and {split}.{tgt_lang}"
|
137 |
+
|
138 |
+
complete_dataset["sentence_SRC"] += processor.preprocess_batch(
|
139 |
+
src_lines, src_lang=src_lang, tgt_lang=tgt_lang, is_target=False
|
140 |
+
)
|
141 |
+
|
142 |
+
complete_dataset["sentence_TGT"] += processor.preprocess_batch(
|
143 |
+
tgt_lines, src_lang=tgt_lang, tgt_lang=src_lang, is_target=True
|
144 |
+
)
|
145 |
+
|
146 |
+
complete_dataset = Dataset.from_dict(complete_dataset).shuffle(seed=seed)
|
147 |
+
|
148 |
+
return complete_dataset.map(
|
149 |
+
lambda example: preprocess_fn(
|
150 |
+
example,
|
151 |
+
tokenizer=tokenizer
|
152 |
+
),
|
153 |
+
batched=True,
|
154 |
+
num_proc=num_proc,
|
155 |
+
)
|
156 |
+
|
157 |
+
|
158 |
+
def compute_metrics_factory(
|
159 |
+
tokenizer, metric_dict=None, print_samples=False, n_samples=10
|
160 |
+
):
|
161 |
+
def compute_metrics(eval_preds):
|
162 |
+
preds, labels = eval_preds
|
163 |
+
|
164 |
+
labels[labels == -100] = tokenizer.pad_token_id
|
165 |
+
preds[preds == -100] = tokenizer.pad_token_id
|
166 |
+
|
167 |
+
with tokenizer.as_target_tokenizer():
|
168 |
+
preds = [
|
169 |
+
x.strip()
|
170 |
+
for x in tokenizer.batch_decode(
|
171 |
+
preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
172 |
+
)
|
173 |
+
]
|
174 |
+
labels = [
|
175 |
+
x.strip()
|
176 |
+
for x in tokenizer.batch_decode(
|
177 |
+
labels, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
178 |
+
)
|
179 |
+
]
|
180 |
+
|
181 |
+
assert len(preds) == len(
|
182 |
+
labels
|
183 |
+
), "Predictions and Labels have different lengths"
|
184 |
+
|
185 |
+
df = pd.DataFrame({"Predictions": preds, "References": labels}).sample(
|
186 |
+
n=n_samples
|
187 |
+
)
|
188 |
+
|
189 |
+
if print_samples:
|
190 |
+
for pred, label in zip(df["Predictions"].values, df["References"].values):
|
191 |
+
print(f" | > Prediction: {pred}")
|
192 |
+
print(f" | > Reference: {label}\n")
|
193 |
+
|
194 |
+
return {
|
195 |
+
metric_name: metric.corpus_score(preds, [labels]).score
|
196 |
+
for (metric_name, metric) in metric_dict.items()
|
197 |
+
}
|
198 |
+
|
199 |
+
return compute_metrics
|
200 |
+
|
201 |
+
|
202 |
+
def preprocess_fn(example, tokenizer, **kwargs):
|
203 |
+
model_inputs = tokenizer(
|
204 |
+
example["sentence_SRC"], truncation=True, padding=False, max_length=256
|
205 |
+
)
|
206 |
+
|
207 |
+
with tokenizer.as_target_tokenizer():
|
208 |
+
labels = tokenizer(
|
209 |
+
example["sentence_TGT"], truncation=True, padding=False, max_length=256
|
210 |
+
)
|
211 |
+
|
212 |
+
model_inputs["labels"] = labels["input_ids"]
|
213 |
+
return model_inputs
|
214 |
+
|
215 |
+
|
216 |
+
def main(args):
|
217 |
+
print(f" | > Loading {args.model} and tokenizer ...")
|
218 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
219 |
+
args.model,
|
220 |
+
trust_remote_code=True,
|
221 |
+
attn_implementation="eager",
|
222 |
+
dropout=args.dropout
|
223 |
+
)
|
224 |
+
|
225 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
226 |
+
processor = IndicProcessor(inference=False) # pre-process before tokenization
|
227 |
+
|
228 |
+
data_collator = IndicDataCollator(
|
229 |
+
tokenizer=tokenizer,
|
230 |
+
model=model,
|
231 |
+
padding="longest", # saves padding tokens
|
232 |
+
pad_to_multiple_of=8, # better to have it as 8 when using fp16
|
233 |
+
label_pad_token_id=-100
|
234 |
+
)
|
235 |
+
|
236 |
+
if args.data_dir is not None:
|
237 |
+
train_dataset = load_and_process_translation_dataset(
|
238 |
+
args.data_dir,
|
239 |
+
split="train",
|
240 |
+
tokenizer=tokenizer,
|
241 |
+
processor=processor,
|
242 |
+
src_lang_list=args.src_lang_list.split(","),
|
243 |
+
tgt_lang_list=args.tgt_lang_list.split(","),
|
244 |
+
)
|
245 |
+
print(f" | > Loaded train dataset from {args.data_dir}. Size: {len(train_dataset)} ...")
|
246 |
+
|
247 |
+
eval_dataset = load_and_process_translation_dataset(
|
248 |
+
args.data_dir,
|
249 |
+
split="dev",
|
250 |
+
tokenizer=tokenizer,
|
251 |
+
processor=processor,
|
252 |
+
src_lang_list=args.src_lang_list.split(","),
|
253 |
+
tgt_lang_list=args.tgt_lang_list.split(","),
|
254 |
+
)
|
255 |
+
print(f" | > Loaded eval dataset from {args.data_dir}. Size: {len(eval_dataset)} ...")
|
256 |
+
else:
|
257 |
+
raise ValueError(" | > Data directory not provided")
|
258 |
+
|
259 |
+
lora_config = LoraConfig(
|
260 |
+
r=args.lora_r,
|
261 |
+
bias="none",
|
262 |
+
inference_mode=False,
|
263 |
+
task_type="SEQ_2_SEQ_LM",
|
264 |
+
lora_alpha=args.lora_alpha,
|
265 |
+
lora_dropout=args.lora_dropout,
|
266 |
+
target_modules=args.lora_target_modules.split(","),
|
267 |
+
)
|
268 |
+
|
269 |
+
model.set_label_smoothing(args.label_smoothing)
|
270 |
+
|
271 |
+
model = get_peft_model(model, lora_config)
|
272 |
+
model.print_trainable_parameters()
|
273 |
+
|
274 |
+
print(f" | > Loading metrics factory with BLEU and chrF ...")
|
275 |
+
seq2seq_compute_metrics = compute_metrics_factory(
|
276 |
+
tokenizer=tokenizer,
|
277 |
+
print_samples=args.print_samples,
|
278 |
+
metric_dict={"BLEU": bleu_metric, "chrF": chrf_metric},
|
279 |
+
)
|
280 |
+
|
281 |
+
training_args = Seq2SeqTrainingArguments(
|
282 |
+
output_dir=args.output_dir,
|
283 |
+
do_train=True,
|
284 |
+
do_eval=True,
|
285 |
+
fp16=True, # use fp16 for faster training
|
286 |
+
logging_strategy="steps",
|
287 |
+
evaluation_strategy="steps",
|
288 |
+
save_strategy="steps",
|
289 |
+
logging_steps=100,
|
290 |
+
save_total_limit=1,
|
291 |
+
predict_with_generate=True,
|
292 |
+
load_best_model_at_end=True,
|
293 |
+
max_steps=args.max_steps, # max_steps overrides num_train_epochs
|
294 |
+
per_device_train_batch_size=args.batch_size,
|
295 |
+
per_device_eval_batch_size=args.batch_size,
|
296 |
+
gradient_accumulation_steps=args.grad_accum_steps,
|
297 |
+
eval_accumulation_steps=args.grad_accum_steps,
|
298 |
+
weight_decay=args.weight_decay,
|
299 |
+
adam_beta1=args.adam_beta1,
|
300 |
+
adam_beta2=args.adam_beta2,
|
301 |
+
max_grad_norm=args.max_grad_norm,
|
302 |
+
optim=args.optimizer,
|
303 |
+
lr_scheduler_type=args.lr_scheduler,
|
304 |
+
warmup_ratio=args.warmup_ratio,
|
305 |
+
warmup_steps=args.warmup_steps,
|
306 |
+
learning_rate=args.learning_rate,
|
307 |
+
num_train_epochs=args.num_train_epochs,
|
308 |
+
save_steps=args.save_steps,
|
309 |
+
eval_steps=args.eval_steps,
|
310 |
+
dataloader_num_workers=args.num_workers,
|
311 |
+
metric_for_best_model=args.metric_for_best_model,
|
312 |
+
greater_is_better=args.greater_is_better,
|
313 |
+
report_to=args.report_to,
|
314 |
+
generation_max_length=256,
|
315 |
+
generation_num_beams=5,
|
316 |
+
sortish_sampler=True,
|
317 |
+
group_by_length=True,
|
318 |
+
include_tokens_per_second=True,
|
319 |
+
include_num_input_tokens_seen=True,
|
320 |
+
dataloader_prefetch_factor=2,
|
321 |
+
)
|
322 |
+
|
323 |
+
# Create Trainer instance
|
324 |
+
trainer = Seq2SeqTrainer(
|
325 |
+
model=model,
|
326 |
+
args=training_args,
|
327 |
+
data_collator=data_collator,
|
328 |
+
train_dataset=train_dataset,
|
329 |
+
eval_dataset=eval_dataset,
|
330 |
+
compute_metrics=seq2seq_compute_metrics,
|
331 |
+
callbacks=[
|
332 |
+
EarlyStoppingCallback(
|
333 |
+
early_stopping_patience=args.patience,
|
334 |
+
early_stopping_threshold=args.threshold,
|
335 |
+
)
|
336 |
+
],
|
337 |
+
)
|
338 |
+
|
339 |
+
print(f" | > Starting training ...")
|
340 |
+
|
341 |
+
try:
|
342 |
+
trainer.train()
|
343 |
+
except KeyboardInterrupt:
|
344 |
+
print(f" | > Training interrupted ...")
|
345 |
+
|
346 |
+
# this will only save the LoRA adapter weights
|
347 |
+
model.save_pretrained(args.output_dir)
|
348 |
+
|
349 |
+
|
350 |
+
|
351 |
+
if __name__ == "__main__":
|
352 |
+
parser = get_arg_parse()
|
353 |
+
args = parser.parse_args()
|
354 |
+
|
355 |
+
main(args)
|
IndicTrans2/huggingface_interface/train_lora.sh
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export CUDA_VISIBLE_DEVICES=0
|
2 |
+
|
3 |
+
data_dir=${1:-"en-indic-exp"}
|
4 |
+
model_name=${2:-"ai4bharat/indictrans2-en-indic-dist-200M"}
|
5 |
+
output_dir=${3:-"output"}
|
6 |
+
src_lang_list=${4:-"eng_Latn"}
|
7 |
+
tgt_lang_list=${5:-"asm_Beng,ben_Beng,guj_Gujr,hin_Deva,kan_Knda,mal_Mlym,mar_Deva,npi_Deva,ory_Orya,pan_Guru,tam_Taml,tel_Telu,urd_Arab"}
|
8 |
+
|
9 |
+
python3 train_lora.py \
|
10 |
+
--data_dir $data_dir \
|
11 |
+
--model_name $model_name \
|
12 |
+
--output_dir $output_dir \
|
13 |
+
--src_lang_list $src_lang_list \
|
14 |
+
--tgt_lang_list $tgt_lang_list \
|
15 |
+
--save_steps 1000 \
|
16 |
+
--max_steps 1000000 \
|
17 |
+
--batch_size 32 \
|
18 |
+
--grad_accum_steps 4 \
|
19 |
+
--warmup_steps 4000 \
|
20 |
+
--max_grad_norm 1.0 \
|
21 |
+
--learning_rate 2e-4 \
|
22 |
+
--adam_beta1 0.9 \
|
23 |
+
--adam_beta2 0.98 \
|
24 |
+
--optimizer adamw_torch \
|
25 |
+
--lr_scheduler inverse_sqrt \
|
26 |
+
--num_workers 16 \
|
27 |
+
--metric_for_best_model eval_BLEU \
|
28 |
+
--greater_is_better \
|
29 |
+
--patience 10 \
|
30 |
+
--weight_decay 0.01 \
|
31 |
+
--lora_target_modules "q_proj,k_proj" \
|
32 |
+
--lora_dropout 0.1 \
|
33 |
+
--lora_r 16 \
|
34 |
+
--lora_alpha 32 \
|
35 |
+
--print_samples
|
IndicTrans2/inference/__init__.py
ADDED
File without changes
|
IndicTrans2/inference/custom_interactive.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python wrapper for fairseq-interactive command line tool
|
2 |
+
|
3 |
+
#!/usr/bin/env python3 -u
|
4 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
5 |
+
#
|
6 |
+
# This source code is licensed under the MIT license found in the
|
7 |
+
# LICENSE file in the root directory of this source tree.
|
8 |
+
"""
|
9 |
+
Translate raw text with a trained model. Batches data on-the-fly.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import os
|
13 |
+
import ast
|
14 |
+
from collections import namedtuple
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from fairseq import checkpoint_utils, options, tasks, utils
|
18 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
19 |
+
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
|
20 |
+
from fairseq_cli.generate import get_symbols_to_strip_from_output
|
21 |
+
|
22 |
+
import codecs
|
23 |
+
|
24 |
+
PWD = os.path.dirname(__file__)
|
25 |
+
Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")
|
26 |
+
Translation = namedtuple("Translation", "src_str hypos pos_scores alignments")
|
27 |
+
|
28 |
+
|
29 |
+
def make_batches(
|
30 |
+
lines, cfg, task, max_positions, encode_fn, constrainted_decoding=False
|
31 |
+
):
|
32 |
+
def encode_fn_target(x):
|
33 |
+
return encode_fn(x)
|
34 |
+
|
35 |
+
if constrainted_decoding:
|
36 |
+
# Strip (tab-delimited) contraints, if present, from input lines,
|
37 |
+
# store them in batch_constraints
|
38 |
+
batch_constraints = [list() for _ in lines]
|
39 |
+
for i, line in enumerate(lines):
|
40 |
+
if "\t" in line:
|
41 |
+
lines[i], *batch_constraints[i] = line.split("\t")
|
42 |
+
|
43 |
+
# Convert each List[str] to List[Tensor]
|
44 |
+
for i, constraint_list in enumerate(batch_constraints):
|
45 |
+
batch_constraints[i] = [
|
46 |
+
task.target_dictionary.encode_line(
|
47 |
+
encode_fn_target(constraint),
|
48 |
+
append_eos=False,
|
49 |
+
add_if_not_exist=False,
|
50 |
+
)
|
51 |
+
for constraint in constraint_list
|
52 |
+
]
|
53 |
+
|
54 |
+
if constrainted_decoding:
|
55 |
+
constraints_tensor = pack_constraints(batch_constraints)
|
56 |
+
else:
|
57 |
+
constraints_tensor = None
|
58 |
+
|
59 |
+
tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn)
|
60 |
+
|
61 |
+
itr = task.get_batch_iterator(
|
62 |
+
dataset=task.build_dataset_for_inference(
|
63 |
+
tokens, lengths, constraints=constraints_tensor
|
64 |
+
),
|
65 |
+
max_tokens=cfg.dataset.max_tokens,
|
66 |
+
max_sentences=cfg.dataset.batch_size,
|
67 |
+
max_positions=max_positions,
|
68 |
+
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
69 |
+
).next_epoch_itr(shuffle=False)
|
70 |
+
for batch in itr:
|
71 |
+
ids = batch["id"]
|
72 |
+
src_tokens = batch["net_input"]["src_tokens"]
|
73 |
+
src_lengths = batch["net_input"]["src_lengths"]
|
74 |
+
constraints = batch.get("constraints", None)
|
75 |
+
|
76 |
+
yield Batch(
|
77 |
+
ids=ids,
|
78 |
+
src_tokens=src_tokens,
|
79 |
+
src_lengths=src_lengths,
|
80 |
+
constraints=constraints,
|
81 |
+
)
|
82 |
+
|
83 |
+
|
84 |
+
class Translator:
|
85 |
+
"""
|
86 |
+
Wrapper class to handle the interaction with fairseq model class for translation
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(
|
90 |
+
self, data_dir, checkpoint_path, batch_size=25, constrained_decoding=False
|
91 |
+
):
|
92 |
+
|
93 |
+
self.constrained_decoding = constrained_decoding
|
94 |
+
self.parser = options.get_generation_parser(interactive=True)
|
95 |
+
# buffer_size is currently not used but we just initialize it to batch
|
96 |
+
# size + 1 to avoid any assertion errors.
|
97 |
+
if self.constrained_decoding:
|
98 |
+
self.parser.set_defaults(
|
99 |
+
path=checkpoint_path,
|
100 |
+
num_workers=-1,
|
101 |
+
constraints="ordered",
|
102 |
+
batch_size=batch_size,
|
103 |
+
buffer_size=batch_size + 1,
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
self.parser.set_defaults(
|
107 |
+
path=checkpoint_path,
|
108 |
+
remove_bpe="subword_nmt",
|
109 |
+
num_workers=-1,
|
110 |
+
batch_size=batch_size,
|
111 |
+
buffer_size=batch_size + 1,
|
112 |
+
)
|
113 |
+
args = options.parse_args_and_arch(self.parser, input_args=[data_dir])
|
114 |
+
# we are explictly setting src_lang and tgt_lang here
|
115 |
+
# generally the data_dir we pass contains {split}-{src_lang}-{tgt_lang}.*.idx files from
|
116 |
+
# which fairseq infers the src and tgt langs(if these are not passed). In deployment we dont
|
117 |
+
# use any idx files and only store the SRC and TGT dictionaries.
|
118 |
+
args.source_lang = "SRC"
|
119 |
+
args.target_lang = "TGT"
|
120 |
+
# since we are truncating sentences to max_seq_len in engine, we can set it to False here
|
121 |
+
args.skip_invalid_size_inputs_valid_test = False
|
122 |
+
|
123 |
+
# we have custom architechtures in this folder and we will let fairseq
|
124 |
+
# import this
|
125 |
+
args.user_dir = os.path.join(PWD, "model_configs")
|
126 |
+
self.cfg = convert_namespace_to_omegaconf(args)
|
127 |
+
|
128 |
+
utils.import_user_module(self.cfg.common)
|
129 |
+
|
130 |
+
if self.cfg.interactive.buffer_size < 1:
|
131 |
+
self.cfg.interactive.buffer_size = 1
|
132 |
+
if self.cfg.dataset.max_tokens is None and self.cfg.dataset.batch_size is None:
|
133 |
+
self.cfg.dataset.batch_size = 1
|
134 |
+
|
135 |
+
assert (
|
136 |
+
not self.cfg.generation.sampling
|
137 |
+
or self.cfg.generation.nbest == self.cfg.generation.beam
|
138 |
+
), "--sampling requires --nbest to be equal to --beam"
|
139 |
+
assert (
|
140 |
+
not self.cfg.dataset.batch_size
|
141 |
+
or self.cfg.dataset.batch_size <= self.cfg.interactive.buffer_size
|
142 |
+
), "--batch-size cannot be larger than --buffer-size"
|
143 |
+
|
144 |
+
# Fix seed for stochastic decoding
|
145 |
+
# if self.cfg.common.seed is not None and not self.cfg.generation.no_seed_provided:
|
146 |
+
# np.random.seed(self.cfg.common.seed)
|
147 |
+
# utils.set_torch_seed(self.cfg.common.seed)
|
148 |
+
|
149 |
+
# if not self.constrained_decoding:
|
150 |
+
# self.use_cuda = torch.cuda.is_available() and not self.cfg.common.cpu
|
151 |
+
# else:
|
152 |
+
# self.use_cuda = False
|
153 |
+
|
154 |
+
self.use_cuda = torch.cuda.is_available() and not self.cfg.common.cpu
|
155 |
+
|
156 |
+
# Setup task, e.g., translation
|
157 |
+
self.task = tasks.setup_task(self.cfg.task)
|
158 |
+
|
159 |
+
# Load ensemble
|
160 |
+
overrides = ast.literal_eval(self.cfg.common_eval.model_overrides)
|
161 |
+
self.models, self._model_args = checkpoint_utils.load_model_ensemble(
|
162 |
+
utils.split_paths(self.cfg.common_eval.path),
|
163 |
+
arg_overrides=overrides,
|
164 |
+
task=self.task,
|
165 |
+
suffix=self.cfg.checkpoint.checkpoint_suffix,
|
166 |
+
strict=(self.cfg.checkpoint.checkpoint_shard_count == 1),
|
167 |
+
num_shards=self.cfg.checkpoint.checkpoint_shard_count,
|
168 |
+
)
|
169 |
+
|
170 |
+
# Set dictionaries
|
171 |
+
self.src_dict = self.task.source_dictionary
|
172 |
+
self.tgt_dict = self.task.target_dictionary
|
173 |
+
|
174 |
+
# Optimize ensemble for generation
|
175 |
+
for model in self.models:
|
176 |
+
if model is None:
|
177 |
+
continue
|
178 |
+
if self.cfg.common.fp16:
|
179 |
+
model.half()
|
180 |
+
if (
|
181 |
+
self.use_cuda
|
182 |
+
and not self.cfg.distributed_training.pipeline_model_parallel
|
183 |
+
):
|
184 |
+
model.cuda()
|
185 |
+
model.prepare_for_inference_(self.cfg)
|
186 |
+
|
187 |
+
# Initialize generator
|
188 |
+
self.generator = self.task.build_generator(self.models, self.cfg.generation)
|
189 |
+
|
190 |
+
self.tokenizer = None
|
191 |
+
self.bpe = None
|
192 |
+
# # Handle tokenization and BPE
|
193 |
+
# self.tokenizer = self.task.build_tokenizer(self.cfg.tokenizer)
|
194 |
+
# self.bpe = self.task.build_bpe(self.cfg.bpe)
|
195 |
+
|
196 |
+
# Load alignment dictionary for unknown word replacement
|
197 |
+
# (None if no unknown word replacement, empty if no path to align dictionary)
|
198 |
+
self.align_dict = utils.load_align_dict(self.cfg.generation.replace_unk)
|
199 |
+
|
200 |
+
self.max_positions = utils.resolve_max_positions(
|
201 |
+
self.task.max_positions(), *[model.max_positions() for model in self.models]
|
202 |
+
)
|
203 |
+
|
204 |
+
def encode_fn(self, x):
|
205 |
+
if self.tokenizer is not None:
|
206 |
+
x = self.tokenizer.encode(x)
|
207 |
+
if self.bpe is not None:
|
208 |
+
x = self.bpe.encode(x)
|
209 |
+
return x
|
210 |
+
|
211 |
+
def decode_fn(self, x):
|
212 |
+
if self.bpe is not None:
|
213 |
+
x = self.bpe.decode(x)
|
214 |
+
if self.tokenizer is not None:
|
215 |
+
x = self.tokenizer.decode(x)
|
216 |
+
return x
|
217 |
+
|
218 |
+
def translate(self, inputs, constraints=None):
|
219 |
+
if self.constrained_decoding and constraints is None:
|
220 |
+
raise ValueError("Constraints cant be None in constrained decoding mode")
|
221 |
+
if not self.constrained_decoding and constraints is not None:
|
222 |
+
raise ValueError("Cannot pass constraints during normal translation")
|
223 |
+
if constraints:
|
224 |
+
constrained_decoding = True
|
225 |
+
modified_inputs = []
|
226 |
+
for _input, constraint in zip(inputs, constraints):
|
227 |
+
modified_inputs.append(_input + f"\t{constraint}")
|
228 |
+
inputs = modified_inputs
|
229 |
+
else:
|
230 |
+
constrained_decoding = False
|
231 |
+
|
232 |
+
start_id = 0
|
233 |
+
results = []
|
234 |
+
final_translations = []
|
235 |
+
for batch in make_batches(
|
236 |
+
inputs,
|
237 |
+
self.cfg,
|
238 |
+
self.task,
|
239 |
+
self.max_positions,
|
240 |
+
self.encode_fn,
|
241 |
+
constrained_decoding,
|
242 |
+
):
|
243 |
+
bsz = batch.src_tokens.size(0)
|
244 |
+
src_tokens = batch.src_tokens
|
245 |
+
src_lengths = batch.src_lengths
|
246 |
+
constraints = batch.constraints
|
247 |
+
if self.use_cuda:
|
248 |
+
src_tokens = src_tokens.cuda()
|
249 |
+
src_lengths = src_lengths.cuda()
|
250 |
+
if constraints is not None:
|
251 |
+
constraints = constraints.cuda()
|
252 |
+
|
253 |
+
sample = {
|
254 |
+
"net_input": {
|
255 |
+
"src_tokens": src_tokens,
|
256 |
+
"src_lengths": src_lengths,
|
257 |
+
},
|
258 |
+
}
|
259 |
+
|
260 |
+
translations = self.task.inference_step(
|
261 |
+
self.generator, self.models, sample, constraints=constraints
|
262 |
+
)
|
263 |
+
|
264 |
+
list_constraints = [[] for _ in range(bsz)]
|
265 |
+
if constrained_decoding:
|
266 |
+
list_constraints = [unpack_constraints(c) for c in constraints]
|
267 |
+
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
|
268 |
+
src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad())
|
269 |
+
constraints = list_constraints[i]
|
270 |
+
results.append(
|
271 |
+
(
|
272 |
+
start_id + id,
|
273 |
+
src_tokens_i,
|
274 |
+
hypos,
|
275 |
+
{
|
276 |
+
"constraints": constraints,
|
277 |
+
},
|
278 |
+
)
|
279 |
+
)
|
280 |
+
|
281 |
+
# sort output to match input order
|
282 |
+
for id_, src_tokens, hypos, _ in sorted(results, key=lambda x: x[0]):
|
283 |
+
src_str = ""
|
284 |
+
if self.src_dict is not None:
|
285 |
+
src_str = self.src_dict.string(
|
286 |
+
src_tokens, self.cfg.common_eval.post_process
|
287 |
+
)
|
288 |
+
|
289 |
+
# Process top predictions
|
290 |
+
for hypo in hypos[: min(len(hypos), self.cfg.generation.nbest)]:
|
291 |
+
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
|
292 |
+
hypo_tokens=hypo["tokens"].int().cpu(),
|
293 |
+
src_str=src_str,
|
294 |
+
alignment=hypo["alignment"],
|
295 |
+
align_dict=self.align_dict,
|
296 |
+
tgt_dict=self.tgt_dict,
|
297 |
+
|
298 |
+
extra_symbols_to_ignore=get_symbols_to_strip_from_output(
|
299 |
+
self.generator
|
300 |
+
),
|
301 |
+
)
|
302 |
+
detok_hypo_str = self.decode_fn(hypo_str)
|
303 |
+
final_translations.append(detok_hypo_str)
|
304 |
+
return final_translations
|
IndicTrans2/inference/download.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import urduhack
|
2 |
+
urduhack.download()
|
3 |
+
|
4 |
+
import nltk
|
5 |
+
nltk.download('punkt')
|
IndicTrans2/inference/engine.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import uuid
|
4 |
+
from typing import List, Tuple, Union, Dict
|
5 |
+
|
6 |
+
import regex as re
|
7 |
+
import sentencepiece as spm
|
8 |
+
from indicnlp.normalize import indic_normalize
|
9 |
+
from indicnlp.tokenize import indic_detokenize, indic_tokenize
|
10 |
+
from indicnlp.tokenize.sentence_tokenize import DELIM_PAT_NO_DANDA, sentence_split
|
11 |
+
from indicnlp.transliterate import unicode_transliterate
|
12 |
+
from mosestokenizer import MosesSentenceSplitter
|
13 |
+
from nltk.tokenize import sent_tokenize
|
14 |
+
from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
from .flores_codes_map_indic import flores_codes, iso_to_flores
|
18 |
+
from .normalize_punctuation import punc_norm
|
19 |
+
from .normalize_regex_inference import EMAIL_PATTERN, normalize
|
20 |
+
|
21 |
+
|
22 |
+
def split_sentences(paragraph: str, lang: str) -> List[str]:
|
23 |
+
"""
|
24 |
+
Splits the input text paragraph into sentences. It uses `moses` for English and
|
25 |
+
`indic-nlp` for Indic languages.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
paragraph (str): input text paragraph.
|
29 |
+
lang (str): flores language code.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
List[str] -> list of sentences.
|
33 |
+
"""
|
34 |
+
if lang == "eng_Latn":
|
35 |
+
with MosesSentenceSplitter(flores_codes[lang]) as splitter:
|
36 |
+
sents_moses = splitter([paragraph])
|
37 |
+
sents_nltk = sent_tokenize(paragraph)
|
38 |
+
if len(sents_nltk) < len(sents_moses):
|
39 |
+
sents = sents_nltk
|
40 |
+
else:
|
41 |
+
sents = sents_moses
|
42 |
+
return [sent.replace("\xad", "") for sent in sents]
|
43 |
+
else:
|
44 |
+
return sentence_split(paragraph, lang=flores_codes[lang], delim_pat=DELIM_PAT_NO_DANDA)
|
45 |
+
|
46 |
+
|
47 |
+
def add_token(sent: str, src_lang: str, tgt_lang: str, delimiter: str = " ") -> str:
|
48 |
+
"""
|
49 |
+
Add special tokens indicating source and target language to the start of the input sentence.
|
50 |
+
The resulting string will have the format: "`{src_lang} {tgt_lang} {input_sentence}`".
|
51 |
+
|
52 |
+
Args:
|
53 |
+
sent (str): input sentence to be translated.
|
54 |
+
src_lang (str): flores lang code of the input sentence.
|
55 |
+
tgt_lang (str): flores lang code in which the input sentence will be translated.
|
56 |
+
delimiter (str): separator to add between language tags and input sentence (default: " ").
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
str: input sentence with the special tokens added to the start.
|
60 |
+
"""
|
61 |
+
return src_lang + delimiter + tgt_lang + delimiter + sent
|
62 |
+
|
63 |
+
|
64 |
+
def apply_lang_tags(sents: List[str], src_lang: str, tgt_lang: str) -> List[str]:
|
65 |
+
"""
|
66 |
+
Add special tokens indicating source and target language to the start of the each input sentence.
|
67 |
+
Each resulting input sentence will have the format: "`{src_lang} {tgt_lang} {input_sentence}`".
|
68 |
+
|
69 |
+
Args:
|
70 |
+
sent (str): input sentence to be translated.
|
71 |
+
src_lang (str): flores lang code of the input sentence.
|
72 |
+
tgt_lang (str): flores lang code in which the input sentence will be translated.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
List[str]: list of input sentences with the special tokens added to the start.
|
76 |
+
"""
|
77 |
+
tagged_sents = []
|
78 |
+
for sent in sents:
|
79 |
+
tagged_sent = add_token(sent.strip(), src_lang, tgt_lang)
|
80 |
+
tagged_sents.append(tagged_sent)
|
81 |
+
return tagged_sents
|
82 |
+
|
83 |
+
|
84 |
+
def truncate_long_sentences(
|
85 |
+
sents: List[str], placeholder_entity_map_sents: List[Dict]
|
86 |
+
) -> Tuple[List[str], List[Dict]]:
|
87 |
+
"""
|
88 |
+
Truncates the sentences that exceed the maximum sequence length.
|
89 |
+
The maximum sequence for the IndicTrans2 model is limited to 256 tokens.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
sents (List[str]): list of input sentences to truncate.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
Tuple[List[str], List[Dict]]: tuple containing the list of sentences with truncation applied and the updated placeholder entity maps.
|
96 |
+
"""
|
97 |
+
MAX_SEQ_LEN = 256
|
98 |
+
new_sents = []
|
99 |
+
placeholders = []
|
100 |
+
|
101 |
+
for j, sent in enumerate(sents):
|
102 |
+
words = sent.split()
|
103 |
+
num_words = len(words)
|
104 |
+
if num_words > MAX_SEQ_LEN:
|
105 |
+
sents = []
|
106 |
+
i = 0
|
107 |
+
while i <= len(words):
|
108 |
+
sents.append(" ".join(words[i : i + MAX_SEQ_LEN]))
|
109 |
+
i += MAX_SEQ_LEN
|
110 |
+
placeholders.extend([placeholder_entity_map_sents[j]] * (len(sents)))
|
111 |
+
new_sents.extend(sents)
|
112 |
+
else:
|
113 |
+
placeholders.append(placeholder_entity_map_sents[j])
|
114 |
+
new_sents.append(sent)
|
115 |
+
return new_sents, placeholders
|
116 |
+
|
117 |
+
|
118 |
+
class Model:
|
119 |
+
"""
|
120 |
+
Model class to run the IndicTransv2 models using python interface.
|
121 |
+
"""
|
122 |
+
|
123 |
+
def __init__(
|
124 |
+
self,
|
125 |
+
ckpt_dir: str,
|
126 |
+
device: str = "cuda",
|
127 |
+
input_lang_code_format: str = "flores",
|
128 |
+
model_type: str = "ctranslate2",
|
129 |
+
):
|
130 |
+
"""
|
131 |
+
Initialize the model class.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
ckpt_dir (str): path of the model checkpoint directory.
|
135 |
+
device (str, optional): where to load the model (defaults: cuda).
|
136 |
+
"""
|
137 |
+
self.ckpt_dir = ckpt_dir
|
138 |
+
self.en_tok = MosesTokenizer(lang="en")
|
139 |
+
self.en_normalizer = MosesPunctNormalizer()
|
140 |
+
self.en_detok = MosesDetokenizer(lang="en")
|
141 |
+
self.xliterator = unicode_transliterate.UnicodeIndicTransliterator()
|
142 |
+
|
143 |
+
print("Initializing sentencepiece model for SRC and TGT")
|
144 |
+
self.sp_src = spm.SentencePieceProcessor(
|
145 |
+
model_file=os.path.join(ckpt_dir, "vocab", "model.SRC")
|
146 |
+
)
|
147 |
+
self.sp_tgt = spm.SentencePieceProcessor(
|
148 |
+
model_file=os.path.join(ckpt_dir, "vocab", "model.TGT")
|
149 |
+
)
|
150 |
+
|
151 |
+
self.input_lang_code_format = input_lang_code_format
|
152 |
+
|
153 |
+
print("Initializing model for translation")
|
154 |
+
# initialize the model
|
155 |
+
if model_type == "ctranslate2":
|
156 |
+
import ctranslate2
|
157 |
+
|
158 |
+
self.translator = ctranslate2.Translator(
|
159 |
+
self.ckpt_dir, device=device
|
160 |
+
) # , compute_type="auto")
|
161 |
+
self.translate_lines = self.ctranslate2_translate_lines
|
162 |
+
elif model_type == "fairseq":
|
163 |
+
from .custom_interactive import Translator
|
164 |
+
|
165 |
+
self.translator = Translator(
|
166 |
+
data_dir=os.path.join(self.ckpt_dir, "final_bin"),
|
167 |
+
checkpoint_path=os.path.join(self.ckpt_dir, "model", "checkpoint_best.pt"),
|
168 |
+
batch_size=100,
|
169 |
+
)
|
170 |
+
self.translate_lines = self.fairseq_translate_lines
|
171 |
+
else:
|
172 |
+
raise NotImplementedError(f"Unknown model_type: {model_type}")
|
173 |
+
|
174 |
+
def ctranslate2_translate_lines(self, lines: List[str]) -> List[str]:
|
175 |
+
tokenized_sents = [x.strip().split(" ") for x in lines]
|
176 |
+
translations = self.translator.translate_batch(
|
177 |
+
tokenized_sents,
|
178 |
+
max_batch_size=9216,
|
179 |
+
batch_type="tokens",
|
180 |
+
max_input_length=160,
|
181 |
+
max_decoding_length=256,
|
182 |
+
beam_size=5,
|
183 |
+
)
|
184 |
+
translations = [" ".join(x.hypotheses[0]) for x in translations]
|
185 |
+
return translations
|
186 |
+
|
187 |
+
def fairseq_translate_lines(self, lines: List[str]) -> List[str]:
|
188 |
+
return self.translator.translate(lines)
|
189 |
+
|
190 |
+
def paragraphs_batch_translate__multilingual(self, batch_payloads: List[tuple]) -> List[str]:
|
191 |
+
"""
|
192 |
+
Translates a batch of input paragraphs (including pre/post processing)
|
193 |
+
from any language to any language.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
batch_payloads (List[tuple]): batch of long input-texts to be translated, each in format: (paragraph, src_lang, tgt_lang)
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
List[str]: batch of paragraph-translations in the respective languages.
|
200 |
+
"""
|
201 |
+
paragraph_id_to_sentence_range = []
|
202 |
+
global__sents = []
|
203 |
+
global__preprocessed_sents = []
|
204 |
+
global__preprocessed_sents_placeholder_entity_map = []
|
205 |
+
|
206 |
+
for i in range(len(batch_payloads)):
|
207 |
+
paragraph, src_lang, tgt_lang = batch_payloads[i]
|
208 |
+
if self.input_lang_code_format == "iso":
|
209 |
+
src_lang, tgt_lang = iso_to_flores[src_lang], iso_to_flores[tgt_lang]
|
210 |
+
|
211 |
+
batch = split_sentences(paragraph, src_lang)
|
212 |
+
global__sents.extend(batch)
|
213 |
+
|
214 |
+
preprocessed_sents, placeholder_entity_map_sents = self.preprocess_batch(
|
215 |
+
batch, src_lang, tgt_lang
|
216 |
+
)
|
217 |
+
|
218 |
+
global_sentence_start_index = len(global__preprocessed_sents)
|
219 |
+
global__preprocessed_sents.extend(preprocessed_sents)
|
220 |
+
global__preprocessed_sents_placeholder_entity_map.extend(placeholder_entity_map_sents)
|
221 |
+
paragraph_id_to_sentence_range.append(
|
222 |
+
(global_sentence_start_index, len(global__preprocessed_sents))
|
223 |
+
)
|
224 |
+
|
225 |
+
translations = self.translate_lines(global__preprocessed_sents)
|
226 |
+
|
227 |
+
translated_paragraphs = []
|
228 |
+
for paragraph_id, sentence_range in enumerate(paragraph_id_to_sentence_range):
|
229 |
+
tgt_lang = batch_payloads[paragraph_id][2]
|
230 |
+
if self.input_lang_code_format == "iso":
|
231 |
+
tgt_lang = iso_to_flores[tgt_lang]
|
232 |
+
|
233 |
+
postprocessed_sents = self.postprocess(
|
234 |
+
translations[sentence_range[0] : sentence_range[1]],
|
235 |
+
global__preprocessed_sents_placeholder_entity_map[
|
236 |
+
sentence_range[0] : sentence_range[1]
|
237 |
+
],
|
238 |
+
tgt_lang,
|
239 |
+
)
|
240 |
+
translated_paragraph = " ".join(postprocessed_sents)
|
241 |
+
translated_paragraphs.append(translated_paragraph)
|
242 |
+
|
243 |
+
return translated_paragraphs
|
244 |
+
|
245 |
+
# translate a batch of sentences from src_lang to tgt_lang
|
246 |
+
def batch_translate(self, batch: List[str], src_lang: str, tgt_lang: str) -> List[str]:
|
247 |
+
"""
|
248 |
+
Translates a batch of input sentences (including pre/post processing)
|
249 |
+
from source language to target language.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
batch (List[str]): batch of input sentences to be translated.
|
253 |
+
src_lang (str): flores source language code.
|
254 |
+
tgt_lang (str): flores target language code.
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
List[str]: batch of translated-sentences generated by the model.
|
258 |
+
"""
|
259 |
+
|
260 |
+
assert isinstance(batch, list)
|
261 |
+
|
262 |
+
if self.input_lang_code_format == "iso":
|
263 |
+
src_lang, tgt_lang = iso_to_flores[src_lang], iso_to_flores[tgt_lang]
|
264 |
+
|
265 |
+
preprocessed_sents, placeholder_entity_map_sents = self.preprocess_batch(
|
266 |
+
batch, src_lang, tgt_lang
|
267 |
+
)
|
268 |
+
translations = self.translate_lines(preprocessed_sents)
|
269 |
+
return self.postprocess(translations, placeholder_entity_map_sents, tgt_lang)
|
270 |
+
|
271 |
+
# translate a paragraph from src_lang to tgt_lang
|
272 |
+
def translate_paragraph(self, paragraph: str, src_lang: str, tgt_lang: str) -> str:
|
273 |
+
"""
|
274 |
+
Translates an input text paragraph (including pre/post processing)
|
275 |
+
from source language to target language.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
paragraph (str): input text paragraph to be translated.
|
279 |
+
src_lang (str): flores source language code.
|
280 |
+
tgt_lang (str): flores target language code.
|
281 |
+
|
282 |
+
Returns:
|
283 |
+
str: paragraph translation generated by the model.
|
284 |
+
"""
|
285 |
+
|
286 |
+
assert isinstance(paragraph, str)
|
287 |
+
|
288 |
+
if self.input_lang_code_format == "iso":
|
289 |
+
flores_src_lang = iso_to_flores[src_lang]
|
290 |
+
else:
|
291 |
+
flores_src_lang = src_lang
|
292 |
+
|
293 |
+
sents = split_sentences(paragraph, flores_src_lang)
|
294 |
+
postprocessed_sents = self.batch_translate(sents, src_lang, tgt_lang)
|
295 |
+
translated_paragraph = " ".join(postprocessed_sents)
|
296 |
+
|
297 |
+
return translated_paragraph
|
298 |
+
|
299 |
+
def preprocess_batch(self, batch: List[str], src_lang: str, tgt_lang: str) -> List[str]:
|
300 |
+
"""
|
301 |
+
Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it. It also tokenizes the
|
302 |
+
normalized text sequences using sentence piece tokenizer and also adds language tags.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
batch (List[str]): input list of sentences to preprocess.
|
306 |
+
src_lang (str): flores language code of the input text sentences.
|
307 |
+
tgt_lang (str): flores language code of the output text sentences.
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
Tuple[List[str], List[Dict]]: a tuple of list of preprocessed input text sentences and also a corresponding list of dictionary
|
311 |
+
mapping placeholders to their original values.
|
312 |
+
"""
|
313 |
+
preprocessed_sents, placeholder_entity_map_sents = self.preprocess(batch, lang=src_lang)
|
314 |
+
tokenized_sents = self.apply_spm(preprocessed_sents)
|
315 |
+
tokenized_sents, placeholder_entity_map_sents = truncate_long_sentences(
|
316 |
+
tokenized_sents, placeholder_entity_map_sents
|
317 |
+
)
|
318 |
+
tagged_sents = apply_lang_tags(tokenized_sents, src_lang, tgt_lang)
|
319 |
+
return tagged_sents, placeholder_entity_map_sents
|
320 |
+
|
321 |
+
def apply_spm(self, sents: List[str]) -> List[str]:
|
322 |
+
"""
|
323 |
+
Applies sentence piece encoding to the batch of input sentences.
|
324 |
+
|
325 |
+
Args:
|
326 |
+
sents (List[str]): batch of the input sentences.
|
327 |
+
|
328 |
+
Returns:
|
329 |
+
List[str]: batch of encoded sentences with sentence piece model
|
330 |
+
"""
|
331 |
+
return [" ".join(self.sp_src.encode(sent, out_type=str)) for sent in sents]
|
332 |
+
|
333 |
+
def preprocess_sent(
|
334 |
+
self,
|
335 |
+
sent: str,
|
336 |
+
normalizer: Union[MosesPunctNormalizer, indic_normalize.IndicNormalizerFactory],
|
337 |
+
lang: str,
|
338 |
+
) -> Tuple[str, Dict]:
|
339 |
+
"""
|
340 |
+
Preprocess an input text sentence by normalizing, tokenization, and possibly transliterating it.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
sent (str): input text sentence to preprocess.
|
344 |
+
normalizer (Union[MosesPunctNormalizer, indic_normalize.IndicNormalizerFactory]): an object that performs normalization on the text.
|
345 |
+
lang (str): flores language code of the input text sentence.
|
346 |
+
|
347 |
+
Returns:
|
348 |
+
Tuple[str, Dict]: A tuple containing the preprocessed input text sentence and a corresponding dictionary
|
349 |
+
mapping placeholders to their original values.
|
350 |
+
"""
|
351 |
+
iso_lang = flores_codes[lang]
|
352 |
+
sent = punc_norm(sent, iso_lang)
|
353 |
+
sent, placeholder_entity_map = normalize(sent)
|
354 |
+
|
355 |
+
transliterate = True
|
356 |
+
if lang.split("_")[1] in ["Arab", "Aran", "Olck", "Mtei", "Latn"]:
|
357 |
+
transliterate = False
|
358 |
+
|
359 |
+
if iso_lang == "en":
|
360 |
+
processed_sent = " ".join(
|
361 |
+
self.en_tok.tokenize(self.en_normalizer.normalize(sent.strip()), escape=False)
|
362 |
+
)
|
363 |
+
elif transliterate:
|
364 |
+
# transliterates from the any specific language to devanagari
|
365 |
+
# which is why we specify lang2_code as "hi".
|
366 |
+
processed_sent = self.xliterator.transliterate(
|
367 |
+
" ".join(
|
368 |
+
indic_tokenize.trivial_tokenize(normalizer.normalize(sent.strip()), iso_lang)
|
369 |
+
),
|
370 |
+
iso_lang,
|
371 |
+
"hi",
|
372 |
+
).replace(" ् ", "्")
|
373 |
+
else:
|
374 |
+
# we only need to transliterate for joint training
|
375 |
+
processed_sent = " ".join(
|
376 |
+
indic_tokenize.trivial_tokenize(normalizer.normalize(sent.strip()), iso_lang)
|
377 |
+
)
|
378 |
+
|
379 |
+
return processed_sent, placeholder_entity_map
|
380 |
+
|
381 |
+
def preprocess(self, sents: List[str], lang: str):
|
382 |
+
"""
|
383 |
+
Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it.
|
384 |
+
|
385 |
+
Args:
|
386 |
+
batch (List[str]): input list of sentences to preprocess.
|
387 |
+
lang (str): flores language code of the input text sentences.
|
388 |
+
|
389 |
+
Returns:
|
390 |
+
Tuple[List[str], List[Dict]]: a tuple of list of preprocessed input text sentences and also a corresponding list of dictionary
|
391 |
+
mapping placeholders to their original values.
|
392 |
+
"""
|
393 |
+
processed_sents, placeholder_entity_map_sents = [], []
|
394 |
+
|
395 |
+
if lang == "eng_Latn":
|
396 |
+
normalizer = None
|
397 |
+
else:
|
398 |
+
normfactory = indic_normalize.IndicNormalizerFactory()
|
399 |
+
normalizer = normfactory.get_normalizer(flores_codes[lang])
|
400 |
+
|
401 |
+
for sent in sents:
|
402 |
+
sent, placeholder_entity_map = self.preprocess_sent(sent, normalizer, lang)
|
403 |
+
processed_sents.append(sent)
|
404 |
+
placeholder_entity_map_sents.append(placeholder_entity_map)
|
405 |
+
|
406 |
+
return processed_sents, placeholder_entity_map_sents
|
407 |
+
|
408 |
+
def postprocess(
|
409 |
+
self,
|
410 |
+
sents: List[str],
|
411 |
+
placeholder_entity_map: List[Dict],
|
412 |
+
lang: str,
|
413 |
+
common_lang: str = "hin_Deva",
|
414 |
+
) -> List[str]:
|
415 |
+
"""
|
416 |
+
Postprocesses a batch of input sentences after the translation generations.
|
417 |
+
|
418 |
+
Args:
|
419 |
+
sents (List[str]): batch of translated sentences to postprocess.
|
420 |
+
placeholder_entity_map (List[Dict]): dictionary mapping placeholders to the original entity values.
|
421 |
+
lang (str): flores language code of the input sentences.
|
422 |
+
common_lang (str, optional): flores language code of the transliterated language (defaults: hin_Deva).
|
423 |
+
|
424 |
+
Returns:
|
425 |
+
List[str]: postprocessed batch of input sentences.
|
426 |
+
"""
|
427 |
+
|
428 |
+
lang_code, script_code = lang.split("_")
|
429 |
+
# SPM decode
|
430 |
+
for i in range(len(sents)):
|
431 |
+
# sent_tokens = sents[i].split(" ")
|
432 |
+
# sents[i] = self.sp_tgt.decode(sent_tokens)
|
433 |
+
|
434 |
+
sents[i] = sents[i].replace(" ", "").replace("▁", " ").strip()
|
435 |
+
|
436 |
+
# Fixes for Perso-Arabic scripts
|
437 |
+
# TODO: Move these normalizations inside indic-nlp-library
|
438 |
+
if script_code in {"Arab", "Aran"}:
|
439 |
+
# UrduHack adds space before punctuations. Since the model was trained without fixing this issue, let's fix it now
|
440 |
+
sents[i] = sents[i].replace(" ؟", "؟").replace(" ۔", "۔").replace(" ،", "،")
|
441 |
+
# Kashmiri bugfix for palatalization: https://github.com/AI4Bharat/IndicTrans2/issues/11
|
442 |
+
sents[i] = sents[i].replace("ٮ۪", "ؠ")
|
443 |
+
|
444 |
+
assert len(sents) == len(placeholder_entity_map)
|
445 |
+
|
446 |
+
for i in range(0, len(sents)):
|
447 |
+
for key in placeholder_entity_map[i].keys():
|
448 |
+
sents[i] = sents[i].replace(key, placeholder_entity_map[i][key])
|
449 |
+
|
450 |
+
# Detokenize and transliterate to native scripts if applicable
|
451 |
+
postprocessed_sents = []
|
452 |
+
|
453 |
+
if lang == "eng_Latn":
|
454 |
+
for sent in sents:
|
455 |
+
postprocessed_sents.append(self.en_detok.detokenize(sent.split(" ")))
|
456 |
+
else:
|
457 |
+
for sent in sents:
|
458 |
+
outstr = indic_detokenize.trivial_detokenize(
|
459 |
+
self.xliterator.transliterate(
|
460 |
+
sent, flores_codes[common_lang], flores_codes[lang]
|
461 |
+
),
|
462 |
+
flores_codes[lang],
|
463 |
+
)
|
464 |
+
|
465 |
+
# Oriya bug: indic-nlp-library produces ଯ଼ instead of ୟ when converting from Devanagari to Odia
|
466 |
+
# TODO: Find out what's the issue with unicode transliterator for Oriya and fix it
|
467 |
+
if lang_code == "ory":
|
468 |
+
outstr = outstr.replace("ଯ଼", 'ୟ')
|
469 |
+
|
470 |
+
postprocessed_sents.append(outstr)
|
471 |
+
|
472 |
+
return postprocessed_sents
|
IndicTrans2/inference/flores_codes_map_indic.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
FLORES language code mapping to 2 letter ISO language code for compatibility
|
3 |
+
with Indic NLP Library (https://github.com/anoopkunchukuttan/indic_nlp_library)
|
4 |
+
"""
|
5 |
+
flores_codes = {
|
6 |
+
"asm_Beng": "as",
|
7 |
+
"awa_Deva": "hi",
|
8 |
+
"ben_Beng": "bn",
|
9 |
+
"bho_Deva": "hi",
|
10 |
+
"brx_Deva": "hi",
|
11 |
+
"doi_Deva": "hi",
|
12 |
+
"eng_Latn": "en",
|
13 |
+
"gom_Deva": "kK",
|
14 |
+
"guj_Gujr": "gu",
|
15 |
+
"hin_Deva": "hi",
|
16 |
+
"hne_Deva": "hi",
|
17 |
+
"kan_Knda": "kn",
|
18 |
+
"kas_Arab": "ur",
|
19 |
+
"kas_Deva": "hi",
|
20 |
+
"kha_Latn": "en",
|
21 |
+
"lus_Latn": "en",
|
22 |
+
"mag_Deva": "hi",
|
23 |
+
"mai_Deva": "hi",
|
24 |
+
"mal_Mlym": "ml",
|
25 |
+
"mar_Deva": "mr",
|
26 |
+
"mni_Beng": "bn",
|
27 |
+
"mni_Mtei": "hi",
|
28 |
+
"npi_Deva": "ne",
|
29 |
+
"ory_Orya": "or",
|
30 |
+
"pan_Guru": "pa",
|
31 |
+
"san_Deva": "hi",
|
32 |
+
"sat_Olck": "or",
|
33 |
+
"snd_Arab": "ur",
|
34 |
+
"snd_Deva": "hi",
|
35 |
+
"tam_Taml": "ta",
|
36 |
+
"tel_Telu": "te",
|
37 |
+
"urd_Arab": "ur",
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
flores_to_iso = {
|
42 |
+
"asm_Beng": "as",
|
43 |
+
"awa_Deva": "awa",
|
44 |
+
"ben_Beng": "bn",
|
45 |
+
"bho_Deva": "bho",
|
46 |
+
"brx_Deva": "brx",
|
47 |
+
"doi_Deva": "doi",
|
48 |
+
"eng_Latn": "en",
|
49 |
+
"gom_Deva": "gom",
|
50 |
+
"guj_Gujr": "gu",
|
51 |
+
"hin_Deva": "hi",
|
52 |
+
"hne_Deva": "hne",
|
53 |
+
"kan_Knda": "kn",
|
54 |
+
"kas_Arab": "ksa",
|
55 |
+
"kas_Deva": "ksd",
|
56 |
+
"kha_Latn": "kha",
|
57 |
+
"lus_Latn": "lus",
|
58 |
+
"mag_Deva": "mag",
|
59 |
+
"mai_Deva": "mai",
|
60 |
+
"mal_Mlym": "ml",
|
61 |
+
"mar_Deva": "mr",
|
62 |
+
"mni_Beng": "mnib",
|
63 |
+
"mni_Mtei": "mnim",
|
64 |
+
"npi_Deva": "ne",
|
65 |
+
"ory_Orya": "or",
|
66 |
+
"pan_Guru": "pa",
|
67 |
+
"san_Deva": "sa",
|
68 |
+
"sat_Olck": "sat",
|
69 |
+
"snd_Arab": "sda",
|
70 |
+
"snd_Deva": "sdd",
|
71 |
+
"tam_Taml": "ta",
|
72 |
+
"tel_Telu": "te",
|
73 |
+
"urd_Arab": "ur",
|
74 |
+
}
|
75 |
+
|
76 |
+
iso_to_flores = {iso_code: flores_code for flores_code, iso_code in flores_to_iso.items()}
|
77 |
+
# Patch for digraphic langs.
|
78 |
+
iso_to_flores["ks"] = "kas_Arab"
|
79 |
+
iso_to_flores["ks_Deva"] = "kas_Deva"
|
80 |
+
iso_to_flores["mni"] = "mni_Mtei"
|
81 |
+
iso_to_flores["mni_Beng"] = "mni_Beng"
|
82 |
+
iso_to_flores["sd"] = "snd_Arab"
|
83 |
+
iso_to_flores["sd_Deva"] = "snd_Deva"
|
IndicTrans2/inference/indic_num_map.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A dictionary mapping intended to normalize the numerals in Indic languages from
|
3 |
+
native script to Roman script. This is done to ensure that the figures / numbers
|
4 |
+
mentioned in native script are perfectly preserved during translation.
|
5 |
+
"""
|
6 |
+
INDIC_NUM_MAP = {
|
7 |
+
"\u09e6": "0",
|
8 |
+
"0": "0",
|
9 |
+
"\u0ae6": "0",
|
10 |
+
"\u0ce6": "0",
|
11 |
+
"\u0966": "0",
|
12 |
+
"\u0660": "0",
|
13 |
+
"\uabf0": "0",
|
14 |
+
"\u0b66": "0",
|
15 |
+
"\u0a66": "0",
|
16 |
+
"\u1c50": "0",
|
17 |
+
"\u06f0": "0",
|
18 |
+
"\u09e7": "1",
|
19 |
+
"1": "1",
|
20 |
+
"\u0ae7": "1",
|
21 |
+
"\u0967": "1",
|
22 |
+
"\u0ce7": "1",
|
23 |
+
"\u06f1": "1",
|
24 |
+
"\uabf1": "1",
|
25 |
+
"\u0b67": "1",
|
26 |
+
"\u0a67": "1",
|
27 |
+
"\u1c51": "1",
|
28 |
+
"\u0c67": "1",
|
29 |
+
"\u09e8": "2",
|
30 |
+
"2": "2",
|
31 |
+
"\u0ae8": "2",
|
32 |
+
"\u0968": "2",
|
33 |
+
"\u0ce8": "2",
|
34 |
+
"\u06f2": "2",
|
35 |
+
"\uabf2": "2",
|
36 |
+
"\u0b68": "2",
|
37 |
+
"\u0a68": "2",
|
38 |
+
"\u1c52": "2",
|
39 |
+
"\u0c68": "2",
|
40 |
+
"\u09e9": "3",
|
41 |
+
"3": "3",
|
42 |
+
"\u0ae9": "3",
|
43 |
+
"\u0969": "3",
|
44 |
+
"\u0ce9": "3",
|
45 |
+
"\u06f3": "3",
|
46 |
+
"\uabf3": "3",
|
47 |
+
"\u0b69": "3",
|
48 |
+
"\u0a69": "3",
|
49 |
+
"\u1c53": "3",
|
50 |
+
"\u0c69": "3",
|
51 |
+
"\u09ea": "4",
|
52 |
+
"4": "4",
|
53 |
+
"\u0aea": "4",
|
54 |
+
"\u096a": "4",
|
55 |
+
"\u0cea": "4",
|
56 |
+
"\u06f4": "4",
|
57 |
+
"\uabf4": "4",
|
58 |
+
"\u0b6a": "4",
|
59 |
+
"\u0a6a": "4",
|
60 |
+
"\u1c54": "4",
|
61 |
+
"\u0c6a": "4",
|
62 |
+
"\u09eb": "5",
|
63 |
+
"5": "5",
|
64 |
+
"\u0aeb": "5",
|
65 |
+
"\u096b": "5",
|
66 |
+
"\u0ceb": "5",
|
67 |
+
"\u06f5": "5",
|
68 |
+
"\uabf5": "5",
|
69 |
+
"\u0b6b": "5",
|
70 |
+
"\u0a6b": "5",
|
71 |
+
"\u1c55": "5",
|
72 |
+
"\u0c6b": "5",
|
73 |
+
"\u09ec": "6",
|
74 |
+
"6": "6",
|
75 |
+
"\u0aec": "6",
|
76 |
+
"\u096c": "6",
|
77 |
+
"\u0cec": "6",
|
78 |
+
"\u06f6": "6",
|
79 |
+
"\uabf6": "6",
|
80 |
+
"\u0b6c": "6",
|
81 |
+
"\u0a6c": "6",
|
82 |
+
"\u1c56": "6",
|
83 |
+
"\u0c6c": "6",
|
84 |
+
"\u09ed": "7",
|
85 |
+
"7": "7",
|
86 |
+
"\u0aed": "7",
|
87 |
+
"\u096d": "7",
|
88 |
+
"\u0ced": "7",
|
89 |
+
"\u06f7": "7",
|
90 |
+
"\uabf7": "7",
|
91 |
+
"\u0b6d": "7",
|
92 |
+
"\u0a6d": "7",
|
93 |
+
"\u1c57": "7",
|
94 |
+
"\u0c6d": "7",
|
95 |
+
"\u09ee": "8",
|
96 |
+
"8": "8",
|
97 |
+
"\u0aee": "8",
|
98 |
+
"\u096e": "8",
|
99 |
+
"\u0cee": "8",
|
100 |
+
"\u06f8": "8",
|
101 |
+
"\uabf8": "8",
|
102 |
+
"\u0b6e": "8",
|
103 |
+
"\u0a6e": "8",
|
104 |
+
"\u1c58": "8",
|
105 |
+
"\u0c6e": "8",
|
106 |
+
"\u09ef": "9",
|
107 |
+
"9": "9",
|
108 |
+
"\u0aef": "9",
|
109 |
+
"\u096f": "9",
|
110 |
+
"\u0cef": "9",
|
111 |
+
"\u06f9": "9",
|
112 |
+
"\uabf9": "9",
|
113 |
+
"\u0b6f": "9",
|
114 |
+
"\u0a6f": "9",
|
115 |
+
"\u1c59": "9",
|
116 |
+
"\u0c6f": "9",
|
117 |
+
}
|
IndicTrans2/inference/model_configs/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import custom_transformer
|
IndicTrans2/inference/model_configs/custom_transformer.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fairseq.models import register_model_architecture
|
2 |
+
from fairseq.models.transformer import base_architecture
|
3 |
+
|
4 |
+
|
5 |
+
@register_model_architecture("transformer", "transformer_2x")
|
6 |
+
def transformer_big(args):
|
7 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
8 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
9 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
10 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
11 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
|
12 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
|
13 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
14 |
+
base_architecture(args)
|
15 |
+
|
16 |
+
|
17 |
+
@register_model_architecture("transformer", "transformer_4x")
|
18 |
+
def transformer_huge(args):
|
19 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1536)
|
20 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
21 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
22 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
23 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1536)
|
24 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
|
25 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
26 |
+
base_architecture(args)
|
27 |
+
|
28 |
+
|
29 |
+
@register_model_architecture("transformer", "transformer_9x")
|
30 |
+
def transformer_xlarge(args):
|
31 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 2048)
|
32 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 8192)
|
33 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
34 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
35 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
|
36 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192)
|
37 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
38 |
+
base_architecture(args)
|
39 |
+
|
40 |
+
|
41 |
+
@register_model_architecture("transformer", "transformer_12e12d_9xeq")
|
42 |
+
def transformer_vxlarge(args):
|
43 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1536)
|
44 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
45 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
46 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
47 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1536)
|
48 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
|
49 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
50 |
+
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
51 |
+
args.decoder_layers = getattr(args, "decoder_layers", 12)
|
52 |
+
base_architecture(args)
|
53 |
+
|
54 |
+
|
55 |
+
@register_model_architecture("transformer", "transformer_18_18")
|
56 |
+
def transformer_deep(args):
|
57 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
58 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 8 * 1024)
|
59 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
60 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
|
61 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
|
62 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
|
63 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8 * 1024)
|
64 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
65 |
+
args.encoder_layers = getattr(args, "encoder_layers", 18)
|
66 |
+
args.decoder_layers = getattr(args, "decoder_layers", 18)
|
67 |
+
base_architecture(args)
|
68 |
+
|
69 |
+
|
70 |
+
@register_model_architecture("transformer", "transformer_24_24")
|
71 |
+
def transformer_xdeep(args):
|
72 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
73 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 8 * 1024)
|
74 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
75 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
|
76 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
|
77 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
|
78 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8 * 1024)
|
79 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
80 |
+
args.encoder_layers = getattr(args, "encoder_layers", 24)
|
81 |
+
args.decoder_layers = getattr(args, "decoder_layers", 24)
|
82 |
+
base_architecture(args)
|
IndicTrans2/inference/normalize-punctuation.perl
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env perl
|
2 |
+
#
|
3 |
+
# This file is part of moses. Its use is licensed under the GNU Lesser General
|
4 |
+
# Public License version 2.1 or, at your option, any later version.
|
5 |
+
|
6 |
+
use warnings;
|
7 |
+
use strict;
|
8 |
+
|
9 |
+
my $language = "en";
|
10 |
+
my $PENN = 0;
|
11 |
+
|
12 |
+
while (@ARGV) {
|
13 |
+
$_ = shift;
|
14 |
+
/^-b$/ && ($| = 1, next); # not buffered (flush each line)
|
15 |
+
/^-l$/ && ($language = shift, next);
|
16 |
+
/^[^\-]/ && ($language = $_, next);
|
17 |
+
/^-penn$/ && ($PENN = 1, next);
|
18 |
+
}
|
19 |
+
|
20 |
+
while(<STDIN>) {
|
21 |
+
s/\r//g;
|
22 |
+
# remove extra spaces
|
23 |
+
s/\(/ \(/g;
|
24 |
+
s/\)/\) /g; s/ +/ /g;
|
25 |
+
s/\) ([\.\!\:\?\;\,])/\)$1/g;
|
26 |
+
s/\( /\(/g;
|
27 |
+
s/ \)/\)/g;
|
28 |
+
s/(\d) \%/$1\%/g;
|
29 |
+
s/ :/:/g;
|
30 |
+
s/ ;/;/g;
|
31 |
+
# normalize unicode punctuation
|
32 |
+
if ($PENN == 0) {
|
33 |
+
s/\`/\'/g;
|
34 |
+
s/\'\'/ \" /g;
|
35 |
+
}
|
36 |
+
|
37 |
+
s/„/\"/g;
|
38 |
+
s/“/\"/g;
|
39 |
+
s/”/\"/g;
|
40 |
+
s/–/-/g;
|
41 |
+
s/—/ - /g; s/ +/ /g;
|
42 |
+
s/´/\'/g;
|
43 |
+
s/([a-z])‘([a-z])/$1\'$2/gi;
|
44 |
+
s/([a-z])’([a-z])/$1\'$2/gi;
|
45 |
+
s/‘/\'/g;
|
46 |
+
s/‚/\'/g;
|
47 |
+
s/’/\"/g;
|
48 |
+
s/''/\"/g;
|
49 |
+
s/´´/\"/g;
|
50 |
+
s/…/.../g;
|
51 |
+
# French quotes
|
52 |
+
s/ « / \"/g;
|
53 |
+
s/« /\"/g;
|
54 |
+
s/«/\"/g;
|
55 |
+
s/ » /\" /g;
|
56 |
+
s/ »/\"/g;
|
57 |
+
s/»/\"/g;
|
58 |
+
# handle pseudo-spaces
|
59 |
+
s/ \%/\%/g;
|
60 |
+
s/nº /nº /g;
|
61 |
+
s/ :/:/g;
|
62 |
+
s/ ºC/ ºC/g;
|
63 |
+
s/ cm/ cm/g;
|
64 |
+
s/ \?/\?/g;
|
65 |
+
s/ \!/\!/g;
|
66 |
+
s/ ;/;/g;
|
67 |
+
s/, /, /g; s/ +/ /g;
|
68 |
+
|
69 |
+
# English "quotation," followed by comma, style
|
70 |
+
if ($language eq "en") {
|
71 |
+
s/\"([,\.]+)/$1\"/g;
|
72 |
+
}
|
73 |
+
# Czech is confused
|
74 |
+
elsif ($language eq "cs" || $language eq "cz") {
|
75 |
+
}
|
76 |
+
# German/Spanish/French "quotation", followed by comma, style
|
77 |
+
else {
|
78 |
+
s/,\"/\",/g;
|
79 |
+
s/(\.+)\"(\s*[^<])/\"$1$2/g; # don't fix period at end of sentence
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
if ($language eq "de" || $language eq "es" || $language eq "cz" || $language eq "cs" || $language eq "fr") {
|
84 |
+
s/(\d) (\d)/$1,$2/g;
|
85 |
+
}
|
86 |
+
else {
|
87 |
+
s/(\d) (\d)/$1.$2/g;
|
88 |
+
}
|
89 |
+
print $_;
|
90 |
+
}
|
IndicTrans2/inference/normalize_punctuation.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# IMPORTANT NOTE: DO NOT DIRECTLY EDIT THIS FILE
|
2 |
+
# This file was manually ported from `normalize-punctuation.perl`
|
3 |
+
# TODO: Only supports English, add others
|
4 |
+
|
5 |
+
import regex as re
|
6 |
+
multispace_regex = re.compile("[ ]{2,}")
|
7 |
+
multidots_regex = re.compile(r"\.{2,}")
|
8 |
+
end_bracket_space_punc_regex = re.compile(r"\) ([\.!:?;,])")
|
9 |
+
digit_space_percent = re.compile(r"(\d) %")
|
10 |
+
double_quot_punc = re.compile(r"\"([,\.]+)")
|
11 |
+
digit_nbsp_digit = re.compile(r"(\d) (\d)")
|
12 |
+
|
13 |
+
def punc_norm(text, lang="en"):
|
14 |
+
text = text.replace('\r', '') \
|
15 |
+
.replace('(', " (") \
|
16 |
+
.replace(')', ") ") \
|
17 |
+
\
|
18 |
+
.replace("( ", "(") \
|
19 |
+
.replace(" )", ")") \
|
20 |
+
\
|
21 |
+
.replace(" :", ':') \
|
22 |
+
.replace(" ;", ';') \
|
23 |
+
.replace('`', "'") \
|
24 |
+
\
|
25 |
+
.replace('„', '"') \
|
26 |
+
.replace('“', '"') \
|
27 |
+
.replace('”', '"') \
|
28 |
+
.replace('–', '-') \
|
29 |
+
.replace('—', " - ") \
|
30 |
+
.replace('´', "'") \
|
31 |
+
.replace('‘', "'") \
|
32 |
+
.replace('‚', "'") \
|
33 |
+
.replace('’', "'") \
|
34 |
+
.replace("''", "\"") \
|
35 |
+
.replace("´´", '"') \
|
36 |
+
.replace('…', "...") \
|
37 |
+
.replace(" « ", " \"") \
|
38 |
+
.replace("« ", '"') \
|
39 |
+
.replace('«', '"') \
|
40 |
+
.replace(" » ", "\" ") \
|
41 |
+
.replace(" »", '"') \
|
42 |
+
.replace('»', '"') \
|
43 |
+
.replace(" %", '%') \
|
44 |
+
.replace("nº ", "nº ") \
|
45 |
+
.replace(" :", ':') \
|
46 |
+
.replace(" ºC", " ºC") \
|
47 |
+
.replace(" cm", " cm") \
|
48 |
+
.replace(" ?", '?') \
|
49 |
+
.replace(" !", '!') \
|
50 |
+
.replace(" ;", ';') \
|
51 |
+
.replace(", ", ", ") \
|
52 |
+
|
53 |
+
|
54 |
+
text = multispace_regex.sub(' ', text)
|
55 |
+
text = multidots_regex.sub('.', text)
|
56 |
+
text = end_bracket_space_punc_regex.sub(r")\1", text)
|
57 |
+
text = digit_space_percent.sub(r"\1%", text)
|
58 |
+
text = double_quot_punc.sub(r'\1"', text) # English "quotation," followed by comma, style
|
59 |
+
text = digit_nbsp_digit.sub(r"\1.\2", text) # What does it mean?
|
60 |
+
return text.strip(' ')
|
IndicTrans2/inference/normalize_punctuation.sh
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
set -euo pipefail
|
2 |
+
|
3 |
+
root=$(dirname $0)
|
4 |
+
|
5 |
+
lang_map_path=$root/utils.map_token_lang.tsv
|
6 |
+
|
7 |
+
usage () {
|
8 |
+
echo "usage: $0 lang" >&2
|
9 |
+
exit 1
|
10 |
+
}
|
11 |
+
|
12 |
+
[ $# -eq 1 ] || usage
|
13 |
+
|
14 |
+
lang=$1
|
15 |
+
|
16 |
+
declare -A lang_map
|
17 |
+
|
18 |
+
while read line; do
|
19 |
+
key=$(cut -f1 <<< "$line")
|
20 |
+
val=$(cut -f2 <<< "$line")
|
21 |
+
lang_map[$key]=$val
|
22 |
+
done < $lang_map_path
|
23 |
+
|
24 |
+
if [ -v "lang_map[$lang]" ]; then
|
25 |
+
lang=${lang_map[$lang]}
|
26 |
+
elif [ -v "lang_map[${lang:0:3}]" ]; then
|
27 |
+
lang=${lang_map[${lang:0:3}]}
|
28 |
+
else
|
29 |
+
echo "undefined mapping: ${lang}, falling back to: en" >&2
|
30 |
+
lang=en
|
31 |
+
fi
|
32 |
+
|
33 |
+
perl $root/normalize-punctuation.perl $lang
|
IndicTrans2/inference/normalize_regex_inference.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import regex as re
|
3 |
+
import sys
|
4 |
+
from tqdm import tqdm
|
5 |
+
from .indic_num_map import INDIC_NUM_MAP
|
6 |
+
|
7 |
+
|
8 |
+
URL_PATTERN = r'\b(?<![\w/.])(?:(?:https?|ftp)://)?(?:(?:[\w-]+\.)+(?!\.))(?:[\w/\-?#&=%.]+)+(?!\.\w+)\b'
|
9 |
+
EMAIL_PATTERN = r'[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}'
|
10 |
+
# handles dates, time, percentages, proportion, ratio, etc
|
11 |
+
NUMERAL_PATTERN = r"(~?\d+\.?\d*\s?%?\s?-?\s?~?\d+\.?\d*\s?%|~?\d+%|\d+[-\/.,:']\d+[-\/.,:'+]\d+(?:\.\d+)?|\d+[-\/.:'+]\d+(?:\.\d+)?)"
|
12 |
+
# handles upi, social media handles and hashtags
|
13 |
+
OTHER_PATTERN = r'[A-Za-z0-9]*[#|@]\w+'
|
14 |
+
|
15 |
+
|
16 |
+
def normalize_indic_numerals(line: str):
|
17 |
+
"""
|
18 |
+
Normalize the numerals in Indic languages from native script to Roman script (if present).
|
19 |
+
|
20 |
+
Args:
|
21 |
+
line (str): an input string with Indic numerals to be normalized.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
str: an input string with the all Indic numerals normalized to Roman script.
|
25 |
+
"""
|
26 |
+
return "".join([INDIC_NUM_MAP.get(c, c) for c in line])
|
27 |
+
|
28 |
+
|
29 |
+
def wrap_with_placeholders(text: str, patterns: list) -> Tuple[str, dict]:
|
30 |
+
"""
|
31 |
+
Wraps substrings with matched patterns in the given text with placeholders and returns
|
32 |
+
the modified text along with a mapping of the placeholders to their original value.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
text (str): an input string which needs to be wrapped with the placeholders.
|
36 |
+
pattern (list): list of patterns to search for in the input string.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
Tuple[str, dict]: a tuple containing the modified text and a dictionary mapping
|
40 |
+
placeholders to their original values.
|
41 |
+
"""
|
42 |
+
serial_no = 1
|
43 |
+
|
44 |
+
placeholder_entity_map = dict()
|
45 |
+
|
46 |
+
for pattern in patterns:
|
47 |
+
matches = set(re.findall(pattern, text))
|
48 |
+
|
49 |
+
# wrap common match with placeholder tags
|
50 |
+
for match in matches:
|
51 |
+
if pattern==URL_PATTERN :
|
52 |
+
#Avoids false positive URL matches for names with initials.
|
53 |
+
temp = match.replace(".",'')
|
54 |
+
if len(temp)<4:
|
55 |
+
continue
|
56 |
+
if pattern==NUMERAL_PATTERN :
|
57 |
+
#Short numeral patterns do not need placeholder based handling.
|
58 |
+
temp = match.replace(" ",'').replace(".",'').replace(":",'')
|
59 |
+
if len(temp)<4:
|
60 |
+
continue
|
61 |
+
|
62 |
+
#Set of Translations of "ID" in all the suppported languages have been collated.
|
63 |
+
#This has been added to deal with edge cases where placeholders might get translated.
|
64 |
+
indic_failure_cases = ['آی ڈی ', 'ꯑꯥꯏꯗꯤ', 'आईडी', 'आई . डी . ', 'ऐटि', 'آئی ڈی ', 'ᱟᱭᱰᱤ ᱾', 'आयडी', 'ऐडि', 'आइडि']
|
65 |
+
placeholder = "<ID{}>".format(serial_no)
|
66 |
+
alternate_placeholder = "< ID{} >".format(serial_no)
|
67 |
+
placeholder_entity_map[placeholder] = match
|
68 |
+
placeholder_entity_map[alternate_placeholder] = match
|
69 |
+
|
70 |
+
for i in indic_failure_cases:
|
71 |
+
placeholder_temp = "<{}{}>".format(i,serial_no)
|
72 |
+
placeholder_entity_map[placeholder_temp] = match
|
73 |
+
placeholder_temp = "< {}{} >".format(i, serial_no)
|
74 |
+
placeholder_entity_map[placeholder_temp] = match
|
75 |
+
placeholder_temp = "< {} {} >".format(i, serial_no)
|
76 |
+
placeholder_entity_map[placeholder_temp] = match
|
77 |
+
|
78 |
+
text = text.replace(match, placeholder)
|
79 |
+
serial_no+=1
|
80 |
+
|
81 |
+
text = re.sub("\s+", " ", text)
|
82 |
+
|
83 |
+
#Regex has failure cases in trailing "/" in URLs, so this is a workaround.
|
84 |
+
text = text.replace(">/",">")
|
85 |
+
|
86 |
+
return text, placeholder_entity_map
|
87 |
+
|
88 |
+
|
89 |
+
def normalize(text: str, patterns: list = [EMAIL_PATTERN, URL_PATTERN, NUMERAL_PATTERN, OTHER_PATTERN]) -> Tuple[str, dict]:
|
90 |
+
"""
|
91 |
+
Normalizes and wraps the spans of input string with placeholder tags. It first normalizes
|
92 |
+
the Indic numerals in the input string to Roman script. Later, it uses the input string with normalized
|
93 |
+
Indic numerals to wrap the spans of text matching the pattern with placeholder tags.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
text (str): input string.
|
97 |
+
pattern (list): list of patterns to search for in the input string.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
Tuple[str, dict]: a tuple containing the modified text and a dictionary mapping
|
101 |
+
placeholders to their original values.
|
102 |
+
"""
|
103 |
+
text = normalize_indic_numerals(text.strip("\n"))
|
104 |
+
text, placeholder_entity_map = wrap_with_placeholders(text, patterns)
|
105 |
+
return text, placeholder_entity_map
|
IndicTrans2/inference/requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/anoopkunchukuttan/indic_nlp_library
|
2 |
+
git+https://github.com/pytorch/fairseq
|
3 |
+
sacremoses
|
4 |
+
pandas
|
5 |
+
mock
|
6 |
+
nltk
|
7 |
+
sacrebleu
|
8 |
+
urduhack[tf]
|
9 |
+
mosestokenizer
|
10 |
+
ctranslate2
|
11 |
+
sentencepiece
|
IndicTrans2/inference/triton_server/Dockerfile
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver:22.12-py3
|
2 |
+
FROM ${BASE_IMAGE}
|
3 |
+
|
4 |
+
# Ensure apt-get won't prompt for selecting options
|
5 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
6 |
+
ENV PYTHONIOENCODING=utf8
|
7 |
+
|
8 |
+
WORKDIR /home
|
9 |
+
|
10 |
+
WORKDIR /home/indicTrans2
|
11 |
+
COPY requirements.txt .
|
12 |
+
RUN pip install -r requirements.txt
|
13 |
+
|
14 |
+
COPY download.py .
|
15 |
+
RUN python3 download.py
|
16 |
+
|
17 |
+
COPY . ./inference
|
18 |
+
|
19 |
+
WORKDIR /home/
|
20 |
+
COPY ./triton_server/triton_repo ./triton_repo
|
21 |
+
|
22 |
+
CMD ["tritonserver", "--model-repository=/home/triton_repo", "--log-verbose=2", "--strict-model-config=false", "--http-port=8000", "--grpc-port=8001", "--metrics-port=8002"]
|
23 |
+
EXPOSE 8000
|
24 |
+
EXPOSE 8001
|
25 |
+
EXPOSE 8002
|
IndicTrans2/inference/triton_server/README.md
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Triton server
|
2 |
+
|
3 |
+
## Building the image
|
4 |
+
|
5 |
+
```
|
6 |
+
cd indicTrans2/inference/
|
7 |
+
docker build -f triton_server/Dockerfile -t indictrans2_triton .
|
8 |
+
```
|
9 |
+
|
10 |
+
## Running the container
|
11 |
+
|
12 |
+
Place the `en-indic` and `indic-en` checkpoint folders into `indicTrans2/checkpoints` directory
|
13 |
+
|
14 |
+
Then start the server by:
|
15 |
+
```
|
16 |
+
docker run --shm-size=256m --gpus=1 --rm -v ${PWD}/../checkpoints/:/models/checkpoints -p 8000:8000 -t indictrans2_triton
|
17 |
+
```
|
18 |
+
|
19 |
+
## Sample client
|
20 |
+
|
21 |
+
- Do `pip install tritonclient[all] gevent` first.
|
22 |
+
- Then `python3 triton_server/client.py`
|
IndicTrans2/inference/triton_server/azure_ml/README.md
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Deployment on Azure Machine Learning
|
2 |
+
|
3 |
+
## Pre-requisites
|
4 |
+
|
5 |
+
```
|
6 |
+
cd inference/triton_server
|
7 |
+
```
|
8 |
+
|
9 |
+
Set the environment for AML:
|
10 |
+
```
|
11 |
+
export RESOURCE_GROUP=Dhruva-prod
|
12 |
+
export WORKSPACE_NAME=dhruva--central-india
|
13 |
+
export DOCKER_REGISTRY=dhruvaprod
|
14 |
+
```
|
15 |
+
|
16 |
+
Also remember to edit the `yml` files accordingly.
|
17 |
+
|
18 |
+
## Registering the model
|
19 |
+
|
20 |
+
```
|
21 |
+
az ml model create --file azure_ml/model.yml --resource-group $RESOURCE_GROUP --workspace-name $WORKSPACE_NAME
|
22 |
+
```
|
23 |
+
|
24 |
+
## Pushing the docker image to Container Registry
|
25 |
+
|
26 |
+
```
|
27 |
+
az acr login --name $DOCKER_REGISTRY
|
28 |
+
docker tag indictrans2_triton $DOCKER_REGISTRY.azurecr.io/nmt/triton-indictrans-v2:latest
|
29 |
+
docker push $DOCKER_REGISTRY.azurecr.io/nmt/triton-indictrans-v2:latest
|
30 |
+
```
|
31 |
+
|
32 |
+
## Creating the execution environment
|
33 |
+
|
34 |
+
```
|
35 |
+
az ml environment create -f azure_ml/environment.yml -g $RESOURCE_GROUP -w $WORKSPACE_NAME
|
36 |
+
```
|
37 |
+
|
38 |
+
## Publishing the endpoint for online inference
|
39 |
+
|
40 |
+
```
|
41 |
+
az ml online-endpoint create -f azure_ml/endpoint.yml -g $RESOURCE_GROUP -w $WORKSPACE_NAME
|
42 |
+
```
|
43 |
+
|
44 |
+
Now from the Azure Portal, open the Container Registry, and grant ACR_PULL permission for the above endpoint, so that it is allowed to download the docker image.
|
45 |
+
|
46 |
+
## Attaching a deployment
|
47 |
+
|
48 |
+
```
|
49 |
+
az ml online-deployment create -f azure_ml/deployment.yml --all-traffic -g $RESOURCE_GROUP -w $WORKSPACE_NAME
|
50 |
+
```
|
51 |
+
|
52 |
+
## Testing if inference works
|
53 |
+
|
54 |
+
1. From Azure ML Studio, go to the "Consume" tab, and get the endpoint domain (without `https://` or trailing `/`) and an authentication key.
|
55 |
+
2. In `client.py`, enable `ENABLE_SSL = True`, and then set the `ENDPOINT_URL` variable as well as `Authorization` value inside `HTTP_HEADERS`.
|
56 |
+
3. Run `python3 client.py`
|
IndicTrans2/inference/triton_server/azure_ml/deployment.yml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json
|
2 |
+
name: ai4b-indictransv2--t4-piv--gpu
|
3 |
+
endpoint_name: ai4b-indictransv2--t4
|
4 |
+
model: azureml:indictrans-v2--models:1
|
5 |
+
model_mount_path: /models
|
6 |
+
environment: azureml:triton-indictrans-v2-env:1
|
7 |
+
instance_type: Standard_NC4as_T4_v3
|
8 |
+
instance_count: 1
|
9 |
+
request_settings:
|
10 |
+
request_timeout_ms: 90000
|
11 |
+
max_concurrent_requests_per_instance: 100
|
12 |
+
max_queue_wait_ms: 2000
|
13 |
+
app_insights_enabled: true
|
IndicTrans2/inference/triton_server/azure_ml/endpoint.yml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineEndpoint.schema.json
|
2 |
+
name: ai4b-indictransv2--t4
|
3 |
+
auth_mode: key
|
IndicTrans2/inference/triton_server/azure_ml/environment.yml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$schema: https://azuremlschemas.azureedge.net/latest/environment.schema.json
|
2 |
+
name: triton-indictrans-v2-env
|
3 |
+
image: dhruvaprod.azurecr.io/nmt/triton-indictrans-v2:latest
|
4 |
+
version: 1
|
5 |
+
inference_config:
|
6 |
+
liveness_route:
|
7 |
+
path: /v2/health/live
|
8 |
+
port: 8000
|
9 |
+
readiness_route:
|
10 |
+
path: /v2/health/ready
|
11 |
+
port: 8000
|
12 |
+
scoring_route:
|
13 |
+
path: /
|
14 |
+
port: 8000
|
IndicTrans2/inference/triton_server/azure_ml/model.yml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$schema: https://azuremlschemas.azureedge.net/latest/model.schema.json
|
2 |
+
name: indictrans-v2--models
|
3 |
+
version: 1
|
4 |
+
path: ../../../checkpoints
|
5 |
+
type: triton_model
|
IndicTrans2/inference/triton_server/client.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tritonclient.http as http_client
|
2 |
+
from tritonclient.utils import *
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
ENABLE_SSL = False
|
6 |
+
ENDPOINT_URL = 'localhost:8000'
|
7 |
+
HTTP_HEADERS = {"Authorization": "Bearer __PASTE_KEY_HERE__"}
|
8 |
+
|
9 |
+
# Connect to the server
|
10 |
+
if ENABLE_SSL:
|
11 |
+
import gevent.ssl
|
12 |
+
triton_http_client = http_client.InferenceServerClient(
|
13 |
+
url=ENDPOINT_URL, verbose=False,
|
14 |
+
ssl=True, ssl_context_factory=gevent.ssl._create_default_https_context,
|
15 |
+
)
|
16 |
+
else:
|
17 |
+
triton_http_client = http_client.InferenceServerClient(
|
18 |
+
url=ENDPOINT_URL, verbose=False,
|
19 |
+
)
|
20 |
+
|
21 |
+
print("Is server ready - {}".format(triton_http_client.is_server_ready(headers=HTTP_HEADERS)))
|
22 |
+
|
23 |
+
def get_string_tensor(string_values, tensor_name):
|
24 |
+
string_obj = np.array(string_values, dtype="object")
|
25 |
+
input_obj = http_client.InferInput(tensor_name, string_obj.shape, np_to_triton_dtype(string_obj.dtype))
|
26 |
+
input_obj.set_data_from_numpy(string_obj)
|
27 |
+
return input_obj
|
28 |
+
|
29 |
+
def get_translation_input_for_triton(texts: list, src_lang: str, tgt_lang: str):
|
30 |
+
return [
|
31 |
+
get_string_tensor([[text] for text in texts], "INPUT_TEXT"),
|
32 |
+
get_string_tensor([[src_lang]] * len(texts), "INPUT_LANGUAGE_ID"),
|
33 |
+
get_string_tensor([[tgt_lang]] * len(texts), "OUTPUT_LANGUAGE_ID"),
|
34 |
+
]
|
35 |
+
|
36 |
+
# Prepare input and output tensors
|
37 |
+
input_sentences = ["Hello world, I am Ram and I am from Ayodhya.", "How are you Ravan bro?"]
|
38 |
+
inputs = get_translation_input_for_triton(input_sentences, "en", "hi")
|
39 |
+
output0 = http_client.InferRequestedOutput("OUTPUT_TEXT")
|
40 |
+
|
41 |
+
# Send request
|
42 |
+
response = triton_http_client.infer(
|
43 |
+
"nmt",
|
44 |
+
model_version='1',
|
45 |
+
inputs=inputs,
|
46 |
+
outputs=[output0],
|
47 |
+
headers=HTTP_HEADERS,
|
48 |
+
)#.get_response()
|
49 |
+
|
50 |
+
# Decode the response
|
51 |
+
output_batch = response.as_numpy('OUTPUT_TEXT').tolist()
|
52 |
+
for input_sentence, translation in zip(input_sentences, output_batch):
|
53 |
+
print()
|
54 |
+
print(input_sentence)
|
55 |
+
print(translation[0].decode("utf-8"))
|
IndicTrans2/inference/triton_server/dhruva/ulca_model.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
IndicTrans2/inference/triton_server/triton_repo/nmt/1/model.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import triton_python_backend_utils as pb_utils
|
6 |
+
|
7 |
+
PWD = os.path.dirname(__file__)
|
8 |
+
|
9 |
+
INFERENCE_MODULE_DIR = "/home/indicTrans2/"
|
10 |
+
sys.path.insert(0, INFERENCE_MODULE_DIR)
|
11 |
+
from inference.engine import Model, iso_to_flores
|
12 |
+
INDIC_LANGUAGES = set(iso_to_flores)
|
13 |
+
|
14 |
+
ALLOWED_DIRECTION_STRINGS = {"en-indic", "indic-en", "indic-indic"}
|
15 |
+
FORCE_PIVOTING = False
|
16 |
+
DEFAULT_PIVOT_LANG = "en"
|
17 |
+
|
18 |
+
class TritonPythonModel:
|
19 |
+
def initialize(self, args):
|
20 |
+
self.model_config = json.loads(args['model_config'])
|
21 |
+
self.model_instance_device_id = json.loads(args['model_instance_device_id'])
|
22 |
+
self.output_name = "OUTPUT_TEXT"
|
23 |
+
self.output_dtype = pb_utils.triton_string_to_numpy(
|
24 |
+
pb_utils.get_output_config_by_name(self.model_config, self.output_name)["data_type"])
|
25 |
+
|
26 |
+
|
27 |
+
# checkpoints_root_dir = os.path.join(PWD, "checkpoints")
|
28 |
+
checkpoints_root_dir = "/models/checkpoints"
|
29 |
+
checkpoint_folders = [ f.path for f in os.scandir(checkpoints_root_dir) if f.is_dir() ]
|
30 |
+
# The assumption is that, each folder name is `<src_direction>-to-<tgt_direction>`
|
31 |
+
|
32 |
+
if not checkpoint_folders:
|
33 |
+
raise RuntimeError(f"No checkpoint folders in: {checkpoints_root_dir}")
|
34 |
+
|
35 |
+
self.models = {}
|
36 |
+
for checkpoint_folder in checkpoint_folders:
|
37 |
+
direction_string = os.path.basename(checkpoint_folder)
|
38 |
+
assert direction_string in ALLOWED_DIRECTION_STRINGS, f"Checkpoint folder-name `{direction_string}` not allowed"
|
39 |
+
self.models[direction_string] = Model(os.path.join(checkpoint_folder, "ct2_fp16_model"), input_lang_code_format="iso", model_type="ctranslate2")
|
40 |
+
# self.models[direction_string] = Model(checkpoint_folder, input_lang_code_format="iso", model_type="fairseq")
|
41 |
+
|
42 |
+
self.pivot_lang = None
|
43 |
+
if "en-indic" in self.models and "indic-en" in self.models:
|
44 |
+
if "indic-indic" not in self.models:
|
45 |
+
self.pivot_lang = DEFAULT_PIVOT_LANG
|
46 |
+
elif FORCE_PIVOTING:
|
47 |
+
del self.models["indic-indic"]
|
48 |
+
self.pivot_lang = DEFAULT_PIVOT_LANG
|
49 |
+
|
50 |
+
def get_direction_string(self, input_language_id, output_language_id):
|
51 |
+
direction_string = None
|
52 |
+
if input_language_id == DEFAULT_PIVOT_LANG and output_language_id in INDIC_LANGUAGES:
|
53 |
+
direction_string = "en-indic"
|
54 |
+
elif input_language_id in INDIC_LANGUAGES:
|
55 |
+
if output_language_id == DEFAULT_PIVOT_LANG:
|
56 |
+
direction_string = "indic-en"
|
57 |
+
elif output_language_id in INDIC_LANGUAGES:
|
58 |
+
direction_string = "indic-indic"
|
59 |
+
return direction_string
|
60 |
+
|
61 |
+
def get_model(self, input_language_id, output_language_id):
|
62 |
+
direction_string = self.get_direction_string(input_language_id, output_language_id)
|
63 |
+
|
64 |
+
if direction_string in self.models:
|
65 |
+
return self.models[direction_string]
|
66 |
+
raise RuntimeError(f"Language-pair not supported: {input_language_id}-{output_language_id}")
|
67 |
+
|
68 |
+
def execute(self,requests):
|
69 |
+
# print("REQ_COUNT", len(requests))
|
70 |
+
modelwise_batches = {}
|
71 |
+
responses = []
|
72 |
+
for request_id, request in enumerate(requests):
|
73 |
+
input_text_batch = pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT").as_numpy()
|
74 |
+
input_language_id_batch = pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID").as_numpy()
|
75 |
+
output_language_id_batch = pb_utils.get_input_tensor_by_name(request, "OUTPUT_LANGUAGE_ID").as_numpy()
|
76 |
+
|
77 |
+
input_text_batch = [input_text[0].decode("utf-8", "ignore") for input_text in input_text_batch]
|
78 |
+
input_language_id_batch = [input_language_id[0].decode("utf-8", "ignore") for input_language_id in input_language_id_batch]
|
79 |
+
output_language_id_batch = [output_language_id[0].decode("utf-8", "ignore") for output_language_id in output_language_id_batch]
|
80 |
+
|
81 |
+
responses.append([['']] * len(input_text_batch))
|
82 |
+
|
83 |
+
for input_id, (input_text, input_language_id, output_language_id) in enumerate(zip(input_text_batch, input_language_id_batch, output_language_id_batch)):
|
84 |
+
direction_string = self.get_direction_string(input_language_id, output_language_id)
|
85 |
+
if direction_string not in self.models:
|
86 |
+
if direction_string == "indic-indic" and self.pivot_lang:
|
87 |
+
pass
|
88 |
+
else:
|
89 |
+
raise RuntimeError(f"Language-pair not supported: {input_language_id}-{output_language_id}")
|
90 |
+
|
91 |
+
if direction_string not in modelwise_batches:
|
92 |
+
modelwise_batches[direction_string] = {
|
93 |
+
"payloads": [],
|
94 |
+
"text_id_to_req_id_input_id": [],
|
95 |
+
}
|
96 |
+
|
97 |
+
modelwise_batches[direction_string]["payloads"].append([input_text, input_language_id, output_language_id])
|
98 |
+
modelwise_batches[direction_string]["text_id_to_req_id_input_id"].append((request_id, input_id))
|
99 |
+
|
100 |
+
for direction_string, batch in modelwise_batches.items():
|
101 |
+
if direction_string == "indic-indic" and self.pivot_lang:
|
102 |
+
model = self.get_model("hi", self.pivot_lang)
|
103 |
+
original_langs = []
|
104 |
+
for i in range(len(batch["payloads"])):
|
105 |
+
original_langs.append(batch["payloads"][i][2])
|
106 |
+
batch["payloads"][i][2] = self.pivot_lang
|
107 |
+
|
108 |
+
pivot_texts = model.paragraphs_batch_translate__multilingual(batch["payloads"])
|
109 |
+
|
110 |
+
for i in range(len(batch["payloads"])):
|
111 |
+
batch["payloads"][i][0] = pivot_texts[i]
|
112 |
+
batch["payloads"][i][1] = self.pivot_lang
|
113 |
+
batch["payloads"][i][2] = original_langs[i]
|
114 |
+
|
115 |
+
model = self.get_model(self.pivot_lang, "hi")
|
116 |
+
translations = model.paragraphs_batch_translate__multilingual(batch["payloads"])
|
117 |
+
else:
|
118 |
+
model = self.models[direction_string]
|
119 |
+
translations = model.paragraphs_batch_translate__multilingual(batch["payloads"])
|
120 |
+
# translations = ["bro"] * len(batch["payloads"])
|
121 |
+
|
122 |
+
for translation, (request_id, output_id) in zip(translations, batch["text_id_to_req_id_input_id"]):
|
123 |
+
responses[request_id][output_id] = [translation]
|
124 |
+
|
125 |
+
for i in range(len(responses)):
|
126 |
+
responses[i] = pb_utils.InferenceResponse(output_tensors=[
|
127 |
+
pb_utils.Tensor(
|
128 |
+
self.output_name,
|
129 |
+
np.array(responses[i], dtype=self.output_dtype),
|
130 |
+
)
|
131 |
+
])
|
132 |
+
return responses
|
133 |
+
|
134 |
+
def execute_sequential(self,requests):
|
135 |
+
# print("REQ_COUNT", len(requests))
|
136 |
+
responses = []
|
137 |
+
for request in requests:
|
138 |
+
input_text_batch = pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT").as_numpy()
|
139 |
+
input_language_id_batch = pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID").as_numpy()
|
140 |
+
output_language_id_batch = pb_utils.get_input_tensor_by_name(request, "OUTPUT_LANGUAGE_ID").as_numpy()
|
141 |
+
|
142 |
+
input_text_batch = [input_text[0].decode("utf-8", "ignore") for input_text in input_text_batch]
|
143 |
+
input_language_id_batch = [input_language_id[0].decode("utf-8", "ignore") for input_language_id in input_language_id_batch]
|
144 |
+
output_language_id_batch = [output_language_id[0].decode("utf-8", "ignore") for output_language_id in output_language_id_batch]
|
145 |
+
|
146 |
+
generated_outputs = []
|
147 |
+
|
148 |
+
for input_text, input_language_id, output_language_id in zip(input_text_batch, input_language_id_batch, output_language_id_batch):
|
149 |
+
if self.pivot_lang and (input_language_id != self.pivot_lang and output_language_id != self.pivot_lang):
|
150 |
+
model = self.get_model(input_language_id, self.pivot_lang)
|
151 |
+
pivot_text = model.translate_paragraph(input_text, input_language_id, self.pivot_lang)
|
152 |
+
|
153 |
+
model = self.get_model(self.pivot_lang, output_language_id)
|
154 |
+
translation = model.translate_paragraph(pivot_text, self.pivot_lang, output_language_id)
|
155 |
+
else:
|
156 |
+
model = self.get_model(input_language_id, output_language_id)
|
157 |
+
translation = model.translate_paragraph(input_text, input_language_id, output_language_id)
|
158 |
+
generated_outputs.append([translation])
|
159 |
+
|
160 |
+
inference_response = pb_utils.InferenceResponse(output_tensors=[
|
161 |
+
pb_utils.Tensor(
|
162 |
+
self.output_name,
|
163 |
+
np.array(generated_outputs, dtype=self.output_dtype),
|
164 |
+
)
|
165 |
+
])
|
166 |
+
responses.append(inference_response)
|
167 |
+
return responses
|
IndicTrans2/inference/triton_server/triton_repo/nmt/config.pbtxt
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
backend: "python"
|
2 |
+
max_batch_size: 512
|
3 |
+
input [{
|
4 |
+
name: "INPUT_TEXT"
|
5 |
+
data_type: TYPE_STRING
|
6 |
+
dims: 1
|
7 |
+
},
|
8 |
+
{
|
9 |
+
name: "INPUT_LANGUAGE_ID"
|
10 |
+
data_type: TYPE_STRING
|
11 |
+
dims: 1
|
12 |
+
},
|
13 |
+
{
|
14 |
+
name: "OUTPUT_LANGUAGE_ID"
|
15 |
+
data_type: TYPE_STRING
|
16 |
+
dims: 1
|
17 |
+
}]
|
18 |
+
|
19 |
+
output {
|
20 |
+
name: "OUTPUT_TEXT"
|
21 |
+
data_type: TYPE_STRING
|
22 |
+
dims: 1
|
23 |
+
}
|
24 |
+
|
25 |
+
dynamic_batching {
|
26 |
+
|
27 |
+
}
|
28 |
+
|
29 |
+
instance_group [{
|
30 |
+
count: 1
|
31 |
+
kind: KIND_GPU
|
32 |
+
}]
|
IndicTrans2/inference/utils.map_token_lang.tsv
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
asm_Beng hi
|
2 |
+
ben_Beng hi
|
3 |
+
brx_Deva hi
|
4 |
+
doi_Deva hi
|
5 |
+
gom_Deva hi
|
6 |
+
eng_Latn en
|
7 |
+
guj_Gujr hi
|
8 |
+
hin_Deva hi
|
9 |
+
kan_Knda hi
|
10 |
+
kas_Arab ar
|
11 |
+
kas_Deva hi
|
12 |
+
mai_Deva hi
|
13 |
+
mar_Deva hi
|
14 |
+
mal_Mlym hi
|
15 |
+
mni_Beng hi
|
16 |
+
mni_Mtei en
|
17 |
+
npi_Deva hi
|
18 |
+
ory_Orya hi
|
19 |
+
pan_Guru hi
|
20 |
+
san_Deva hi
|
21 |
+
sat_Olck hi
|
22 |
+
snd_Arab ar
|
23 |
+
snd_Deva hi
|
24 |
+
tam_Taml hi
|
25 |
+
tel_Telu hi
|
26 |
+
urd_Arab ar
|