Esmail Atta Gumaan
commited on
Commit
โข
cd89176
1
Parent(s):
df91c19
Upload 7 files
Browse files- TrainLlTRA.ipynb +1 -0
- configuration.py +32 -0
- dataset.py +78 -0
- model.py +221 -0
- tokenizer_ar.json +0 -0
- tokenizer_en.json +0 -0
- train.py +203 -0
TrainLlTRA.ipynb
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"cells":[{"cell_type":"code","execution_count":2,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":14916,"status":"ok","timestamp":1704900816235,"user":{"displayName":"Esmail Atta","userId":"01595736708541341728"},"user_tz":-180},"id":"u40Mwfxw3KgY","outputId":"be4f239a-c096-40e8-b182-3d13d150bc59"},"outputs":[{"name":"stdout","output_type":"stream","text":["Collecting datasets\n"," Downloading datasets-2.16.1-py3-none-any.whl (507 kB)\n","\u001b[2K \u001b[90mโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ\u001b[0m \u001b[32m507.1/507.1 kB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.13.1)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.23.5)\n","Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (10.0.1)\n","Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6)\n","Collecting dill<0.3.8,>=0.3.0 (from datasets)\n"," Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n","\u001b[2K \u001b[90mโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m7.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n","Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n","Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.1)\n","Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n","Collecting multiprocess (from datasets)\n"," Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)\n","\u001b[2K \u001b[90mโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n","Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.1)\n","Requirement already satisfied: huggingface-hub>=0.19.4 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.20.2)\n","Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.2)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.2.0)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n","Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n","Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.19.4->datasets) (4.5.0)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.3.2)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.6)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.7)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.11.17)\n","Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3.post1)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n","Installing collected packages: dill, multiprocess, datasets\n","Successfully installed datasets-2.16.1 dill-0.3.7 multiprocess-0.70.15\n","Collecting torchmetrics\n"," Downloading torchmetrics-1.2.1-py3-none-any.whl (806 kB)\n","\u001b[2K \u001b[90mโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ\u001b[0m \u001b[32m806.1/806.1 kB\u001b[0m \u001b[31m8.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: numpy>1.20.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (1.23.5)\n","Requirement already satisfied: packaging>17.1 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (23.2)\n","Requirement already satisfied: torch>=1.8.1 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (2.1.0+cu121)\n","Collecting lightning-utilities>=0.8.0 (from torchmetrics)\n"," Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)\n","Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (67.7.2)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (4.5.0)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (3.13.1)\n","Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (1.12)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (3.2.1)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (3.1.2)\n","Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (2023.6.0)\n","Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.8.1->torchmetrics) (2.1.0)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.8.1->torchmetrics) (2.1.3)\n","Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.8.1->torchmetrics) (1.3.0)\n","Installing collected packages: lightning-utilities, torchmetrics\n","Successfully installed lightning-utilities-0.10.0 torchmetrics-1.2.1\n"]}],"source":["!pip install datasets\n","!pip install torchmetrics"]},{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000,"referenced_widgets":["e7ca2b775f364535b8680690debe2474","ef05b268e2474f4c85568fbdde4b6438","4ca3067006f94781b4ba3bccb5ba06dc","3c8a2090fcbf4852ba9725567276f075","2417416e2fac4a128e2f1a9c9d7ec954","31327824719549bf8efa2d5552a5bbb6","87a904f6c0104c6f9a7db082115add8e","4f2aae26e0134d5893bf6d1e5df3e8fa","a0d08f72e8944ac3974ead823bbb77eb","4527d0722a5945ef971c84501a83cb80","42ab868cbda24d0ba8c9edc830ebf0e2","c72162b3d21546dabe5106c1b5a0bc05","488f4ee309054452abbf1723c464a655","7f98711a78904150b33a22304b385790","d02174a09e7a4940b4036d33da452f14","df3b97420e9f44a587e6cce4af770609","a07af88443414f498a391d322448b0a4","99b49687e4644218a4ac3a6e1384c373","8eecef66de34438784767d6f9a85da03","a58c1a159c7144bcae7ad01354915db9","9a7ddd319fb046a5bc84deaa2471f2d5","66f389563fec4f37b40c63b56880fbc7"]},"executionInfo":{"elapsed":13029424,"status":"ok","timestamp":1704913851683,"user":{"displayName":"Esmail Atta","userId":"01595736708541341728"},"user_tz":-180},"id":"J0x_z-BA3R7z","outputId":"97e67505-5e6f-4839-e516-0a0bb515a10f"},"outputs":[{"name":"stdout","output_type":"stream","text":["Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n","Using device: cuda\n","Device name: Tesla T4\n","Device memory: 14.74810791015625 GB\n"]},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"e7ca2b775f364535b8680690debe2474","version_major":2,"version_minor":0},"text/plain":["Downloading data: 0%| | 0.00/2.78M [00:00<?, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"c72162b3d21546dabe5106c1b5a0bc05","version_major":2,"version_minor":0},"text/plain":["Generating train split: 0%| | 0/50769 [00:00<?, ? examples/s]"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Max length of source sentence: 79\n","Max length of target sentence: 83\n","Preloading model opus_infopankki_weights/tmodel_00.pt\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 01: 100%|โโโโโโโโโโ| 5712/5712 [11:16<00:00, 8.45it/s, loss=1.870]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: ุงูุนุงุฆูุงุช ุงููุงุทูุฉ ุจูุบุฉ ุฃุฌูุจูุฉ ูุฏููุง ุงูุญู ูู ุฎุฏู
ุงุช ุงูุชุฑุฌู
ุฉ ุนูุฏ ุงููุฒูู
.\n"," TARGET: A foreign-language family is entitled to interpreting services as necessary.\n"," PREDICTED: in a native language is provided by the services of the services for the elderly .\n","--------------------------------------------------------------------------------\n"," SOURCE: ูู
ูู ูู
ูุงุทูู ุงูุงุชุญุงุฏ ุงูุฃูุฑูุจู (EU) ูุงูู
ูุทูุฉ ุงูุงูุชุตุงุฏูุฉ ุงูุฃูุฑูุจูุฉ (ETA) ุฃู ูุนูููุง ุนู ุฃููุณูู
ูุจุงุญุซูู ุนู ุงูุนู
ู ูุฐูู ุจูุงุณุทุฉ ุงูุฎุฏู
ุฉ ุงูุดุจููุฉ ูู
ูุชุจ ุงูุนู
ู ูุงูู
ูุงุฑุฏ ุงูู
ุนูุดูุฉ ูู ูุณู
\"Oma asiointi\" ุงูู
ุนุงู
ูุงุช ุงูุดุฎุตูุฉ.\n"," TARGET: To find web pages for jobs on the Internet, write โavoimet tyรถpaikatโ (vacancies) in the search engineโs text field. Many web pages for jobs allow you to fill in and send a job application and to enclose your CV.\n"," PREDICTED: of : ( ) and , , and .\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 02: 100%|โโโโโโโโโโ| 5712/5712 [11:17<00:00, 8.43it/s, loss=2.823]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: ุนูุฏู
ุง ุชูุชุญู ุจุงูุฏุฑุงุณุฉุ ูุณุชุญุตู ุนูู ุงูุญู ูู ุฅูุฌุงุฒ ููุชุง ุงูุฏุฑุฌุชูู ุงูุนูู
ูุชูู.\n"," TARGET: When you are accepted into an institute of higher education, you receive the right to complete both degrees.\n"," PREDICTED: When you have a of residence , you will receive a higher education degree .\n","--------------------------------------------------------------------------------\n"," SOURCE: ุนูุฏู
ุง ุงุณุชููุช ููููุฏุง ุณูุฉ 1917ุ ุฃุตุจุญุช ููุณููู ุงูุนุงุตู
ุฉ ูุฌู
ููุฑูุฉ ููููุฏุง.\n"," TARGET: When Finland gained its independence in 1917, Helsinki became the capital of the republic.\n"," PREDICTED: When gained its independence in , the became the capital of .\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 03: 100%|โโโโโโโโโโ| 5712/5712 [11:18<00:00, 8.42it/s, loss=2.646]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: ู
ุฑูุฒ ุงูุถู
ุงู ุงูุชูุงุนุฏู ููุฏู
ุงููุตูุญุฉ ููุ ุนูุฏู
ุง ุชุทูุจ ุงูุชูุงุนุฏ ู
ู ุงูุฎุงุฑุฌ.\n"," TARGET: The Finnish Centre for Pension will give you advice for applying for pension abroad.\n"," PREDICTED: The for will apply for a when you apply for .\n","--------------------------------------------------------------------------------\n"," SOURCE: ุงููุฌูุฉ ูุง ุชุชุฏุงูู ุญุงูุงุช ุงูุชูู
ูุด ูุงูุชู
ููุฒ ุงูู
ุชุนููุฉ ุจุงูุนู
ู.\n"," TARGET: The Tribunal does not handle cases of employment-related discrimination.\n"," PREDICTED: The does not have to pay and the work .\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 04: 100%|โโโโโโโโโโ| 5712/5712 [11:16<00:00, 8.44it/s, loss=2.014]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: ุงููุงุณู
ุงูู
ุดุชุฑู ููุฐู ุงูุบุฑู ูู ุนู
ู ุงููููุฉ ูุฅู
ูุงููุฉ ุงุณุชุนู
ุงู ุงูููู
ุจููุชุฑ ุงูู
ุฌูุฒ ุจุงุดุชุฑุงู ุงูุฅูุชุฑูุช ู
ุฌุงูุงู ูุฐูู ูู
ูู ูู ุนุฏูุฏ ู
ู ูุฐู ุงูุบุฑู ุงูุชู
ุชุน ุจุงููุฌุจุงุช ุงูู
ุญุถุฑุฉ ูู ุงูุบุฑูุฉ ุจุณุนุฑ ู
ุนููู ูุงูุชู ุชูุฏู
ูู ุฃูุงู
ุงูุนู
ู.\n"," TARGET: All centres have a cafe and the opportunity to use a computer with a free-of-charge Internet connection. Most resident centres also offer the opportunity to enjoy an affordable lunch prepared at the centre that is served on weekdays.\n"," PREDICTED: All centres are open to and at a . Many centres also offer an opportunity to use a at a centre in the centre where are .\n","--------------------------------------------------------------------------------\n"," SOURCE: ุชุณุงุนุฏ ูู ู
ุดุงูู ุนูุงูุงุช ุงูุดุฑุงูุฉ ูุงูู
ุดุงูู ุงูุนุงุฆููุฉ ุฃูุถุงู ุงูุงุณุชุดุงุฑุงุช ุงูุนุงุฆููุฉ ูุฃุจุฑุดูุฉ ูุงูุชุง (Vantaa).\n"," TARGET: The family guidance of the Vantaa Parish Union also helps with relationship and family issues.\n"," PREDICTED: and family also help with the .\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 05: 100%|โโโโโโโโโโ| 5712/5712 [11:19<00:00, 8.40it/s, loss=1.728]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: ุงููุฌูุฉ ุงูุฅุณุชุดุงุฑูุฉ ูุชุนุฏุฏ ุงูุซูุงูุงุชุงูููููุฏูุฉ\n"," TARGET: Advisory board for multicultural affairsFinnish\n"," PREDICTED: of\n","--------------------------------------------------------------------------------\n"," SOURCE: ูุชู
ุชูุธูู
ุชุนููู
ุชุญุถูุฑู ููู
ูุงุฌุฑูู ููุชุนููู
ุงูู
ููู ูู ู
ุนูุฏ ุงูุชุนููู
ุงูู
ููู ูุชุนููู
ุงููุจุงุฑ ูุณุชุงุฏู.\n"," TARGET: Helsinki Vocational College organises training that prepares immigrants for vocational education.\n"," PREDICTED: education for education and training training for training and .\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 06: 100%|โโโโโโโโโโ| 5712/5712 [11:16<00:00, 8.44it/s, loss=2.187]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: Metsolakoditุงูููููุฏูุฉ\n"," TARGET: MetsolakoditFinnish\n"," PREDICTED: \n","--------------------------------------------------------------------------------\n"," SOURCE: ูุฌุจ ุนููู ุฃูุถุงู ุฃู ุชุณุชุทูุน ุฅุซุจุงุช ุนูู ุณุจูู ุงูู
ุซุงู ุจุงููุตูุฉ ุงูุทุจูุฉ ุฃู ุจุงูุชูุฑูุฑ ุงูุทุจู ุจุฃู ุงูุบุฑุถ ู
ู ุงูุฏูุงุก ูู ุงุณุชุฎุฏุงู
ู ุฃูุช ุงูุดุฎุตู.\n"," TARGET: In addition, you must be able to prove with a prescription or medical certificate, for example, that the medicine is intended for your personal use.\n"," PREDICTED: You must also have to prove your identity with a friend or friend , for example , that the medicine is intended for your personal use .\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 07: 100%|โโโโโโโโโโ| 5712/5712 [11:17<00:00, 8.44it/s, loss=1.719]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: ูู
ุช ุจุงุฑุชูุงุจ ุฌุฑุงุฆู
ูุชูุนุชุจุฑ ุจุฃูู ุฎุทูุฑ ุนูู ุงููุธุงู
ุฃู ุงูุฃู
ู ุงูุนุงู
.\n"," TARGET: you have committed crimes and are considered a danger to public order or safety\n"," PREDICTED: you have committed crimes and are considered a danger to public order or safety\n","--------------------------------------------------------------------------------\n"," SOURCE: ุฅุฐุง ูุงู ูุฏูู ุชุฑุฎูุต ุฅูุงู
ุฉ ูู ููููุฏุงุ ูููู ูู
ุชูู
ูุญ ุชุฑุฎูุต ุฅูุงู
ุฉ ุงุณุชู
ุฑุงุฑูุ ูุณูู ุชุตุฏุฑ ุฏุงุฆุฑุฉ ุดุคูู ุงููุฌุฑุฉ ูุฑุงุฑุงู ุจุงูุชุฑุญูู.\n"," TARGET: If you already have a residence permit in Finland but are not granted a residence permit extension, the Finnish Immigration Service makes a deportation decision.\n"," PREDICTED: If you have a residence permit in but are not granted a residence permit , the Service makes a decision .\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 08: 100%|โโโโโโโโโโ| 5712/5712 [11:16<00:00, 8.44it/s, loss=1.682]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: ุงูุนูู ุงูุฌูุณู\n"," TARGET: Sexual violence\n"," PREDICTED: violence\n","--------------------------------------------------------------------------------\n"," SOURCE: ููุงุฏูุงูููููุฏูุฉ _ ุงูุณููุฏูุฉ _ ุงูุฅูุฌููุฒูุฉ _ ุงูุฑูุณูุฉ\n"," TARGET: HotelsFinnish _ Swedish _ English _ Russian\n"," PREDICTED: _ Swedish _ English _ Russian\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 09: 100%|โโโโโโโโโโ| 5712/5712 [11:16<00:00, 8.44it/s, loss=1.484]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: ุงูุฃุณุฑ ุงูุชู ูุฏููุง ุฃุทูุงู\n"," TARGET: families with children\n"," PREDICTED: families with children\n","--------------------------------------------------------------------------------\n"," SOURCE: ู
ู ุงูู
ู
ูู ุฃู ูุญุชุงุฌ ุฅูู ุงูู
ุณุงุนุฏุฉ ุฃู ุดุฎุต ูู ุงูุญุงูุงุช ุงูุญุฑุฌุฉ.\n"," TARGET: Anyone in a challenging situation in life can be in need of help.\n"," PREDICTED: in a challenging situation in a family can be in a challenging situation .\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 10: 100%|โโโโโโโโโโ| 5712/5712 [11:18<00:00, 8.42it/s, loss=1.445]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: Welcome Guideุงูุฅูุฌููุฒูุฉ\n"," TARGET: Welcome GuideEnglish\n"," PREDICTED: \n","--------------------------------------------------------------------------------\n"," SOURCE: ุงุฐูุจู ุฅูู ุงูุทุจูุจ\n"," TARGET: See a doctor\n"," PREDICTED: See a doctor\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 11: 100%|โโโโโโโโโโ| 5712/5712 [11:16<00:00, 8.44it/s, loss=1.821]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: ูุฐุง ููุนุฑู ุจุงุณู
ู
ุนุงู
ูุงุช ุฏุจูู (Dublin).\n"," TARGET: This is called the Dublin procedure.\n"," PREDICTED: This is called the .\n","--------------------------------------------------------------------------------\n"," SOURCE: ุนูู ุณุจูู ุงูู
ุซุงูุ ู
ุคุณุณุฉ ุงูุชูุงุนุฏ ุงููุทูู ุงููููุง ูุฏุงุฆุฑุฉ ุดุฆูู ุงููุฌุฑุฉ (Maahanmuuttovirasto) ุชููู
ุจุญุฌุฒ ู
ุชุฑุฌู
ุดููู ููุฒุจูู ูู ุจุนุถ ุงูุญุงูุงุช.\n"," TARGET: Kela and the Finnish Immigration service (Maahanmuuttovirasto), for example, book interpreters for their customers in some cases.\n"," PREDICTED: and the service ( ), for example , book for their in some cases .\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 12: 100%|โโโโโโโโโโ| 5712/5712 [11:15<00:00, 8.46it/s, loss=1.528]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: ู
ุซูุงู. ูุฌุจ ุฃุฑูุงู ุชูุถูุญ ุดุฎุตู ุนู ุงููุถุน ุฃูุถุงู ูู
ูุญู ูุทูุจ ุชุฑุฎูุต ุงูุฅูุงู
ุฉ.\n"," TARGET: You should also attach your account of the circumstances to the residence permit application.\n"," PREDICTED: You should also attach your account of the circumstances to the residence permit application .\n","--------------------------------------------------------------------------------\n"," SOURCE: ุชููู
ุจุงุณุชูุจุงู ุทูุจุงุช ุชุฑุฎูุต ุงูุฅูุงู
ุฉ ุฅูู ููููุฏุง\n"," TARGET: accept residence permit applications for Finland\n"," PREDICTED: accept residence permit for\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 13: 100%|โโโโโโโโโโ| 5712/5712 [11:16<00:00, 8.44it/s, loss=1.461]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: ุฐุง ุจุฏุฃุช ุงูุนู
ู ุฃู ุงูุชุนููู
ุนูุฏู
ุง ุชููู ุนุงุทููุง ุนู ุงูุนู
ูุ ูุฃุจูุบ ุนู ุฐูู ุฅูู ู
ูุชุจ ุงูุนู
ู ูุงูู
ูุงุฑุฏ ุงูู
ุนูุดูุฉ.\n"," TARGET: If you start working or studying while you are unemployed, please notify the TE Office.\n"," PREDICTED: If you start working or studying while you are , please the Office .\n","--------------------------------------------------------------------------------\n"," SOURCE: ูุชุฃููู
ุงูุดุฎุต ุงูุฐู ููุทู ุงูุณูู ุงูู
ุฏุนูู
ู
ุน ุฃุนู
ุงูู ุงูููู
ุฉ ุจุตูุฉ ู
ุณุชููุฉ ุชูุฑูุจุงู.\n"," TARGET: A person living in supported housing copes with everyday tasks almost independently.\n"," PREDICTED: A person living in supported with everyday tasks almost independently .\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 14: 100%|โโโโโโโโโโ| 5712/5712 [11:19<00:00, 8.41it/s, loss=1.442]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: ุชูุนู ุฅุณุชุดุงุฑูุฉ ุงูุชูุธูู
ุงูุฃุณุฑู ูู ุงูู
ุฑูุฒ ุงูุตุญู ุงูุฑุฆูุณู ูู ู
ุงูุชูู
ูู (Mรคntymรคki).\n"," TARGET: The Contraception Clinic is located at the main health centre in Mรคntymรคki.\n"," PREDICTED: The is in the largest cities and in .\n","--------------------------------------------------------------------------------\n"," SOURCE: ุฅุฐุง ูุงูุช ููุงู ุญุงุฌุฉ ูุชูุถูุญุงุช ุฅุถุงููุฉ ูุทูุจูุ ูุณูู ูุชู
ุฅุจูุงุบู ุจุฐูู ุนู ุทุฑูู ุญุณุงุจู ุนูู - Enter Finland.\n"," TARGET: If further clarifications are needed, you will be informed through your account.\n"," PREDICTED: If further are needed , you will be informed through your account .\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 15: 100%|โโโโโโโโโโ| 5712/5712 [11:20<00:00, 8.39it/s, loss=1.403]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: _ ุงูุจูุบุงุฑูุฉ\n"," TARGET: _ Bulgarian\n"," PREDICTED: _\n","--------------------------------------------------------------------------------\n"," SOURCE: ู
ูุธูู
ุฉ ุงูู
ุณุงุนุฏุฉ ุงููุทููุฉ ุงูููููุฏูุฉ ูุถุญุงูุง ุงูุฅุชุฌุงุฑ ุจุงูุจุดุฑ\n"," TARGET: The Finnish national system for assisting victims of human trafficking\n"," PREDICTED: The national national system for assisting victims of human\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 16: 100%|โโโโโโโโโโ| 5712/5712 [11:16<00:00, 8.45it/s, loss=1.457]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: ู
ู ุงูู
ู
ูู ุฃู ูุญุชุงุฌ ุฅูู ุงูู
ุณุงุนุฏุฉ ุฃู ุดุฎุต ูู ุงูุญุงูุงุช ุงูุญุฑุฌุฉ.\n"," TARGET: Anyone in a challenging situation in life can be in need of help.\n"," PREDICTED: in a challenging situation in life can be in need of help .\n","--------------------------------------------------------------------------------\n"," SOURCE: ุดุฑูุท ูููุฏ ุงูุจุฏุก ูู ุนูู ุณุจูู ุงูู
ุซุงู:\n"," TARGET: Prerequisites for getting a start-up grant:\n"," PREDICTED: for getting a start - up grant :\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 17: 100%|โโโโโโโโโโ| 5712/5712 [11:15<00:00, 8.46it/s, loss=1.453]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: ูุงุชู: 0800 414 004\n"," TARGET: Tel. 0800 414 004\n"," PREDICTED: .\n","--------------------------------------------------------------------------------\n"," SOURCE: ุงูุฑุฃ ุงูู
ุฒูุฏ: ุงูุตุญุฉ ุงูุฌูุณูุฉ.\n"," TARGET: Read more: Sexual health.\n"," PREDICTED: Read more : health .\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 18: 100%|โโโโโโโโโโ| 5712/5712 [11:15<00:00, 8.46it/s, loss=1.452]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: ุญุฑุฑ ุฅุนูุงู ุฅูุบุงุก ุนูุฏ ุงูุฅูุฌุงุฑ ูุชุงุจูุงู ุฏุงุฆู
ุงูุ ุฅุฐ ูุฌุจ ุนููู ุฃู ุชููู ุนูู ุงุณุชุนุฏุงุฏ ููุจุฑูุงู ุจุฃูู ูุฏ ุจูุบุช ุงูู
ุคุฌููุฑ ุนู ุงูุฅูุบุงุก.\n"," TARGET: Always make the notice of termination in writing. You must be able to prove that you have given the notice to the landlord.\n"," PREDICTED: Always make the notice of termination in writing . You must be able to prove that you have given the notice to the landlord .\n","--------------------------------------------------------------------------------\n"," SOURCE: ุฅุฐุง ููุช ุชุดู ุจุฃูู ููุนุช ูู ู
ูุงู ุชู
ููุฒ ูุจุฅู
ูุงูู ุงูุชูุงุตู ู
ุน ุณูุทุงุช ุงูุญู
ุงูุฉ ูู ุงูุนู
ู ูุจุงูุฅุชุญุงุฏ ุงูู
ููู ุงูุฎุงุต ุจู.\n"," TARGET: If you suspect that you have been the object of discrimination at work, you can contact the occupational safety and health authorities or your own trade union.\n"," PREDICTED: If you suspect that you have been the object of at work , you can contact the safety and health authorities or your own trade union .\n","--------------------------------------------------------------------------------\n"]},{"name":"stderr","output_type":"stream","text":["Processing Epoch 19: 100%|โโโโโโโโโโ| 5712/5712 [11:18<00:00, 8.42it/s, loss=1.437]\n"]},{"name":"stdout","output_type":"stream","text":["--------------------------------------------------------------------------------\n"," SOURCE: ู
ุชุฒูุฌ ุฃู ูู ุนูุงูุฉ ู
ูุชูุญุฉ ุงูููููุฏูุฉ\n"," TARGET: DivorcingFinnish\n"," PREDICTED: \n","--------------------------------------------------------------------------------\n"," SOURCE: ุงูุฎุฏู
ุฉ ุชูุฏู
ุฃูุถุงู ุฎุฏู
ุงุช ุฏุนู
ุนูู ุฃุดูุงู ู
ุฌู
ูุนุงุช ููุฐูู ุฅู
ูุงููุฉ ุงูู
ูุงุจูุฉ ุงูู
ุฏุนูู
ุฉ/ุงููุงุฆูุฉ ุชุญุช ุงูู
ุฑุงูุจุฉ ููุฐูู ูู ุชุนููู
ููููุฉ ุงูููู
.\n"," TARGET: The service also offers group support services and opportunities for supported/supervised meetings and sleep training.\n"," PREDICTED: The service also offers group support services and opportunities for supported / meetings and sleep training .\n","--------------------------------------------------------------------------------\n"]}],"source":["from google.colab import drive\n","\n","drive.mount('/content/drive')\n","\n","import os \n","\n","os.chdir('/content/drive/MyDrive/LlTRA')\n","\n","%run train.py"]}],"metadata":{"accelerator":"GPU","colab":{"authorship_tag":"ABX9TyOETClsPsvovjgdHvEjmFPX","gpuType":"T4","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"},"widgets":{"application/vnd.jupyter.widget-state+json":{"2417416e2fac4a128e2f1a9c9d7ec954":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"31327824719549bf8efa2d5552a5bbb6":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"3c8a2090fcbf4852ba9725567276f075":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HTMLModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_4527d0722a5945ef971c84501a83cb80","placeholder":"โ","style":"IPY_MODEL_42ab868cbda24d0ba8c9edc830ebf0e2","value":" 2.78M/2.78M [00:00<00:00, 8.24MB/s]"}},"42ab868cbda24d0ba8c9edc830ebf0e2":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"DescriptionStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"4527d0722a5945ef971c84501a83cb80":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"488f4ee309054452abbf1723c464a655":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HTMLModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_a07af88443414f498a391d322448b0a4","placeholder":"โ","style":"IPY_MODEL_99b49687e4644218a4ac3a6e1384c373","value":"Generating train split: 100%"}},"4ca3067006f94781b4ba3bccb5ba06dc":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"FloatProgressModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_4f2aae26e0134d5893bf6d1e5df3e8fa","max":2775475,"min":0,"orientation":"horizontal","style":"IPY_MODEL_a0d08f72e8944ac3974ead823bbb77eb","value":2775475}},"4f2aae26e0134d5893bf6d1e5df3e8fa":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"66f389563fec4f37b40c63b56880fbc7":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"DescriptionStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"7f98711a78904150b33a22304b385790":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"FloatProgressModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_8eecef66de34438784767d6f9a85da03","max":50769,"min":0,"orientation":"horizontal","style":"IPY_MODEL_a58c1a159c7144bcae7ad01354915db9","value":50769}},"87a904f6c0104c6f9a7db082115add8e":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"DescriptionStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"8eecef66de34438784767d6f9a85da03":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"99b49687e4644218a4ac3a6e1384c373":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"DescriptionStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"9a7ddd319fb046a5bc84deaa2471f2d5":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"a07af88443414f498a391d322448b0a4":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"a0d08f72e8944ac3974ead823bbb77eb":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"ProgressStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"a58c1a159c7144bcae7ad01354915db9":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"ProgressStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"c72162b3d21546dabe5106c1b5a0bc05":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HBoxModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_488f4ee309054452abbf1723c464a655","IPY_MODEL_7f98711a78904150b33a22304b385790","IPY_MODEL_d02174a09e7a4940b4036d33da452f14"],"layout":"IPY_MODEL_df3b97420e9f44a587e6cce4af770609"}},"d02174a09e7a4940b4036d33da452f14":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HTMLModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_9a7ddd319fb046a5bc84deaa2471f2d5","placeholder":"โ","style":"IPY_MODEL_66f389563fec4f37b40c63b56880fbc7","value":" 50769/50769 [00:00<00:00, 398935.49 examples/s]"}},"df3b97420e9f44a587e6cce4af770609":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"e7ca2b775f364535b8680690debe2474":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HBoxModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_ef05b268e2474f4c85568fbdde4b6438","IPY_MODEL_4ca3067006f94781b4ba3bccb5ba06dc","IPY_MODEL_3c8a2090fcbf4852ba9725567276f075"],"layout":"IPY_MODEL_2417416e2fac4a128e2f1a9c9d7ec954"}},"ef05b268e2474f4c85568fbdde4b6438":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HTMLModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_31327824719549bf8efa2d5552a5bbb6","placeholder":"โ","style":"IPY_MODEL_87a904f6c0104c6f9a7db082115add8e","value":"Downloading data: 100%"}}}}},"nbformat":4,"nbformat_minor":0}
|
configuration.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
def Get_configuration():
|
4 |
+
return {
|
5 |
+
"batch_size": 8,
|
6 |
+
"num_epochs": 20,
|
7 |
+
"lr": 10**-4,
|
8 |
+
"sequence_length": 100,
|
9 |
+
"d_model": 512,
|
10 |
+
"datasource": 'opus_infopankki',
|
11 |
+
"source_language": "ar",
|
12 |
+
"target_language": "en",
|
13 |
+
"model_folder": "weights",
|
14 |
+
"model_basename": "tmodel_",
|
15 |
+
"preload": "latest",
|
16 |
+
"tokenizer_file": "tokenizer_{0}.json",
|
17 |
+
"experiment_name": "runs/tmodel"
|
18 |
+
}
|
19 |
+
|
20 |
+
def Get_weights_file_path(config, epoch: str):
|
21 |
+
model_folder = f"{config['datasource']}_{config['model_folder']}"
|
22 |
+
model_filename = f"{config['model_basename']}{epoch}.pt"
|
23 |
+
return str(Path('.') / model_folder / model_filename)
|
24 |
+
|
25 |
+
def latest_weights_file_path(config):
|
26 |
+
model_folder = f"{config['datasource']}_{config['model_folder']}"
|
27 |
+
model_filename = f"{config['model_basename']}*"
|
28 |
+
weights_files = list(Path(model_folder).glob(model_filename))
|
29 |
+
if len(weights_files) == 0:
|
30 |
+
return None
|
31 |
+
weights_files.sort()
|
32 |
+
return str(weights_files[-1])
|
dataset.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
|
4 |
+
class BilingualDataset(Dataset):
|
5 |
+
|
6 |
+
def __init__(self, dataset, source_tokenizer, target_tokenizer, source_language, target_language, sequence_length):
|
7 |
+
super().__init__()
|
8 |
+
self.dataset = dataset
|
9 |
+
self.source_tokenizer = source_tokenizer
|
10 |
+
self.target_tokenizer = target_tokenizer
|
11 |
+
self.source_language = source_language
|
12 |
+
self.target_language = target_language
|
13 |
+
self.sequence_length = sequence_length
|
14 |
+
|
15 |
+
self.SOS_token = torch.tensor([target_tokenizer.token_to_id("[SOS]")], dtype=torch.int64)
|
16 |
+
self.PAD_token = torch.tensor([target_tokenizer.token_to_id("[PAD]")], dtype= torch.int64)
|
17 |
+
self.EOS_token = torch.tensor([target_tokenizer.token_to_id("[EOS]")], dtype= torch.int64)
|
18 |
+
|
19 |
+
def __len__(self):
|
20 |
+
return len(self.dataset)
|
21 |
+
|
22 |
+
def __getitem__(self, index) :
|
23 |
+
source_target_dataset = self.dataset[index]
|
24 |
+
source_text = source_target_dataset['translation'][self.source_language]
|
25 |
+
target_text = source_target_dataset['translation'][self.target_language]
|
26 |
+
|
27 |
+
encode_source_tokenizer = self.source_tokenizer.encode(source_text).ids
|
28 |
+
encode_target_tokenizer = self.target_tokenizer.encode(target_text).ids
|
29 |
+
|
30 |
+
encode_source_padding = self.sequence_length - len(encode_source_tokenizer) - 2
|
31 |
+
encode_target_padding = self.sequence_length - len(encode_target_tokenizer) - 1
|
32 |
+
|
33 |
+
if encode_source_padding < 0 or encode_target_padding < 0:
|
34 |
+
raise ValueError("sequence is too long")
|
35 |
+
|
36 |
+
encoder_input = torch.cat(
|
37 |
+
[
|
38 |
+
self.SOS_token,
|
39 |
+
torch.tensor(encode_source_tokenizer, dtype=torch.int64),
|
40 |
+
self.EOS_token,
|
41 |
+
torch.tensor([self.PAD_token] * encode_source_padding, dtype=torch.int64)
|
42 |
+
]
|
43 |
+
)
|
44 |
+
|
45 |
+
decoder_input = torch.cat(
|
46 |
+
[
|
47 |
+
self.SOS_token,
|
48 |
+
torch.tensor(encode_target_tokenizer, dtype=torch.int64),
|
49 |
+
torch.tensor([self.PAD_token] * encode_target_padding, dtype=torch.int64)
|
50 |
+
]
|
51 |
+
)
|
52 |
+
|
53 |
+
Target = torch.cat(
|
54 |
+
[
|
55 |
+
torch.tensor(encode_target_tokenizer, dtype=torch.int64),
|
56 |
+
torch.tensor([self.PAD_token] * encode_target_padding, dtype=torch.int64),
|
57 |
+
self.EOS_token
|
58 |
+
]
|
59 |
+
)
|
60 |
+
|
61 |
+
assert encoder_input.size(0) == self.sequence_length
|
62 |
+
assert decoder_input.size(0) == self.sequence_length
|
63 |
+
assert Target.size(0) == self.sequence_length
|
64 |
+
|
65 |
+
return {
|
66 |
+
"encoder_input": encoder_input,
|
67 |
+
"decoder_input": decoder_input,
|
68 |
+
"encoder_input_mask": (encoder_input != self.PAD_token).unsqueeze(0).unsqueeze(0).int(),
|
69 |
+
"decoder_input_mask": (decoder_input != self.PAD_token).unsqueeze(0).int() & casual_mask(decoder_input.size(0)),
|
70 |
+
"Target": Target,
|
71 |
+
"source_text": source_text,
|
72 |
+
"target_text": target_text
|
73 |
+
}
|
74 |
+
|
75 |
+
|
76 |
+
def casual_mask(size):
|
77 |
+
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
|
78 |
+
return mask == 0
|
model.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#LlTRA = Language to Language Transformer model.
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
class InputEmbeddingsLayer(nn.Module):
|
7 |
+
def __init__(self, d_model: int, vocab_size: int) -> None:
|
8 |
+
super().__init__()
|
9 |
+
self.d_model = d_model
|
10 |
+
self.vocab_size = vocab_size
|
11 |
+
self.embedding = nn.Embedding(vocab_size, d_model)
|
12 |
+
def forward(self, x):
|
13 |
+
return self.embedding(x) * math.sqrt(self.d_model)
|
14 |
+
|
15 |
+
class PositionalEncodingLayer(nn.Module):
|
16 |
+
def __init__(self, d_model: int, sequence_length: int, dropout: float) -> None:
|
17 |
+
super().__init__()
|
18 |
+
self.d_model = d_model
|
19 |
+
self.sequence_length = sequence_length
|
20 |
+
self.dropout = nn.Dropout(dropout)
|
21 |
+
|
22 |
+
PE = torch.zeros(sequence_length, d_model)
|
23 |
+
Position = torch.arange(0, sequence_length, dtype=torch.float).unsqueeze(1)
|
24 |
+
deviation_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
25 |
+
|
26 |
+
PE[:, 0::2] = torch.sin(Position * deviation_term)
|
27 |
+
PE[:, 1::2] = torch.cos(Position * deviation_term)
|
28 |
+
PE = PE.unsqueeze(0)
|
29 |
+
self.register_buffer('PE', PE)
|
30 |
+
def forward(self, x):
|
31 |
+
x = x + (self.PE[:, :x.shape[1], :]).requires_grad_(False)
|
32 |
+
return self.dropout(x)
|
33 |
+
|
34 |
+
class NormalizationLayer(nn.Module):
|
35 |
+
def __init__(self, Epslone: float = 10**-6) -> None:
|
36 |
+
super().__init__()
|
37 |
+
self.Epslone = Epslone
|
38 |
+
self.Alpha = nn.Parameter(torch.ones(1))
|
39 |
+
self.Bias = nn.Parameter(torch.ones(1))
|
40 |
+
def forward(self, x):
|
41 |
+
mean = x.mean(dim = -1, keepdim = True)
|
42 |
+
std = x.std(dim = -1, keepdim = True)
|
43 |
+
|
44 |
+
return self.Alpha * (x - mean) / (std + self.Epslone) + self.Bias
|
45 |
+
|
46 |
+
class FeedForwardBlock(nn.Module):
|
47 |
+
def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
|
48 |
+
super().__init__()
|
49 |
+
self.Linear_1 = nn.Linear(d_model, d_ff)
|
50 |
+
self.dropout = nn.Dropout(dropout)
|
51 |
+
self.Linear_2 = nn.Linear(d_ff, d_model)
|
52 |
+
def forward(self, x):
|
53 |
+
return self.Linear_2(self.dropout(torch.relu(self.Linear_1(x))))
|
54 |
+
|
55 |
+
class MultiHeadAttentionBlock(nn.Module):
|
56 |
+
def __init__(self, d_model: int, heads: int, dropout: float) -> None:
|
57 |
+
super().__init__()
|
58 |
+
self.d_model = d_model
|
59 |
+
self.heads = heads
|
60 |
+
|
61 |
+
assert d_model % heads == 0 , "d_model is not divisible by heads"
|
62 |
+
|
63 |
+
self.d_k = d_model // heads
|
64 |
+
|
65 |
+
self.W_Q = nn.Linear(d_model, d_model)
|
66 |
+
self.W_K = nn.Linear(d_model, d_model)
|
67 |
+
self.W_V = nn.Linear(d_model, d_model)
|
68 |
+
|
69 |
+
self.W_O = nn.Linear(d_model, d_model)
|
70 |
+
self.dropout = nn.Dropout(dropout)
|
71 |
+
|
72 |
+
@staticmethod
|
73 |
+
def Attention(query, key, value, mask, dropout: nn.Dropout):
|
74 |
+
d_k = query.shape[-1]
|
75 |
+
self_attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
|
76 |
+
if mask is not None:
|
77 |
+
self_attention_scores.masked_fill_(mask == 0, -1e9)
|
78 |
+
|
79 |
+
self_attention_scores = self_attention_scores.softmax(dim=-1)
|
80 |
+
if dropout is not None:
|
81 |
+
self_attention_scores = dropout(self_attention_scores)
|
82 |
+
|
83 |
+
return (self_attention_scores @ value), self_attention_scores
|
84 |
+
def forward(self, query, key, value, mask):
|
85 |
+
Query = self.W_Q(query)
|
86 |
+
Key = self.W_K(key)
|
87 |
+
Value = self.W_V(value)
|
88 |
+
|
89 |
+
Query = Query.view(Query.shape[0], Query.shape[1], self.heads, self.d_k).transpose(1,2)
|
90 |
+
Key = Key.view(Key.shape[0], Key.shape[1], self.heads, self.d_k).transpose(1,2)
|
91 |
+
Value = Value.view(Value.shape[0], Value.shape[1], self.heads, self.d_k).transpose(1,2)
|
92 |
+
|
93 |
+
x, self.self_attention_scores = MultiHeadAttentionBlock.Attention(Query, Key, Value, mask, self.dropout)
|
94 |
+
|
95 |
+
x = x.transpose(1,2).contiguous().view(x.shape[0], -1, self.heads * self.d_k)
|
96 |
+
|
97 |
+
return self.W_O(x)
|
98 |
+
|
99 |
+
class ResidualConnection(nn.Module):
|
100 |
+
def __init__(self, dropout: float) -> None:
|
101 |
+
super().__init__()
|
102 |
+
self.dropout = nn.Dropout(dropout)
|
103 |
+
self.normalization = NormalizationLayer()
|
104 |
+
def forward(self, x, subLayer):
|
105 |
+
return x + self.dropout(subLayer(self.normalization(x)))
|
106 |
+
|
107 |
+
class EncoderBlock(nn.Module):
|
108 |
+
def __init__(self, encoder_self_attention_block: MultiHeadAttentionBlock, encoder_feed_forward_block: FeedForwardBlock, dropout: float) -> None:
|
109 |
+
super().__init__()
|
110 |
+
self.encoder_self_attention_block = encoder_self_attention_block
|
111 |
+
self.encoder_feed_forward_block = encoder_feed_forward_block
|
112 |
+
self.residual_connection = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])
|
113 |
+
def forward(self, x, source_mask):
|
114 |
+
x = self.residual_connection[0](x, lambda x: self.encoder_self_attention_block(x, x, x, source_mask))
|
115 |
+
x = self.residual_connection[1](x, self.encoder_feed_forward_block)
|
116 |
+
|
117 |
+
return x
|
118 |
+
|
119 |
+
class Encoder(nn.Module):
|
120 |
+
def __init__(self, Layers: nn.ModuleList) -> None:
|
121 |
+
super().__init__()
|
122 |
+
self.Layers = Layers
|
123 |
+
self.normalization = NormalizationLayer()
|
124 |
+
def forward(self, x, source_mask):
|
125 |
+
for layer in self.Layers:
|
126 |
+
x = layer(x, source_mask)
|
127 |
+
|
128 |
+
return self.normalization(x)
|
129 |
+
|
130 |
+
class DecoderBlock(nn.Module):
|
131 |
+
def __init__(self, decoder_self_attention_block: MultiHeadAttentionBlock, decoder_cross_attention_block: MultiHeadAttentionBlock, decoder_feed_forward_block: FeedForwardBlock, dropout: float) -> None:
|
132 |
+
super().__init__()
|
133 |
+
self.decoder_self_attention_block = decoder_self_attention_block
|
134 |
+
self.decoder_cross_attention_block = decoder_cross_attention_block
|
135 |
+
self.decoder_feed_forward_block = decoder_feed_forward_block
|
136 |
+
self.residual_connection = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])
|
137 |
+
def forward(self, x, Encoder_output, maks, target_mask):
|
138 |
+
x = self.residual_connection[0](x, lambda x: self.decoder_self_attention_block(x, x, x, target_mask))
|
139 |
+
x = self.residual_connection[1](x, lambda x: self.decoder_cross_attention_block(x, Encoder_output, Encoder_output, target_mask))
|
140 |
+
x = self.residual_connection[2](x, self.decoder_feed_forward_block)
|
141 |
+
|
142 |
+
return x
|
143 |
+
|
144 |
+
class Decoder(nn.Module):
|
145 |
+
def __init__(self, Layers: nn.ModuleList) -> None:
|
146 |
+
super().__init__()
|
147 |
+
self.Layers = Layers
|
148 |
+
self.normalization = NormalizationLayer()
|
149 |
+
def forward(self, x, Encoder_output, mask, target_mask):
|
150 |
+
for layer in self.Layers:
|
151 |
+
x = layer(x, Encoder_output, mask, target_mask)
|
152 |
+
|
153 |
+
return self.normalization(x)
|
154 |
+
|
155 |
+
class LinearLayer(nn.Module):
|
156 |
+
def __init__(self, d_model: int, vocab_size: int) -> None:
|
157 |
+
super().__init__()
|
158 |
+
self.Linear = nn.Linear(d_model, vocab_size)
|
159 |
+
def forward(self, x):
|
160 |
+
return self.Linear(x)
|
161 |
+
|
162 |
+
|
163 |
+
class TransformerBlock(nn.Module):
|
164 |
+
def __init__(self, encoder: Encoder, decoder: Decoder, source_embedding: InputEmbeddingsLayer, target_embedding: InputEmbeddingsLayer, source_position: PositionalEncodingLayer, target_position: PositionalEncodingLayer, Linear: LinearLayer) -> None:
|
165 |
+
super().__init__()
|
166 |
+
self.encoder = encoder
|
167 |
+
self.decoder = decoder
|
168 |
+
self.source_embedding = source_embedding
|
169 |
+
self.target_embedding = target_embedding
|
170 |
+
self.source_position = source_position
|
171 |
+
self.target_position = target_position
|
172 |
+
self.Linear = Linear
|
173 |
+
def encode(self, source_language, source_mask):
|
174 |
+
source_language = self.source_embedding(source_language)
|
175 |
+
source_language = self.source_position(source_language)
|
176 |
+
return self.encoder(source_language, source_mask)
|
177 |
+
def decode(self, Encoder_output, mask, target_language, target_mask):
|
178 |
+
target_language = self.target_embedding(target_language)
|
179 |
+
target_language = self.target_position(target_language)
|
180 |
+
return self.decoder(target_language, Encoder_output, mask, target_mask)
|
181 |
+
def linear(self, x):
|
182 |
+
return self.Linear(x)
|
183 |
+
|
184 |
+
|
185 |
+
def TransformerModel(source_vocab_size: int, target_vocab_size: int, source_sequence_length: int, target_sequence_length: int, d_model: int = 512, Layers: int = 6, heads: int = 8, dropout: float = 0.1, d_ff: int = 2048)->TransformerBlock:
|
186 |
+
|
187 |
+
source_embedding = InputEmbeddingsLayer(d_model, source_vocab_size)
|
188 |
+
source_position = PositionalEncodingLayer(d_model, source_sequence_length, dropout)
|
189 |
+
|
190 |
+
target_embedding = InputEmbeddingsLayer(d_model, target_vocab_size)
|
191 |
+
target_position = PositionalEncodingLayer(d_model, target_sequence_length, dropout)
|
192 |
+
|
193 |
+
EncoderBlocks = []
|
194 |
+
for _ in range(Layers):
|
195 |
+
encoder_self_attention_block = MultiHeadAttentionBlock(d_model, heads, dropout)
|
196 |
+
encoder_feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
|
197 |
+
encoder_block = EncoderBlock(encoder_self_attention_block, encoder_feed_forward_block, dropout)
|
198 |
+
EncoderBlocks.append(encoder_block)
|
199 |
+
|
200 |
+
|
201 |
+
DecoderBlocks = []
|
202 |
+
for _ in range(Layers):
|
203 |
+
decoder_self_attention_block = MultiHeadAttentionBlock(d_model, heads, dropout)
|
204 |
+
decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, heads, dropout)
|
205 |
+
decoder_feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
|
206 |
+
decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, decoder_feed_forward_block, dropout)
|
207 |
+
DecoderBlocks.append(decoder_block)
|
208 |
+
|
209 |
+
|
210 |
+
encoder = Encoder(nn.ModuleList(EncoderBlocks))
|
211 |
+
decoder = Decoder(nn.ModuleList(DecoderBlocks))
|
212 |
+
|
213 |
+
linear = LinearLayer(d_model, target_vocab_size)
|
214 |
+
|
215 |
+
Transformer = TransformerBlock(encoder, decoder, source_embedding, target_embedding, source_position, target_position, linear)
|
216 |
+
|
217 |
+
for T in Transformer.parameters():
|
218 |
+
if T.dim() > 1:
|
219 |
+
nn.init.xavier_uniform(T)
|
220 |
+
|
221 |
+
return Transformer
|
tokenizer_ar.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_en.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
train.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.utils.data import Dataset, DataLoader, random_split
|
4 |
+
from torch.utils.tensorboard import SummaryWriter
|
5 |
+
|
6 |
+
from model import TransformerModel
|
7 |
+
from dataset import BilingualDataset, casual_mask
|
8 |
+
from configuration import Get_configuration, Get_weights_file_path, latest_weights_file_path
|
9 |
+
|
10 |
+
from datasets import load_dataset
|
11 |
+
|
12 |
+
from tokenizers import Tokenizer
|
13 |
+
from tokenizers.models import WordLevel
|
14 |
+
from tokenizers.pre_tokenizers import Whitespace
|
15 |
+
from tokenizers.trainers import WordLevelTrainer
|
16 |
+
|
17 |
+
from pathlib import Path
|
18 |
+
|
19 |
+
import warnings
|
20 |
+
from tqdm import tqdm
|
21 |
+
import os
|
22 |
+
|
23 |
+
def greedy_search(model, source, source_mask, source_tokenizer, target_tokenizer, max_len, device):
|
24 |
+
sos_idx = target_tokenizer.token_to_id('[SOS]')
|
25 |
+
eos_idx = target_tokenizer.token_to_id('[EOS]')
|
26 |
+
|
27 |
+
encoder_output = model.encode(source, source_mask)
|
28 |
+
|
29 |
+
decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
|
30 |
+
while True:
|
31 |
+
if decoder_input.size(1) == max_len:
|
32 |
+
break
|
33 |
+
|
34 |
+
decoder_mask = casual_mask(decoder_input.size(1)).type_as(source_mask).to(device)
|
35 |
+
|
36 |
+
out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
|
37 |
+
|
38 |
+
# get next token (get the token with the maximum probabilty)
|
39 |
+
prob = model.linear(out[:, -1])
|
40 |
+
_, next_word = torch.max(prob, dim=1)
|
41 |
+
decoder_input = torch.cat(
|
42 |
+
[decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
|
43 |
+
)
|
44 |
+
|
45 |
+
if next_word == eos_idx:
|
46 |
+
break
|
47 |
+
|
48 |
+
return decoder_input.squeeze(0)
|
49 |
+
|
50 |
+
|
51 |
+
def run_validation(model, validation_ds, source_tokenizer, target_tokenizer, max_len, device, print_msg, global_step, writer, num_examples=2):
|
52 |
+
model.eval()
|
53 |
+
count = 0
|
54 |
+
console_width = 80
|
55 |
+
|
56 |
+
with torch.no_grad():
|
57 |
+
for batch in validation_ds:
|
58 |
+
count += 1
|
59 |
+
encoder_input = batch["encoder_input"].to(device)
|
60 |
+
encoder_mask = batch["encoder_input_mask"].to(device)
|
61 |
+
|
62 |
+
assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"
|
63 |
+
|
64 |
+
model_out = greedy_search(model, encoder_input, encoder_mask, source_tokenizer, target_tokenizer, max_len, device)
|
65 |
+
|
66 |
+
source_text = batch["target_text"][0]
|
67 |
+
target_text = batch["target_text"][0]
|
68 |
+
model_out_text = target_tokenizer.decode(model_out.detach().cpu().numpy())
|
69 |
+
|
70 |
+
print_msg('-'*console_width)
|
71 |
+
print_msg(f"{f'SOURCE: ':>12}{source_text}")
|
72 |
+
print_msg(f"{f'TARGET: ':>12}{target_text}")
|
73 |
+
print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")
|
74 |
+
|
75 |
+
if count == num_examples:
|
76 |
+
break
|
77 |
+
|
78 |
+
def Get_All_Sentences(dataset, language):
|
79 |
+
for lang in dataset:
|
80 |
+
yield lang['translation'][language]
|
81 |
+
|
82 |
+
def Build_Tokenizer(configuration, dataset, language):
|
83 |
+
tokenizer_path = Path(configuration['tokenizer_file'].format(language))
|
84 |
+
if not Path.exists(tokenizer_path):
|
85 |
+
tokenizer = Tokenizer(WordLevel(unk_token= "[UNK]"))
|
86 |
+
tokenizer.pre_tokenizer = Whitespace()
|
87 |
+
trainer = WordLevelTrainer(special_tokens = ["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency = 2)
|
88 |
+
tokenizer.train_from_iterator(Get_All_Sentences(dataset, language), trainer=trainer)
|
89 |
+
tokenizer.save(str(tokenizer_path))
|
90 |
+
else:
|
91 |
+
tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
92 |
+
return tokenizer
|
93 |
+
|
94 |
+
def Get_dataset(configuration):
|
95 |
+
dataset_Raw = load_dataset(f"{configuration['datasource']}", f"{configuration['source_language']}-{configuration['target_language']}", split="train")
|
96 |
+
|
97 |
+
source_tokenizer = Build_Tokenizer(configuration, dataset_Raw, configuration['source_language'])
|
98 |
+
target_tokenizer = Build_Tokenizer(configuration, dataset_Raw, configuration['target_language'])
|
99 |
+
|
100 |
+
train_dataset_Size = int(0.9 * len(dataset_Raw))
|
101 |
+
validation_dataset_Size = len(dataset_Raw) - train_dataset_Size
|
102 |
+
|
103 |
+
train_dataset_Raw, validation_dataset_Raw = random_split(dataset_Raw, [train_dataset_Size, validation_dataset_Size])
|
104 |
+
|
105 |
+
train_dataset = BilingualDataset(train_dataset_Raw, source_tokenizer, target_tokenizer, configuration['source_language'], configuration['target_language'], configuration['sequence_length'])
|
106 |
+
validation_dataset = BilingualDataset(validation_dataset_Raw, source_tokenizer, target_tokenizer, configuration['source_language'], configuration['target_language'], configuration['sequence_length'])
|
107 |
+
|
108 |
+
maximum_source_sequence_length = 0
|
109 |
+
maximum_target_sequence_length = 0
|
110 |
+
|
111 |
+
for item in dataset_Raw:
|
112 |
+
source_id = source_tokenizer.encode(item['translation'][configuration['source_language']]).ids
|
113 |
+
target_id = target_tokenizer.encode(item['translation'][configuration['target_language']]).ids
|
114 |
+
maximum_source_sequence_length = max(maximum_source_sequence_length, len(source_id))
|
115 |
+
maximum_target_sequence_length = max(maximum_target_sequence_length, len(target_id))
|
116 |
+
|
117 |
+
print(f"maximum_source_sequence_length : {maximum_source_sequence_length}")
|
118 |
+
print(f"maximum_target_sequence_length: {maximum_target_sequence_length}")
|
119 |
+
|
120 |
+
train_dataLoader = DataLoader(train_dataset, batch_size= configuration['batch_size'], shuffle=True)
|
121 |
+
validation_dataLoader = DataLoader(validation_dataset, batch_size= 1, shuffle=True)
|
122 |
+
|
123 |
+
return train_dataLoader, validation_dataLoader, source_tokenizer, target_tokenizer
|
124 |
+
|
125 |
+
def Get_model(configuration, source_vocab_size, target_vocab_size):
|
126 |
+
model = TransformerModel(source_vocab_size, target_vocab_size, configuration['sequence_length'], configuration['sequence_length'], configuration['d_model'])
|
127 |
+
return model
|
128 |
+
|
129 |
+
def train_model(configuration):
|
130 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
|
131 |
+
print("Using device:", device)
|
132 |
+
|
133 |
+
Path(f"{configuration['datasource']}_{configuration['model_folder']}").mkdir(parents=True, exist_ok=True)
|
134 |
+
|
135 |
+
train_dataLoader, validation_dataLoader, source_tokenizer, target_tokenizer = Get_dataset(configuration)
|
136 |
+
model = Get_model(configuration, source_tokenizer.get_vocab_size(), target_tokenizer.get_vocab_size()).to(device)
|
137 |
+
|
138 |
+
writer = SummaryWriter(configuration['experiment_name'])
|
139 |
+
|
140 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=configuration['lr'], eps=1e-9)
|
141 |
+
|
142 |
+
initial_epoch = 0
|
143 |
+
global_step = 0
|
144 |
+
preload = configuration['preload']
|
145 |
+
model_filename = latest_weights_file_path(configuration) if preload == 'latest' else Get_weights_file_path(configuration, preload) if preload else None
|
146 |
+
if model_filename:
|
147 |
+
print(f'Preloading model {model_filename}')
|
148 |
+
state = torch.load(model_filename)
|
149 |
+
model.load_state_dict(state['model_state_dict'])
|
150 |
+
initial_epoch = state['epoch'] + 1
|
151 |
+
optimizer.load_state_dict(state['optimizer_state_dict'])
|
152 |
+
global_step = state['global_step']
|
153 |
+
else:
|
154 |
+
print('No model to preload, starting from scratch')
|
155 |
+
|
156 |
+
loss_fn = nn.CrossEntropyLoss(ignore_index=source_tokenizer.token_to_id('[PAD]'), label_smoothing=0.1).to(device)
|
157 |
+
|
158 |
+
for epoch in range(initial_epoch, configuration['num_epochs']):
|
159 |
+
torch.cuda.empty_cache()
|
160 |
+
batch_iterator = tqdm(train_dataLoader, desc=f"Processing Epoch {epoch:02d}")
|
161 |
+
for batch in batch_iterator:
|
162 |
+
model.train()
|
163 |
+
encoder_input = batch['encoder_input'].to(device)
|
164 |
+
decoder_input = batch['decoder_input'].to(device)
|
165 |
+
encoder_mask = batch['encoder_input_mask'].to(device)
|
166 |
+
decoder_mask = batch['encoder_input_mask'].to(device)
|
167 |
+
|
168 |
+
encoder_output = model.encode(encoder_input, encoder_mask)
|
169 |
+
decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
|
170 |
+
proj_output = model.linear(decoder_output)
|
171 |
+
|
172 |
+
Target = batch['Target'].to(device)
|
173 |
+
|
174 |
+
loss = loss_fn(proj_output.view(-1, target_tokenizer.get_vocab_size()), Target.view(-1))
|
175 |
+
batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
|
176 |
+
|
177 |
+
writer.add_scalar('train loss', loss.item(), global_step)
|
178 |
+
writer.flush()
|
179 |
+
|
180 |
+
loss.backward()
|
181 |
+
|
182 |
+
optimizer.step()
|
183 |
+
optimizer.zero_grad(set_to_none=True)
|
184 |
+
|
185 |
+
# run_validation(model, validation_dataLoader, source_tokenizer, target_tokenizer, configuration['sequence_length'], device, lambda msg: batch_iterator.write(msg), global_step, writer)
|
186 |
+
|
187 |
+
global_step += 1
|
188 |
+
|
189 |
+
|
190 |
+
model_filename = Get_weights_file_path(configuration, f"{epoch:02d}")
|
191 |
+
torch.save({
|
192 |
+
'epoch': epoch,
|
193 |
+
'model_state_dict': model.state_dict(),
|
194 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
195 |
+
'global_step': global_step
|
196 |
+
}, model_filename)
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
if __name__ == '__main__':
|
201 |
+
warnings.filterwarnings("ignore")
|
202 |
+
configuration = Get_configuration()
|
203 |
+
train_model(configuration)
|