diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,35 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..768a532aae40bbe0c4c5ad398536cdc8c46fae59
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,7 @@
+deep-learning-models/*
+deep-learning-models.zip
+__MACOSX/*
+__pycache__/*
+poetry_diacritizer/*
+*.pyc
+deep-learning-models.zip:Zone.Identifier
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1b71d2feea99522490bbac0cf7b45b4d3b01157a
--- /dev/null
+++ b/README.md
@@ -0,0 +1,14 @@
+---
+title: Ashaar
+emoji: 🧑🎤
+colorFrom: purple
+colorTo: blue
+sdk: gradio
+sdk_version: 3.35.2
+app_file: app.py
+pinned: false
+license: apache-2.0
+duplicated_from: arbml/Ashaar
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..580d3b353dfe066a53293417f4380121aaa5827b
--- /dev/null
+++ b/app.py
@@ -0,0 +1,151 @@
+import os
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
+import gradio as gr
+from transformers import pipeline
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from Ashaar.utils import get_output_df, get_highlighted_patterns_html
+from Ashaar.bait_analysis import BaitAnalysis
+from langs import *
+import sys
+import json
+import argparse
+
+arg_parser = argparse.ArgumentParser()
+arg_parser.add_argument('--lang', type = str, default = 'ar')
+args = arg_parser.parse_args()
+lang = args.lang
+
+if lang == 'ar':
+ TITLE = TITLE_ar
+ DESCRIPTION = DESCRIPTION_ar
+ textbox_trg_text = textbox_trg_text_ar
+ textbox_inp_text = textbox_inp_text_ar
+ btn_trg_text = btn_trg_text_ar
+ btn_inp_text = btn_inp_text_ar
+ css = """ #textbox{ direction: RTL;}"""
+
+else:
+ TITLE = TITLE_en
+ DESCRIPTION = DESCRIPTION_en
+ textbox_trg_text = textbox_trg_text_en
+ textbox_inp_text = textbox_inp_text_en
+ btn_trg_text = btn_trg_text_en
+ btn_inp_text = btn_inp_text_en
+ css = ""
+
+gpt_tokenizer = AutoTokenizer.from_pretrained('arbml/ashaar_tokenizer')
+model = AutoModelForCausalLM.from_pretrained('arbml/Ashaar_model')
+
+theme_to_token = json.load(open("extra/theme_tokens.json", "r"))
+token_to_theme = {t:m for m,t in theme_to_token.items()}
+meter_to_token = json.load(open("extra/meter_tokens.json", "r"))
+token_to_meter = {t:m for m,t in meter_to_token.items()}
+
+analysis = BaitAnalysis()
+meter, theme, qafiyah = "", "", ""
+
+def analyze(poem):
+ global meter,theme,qafiyah, generate_btn
+ shatrs = poem.split("\n")
+ baits = [' # '.join(shatrs[2*i:2*i+2]) for i in range(len(shatrs)//2)]
+ output = analysis.analyze(baits,override_tashkeel=True)
+ meter = output['meter']
+ qafiyah = output['qafiyah'][0]
+ theme = output['theme'][-1]
+ df = get_output_df(output)
+ return get_highlighted_patterns_html(df), gr.Button.update(interactive=True)
+
+def generate(inputs, top_p = 3):
+ baits = inputs.split('\n')
+ if len(baits) % 2 !=0:
+ baits = baits[:-1]
+ poem = ' '.join(['<|bsep|> '+baits[i]+' <|vsep|> '+baits[i+1]+' |bsep|>' for i in range(0, len(baits), 2)])
+ prompt = f"""
+ {meter_to_token[meter]} {qafiyah} {theme_to_token[theme]}
+ <|psep|>
+ {poem}
+ """.strip()
+ print(prompt)
+ encoded_input = gpt_tokenizer(prompt, return_tensors='pt')
+ output = model.generate(**encoded_input, max_length = 512, top_p = 3, do_sample=True)
+
+ result = ""
+ prev_token = ""
+ line_cnts = 0
+ for i, beam in enumerate(output[:, len(encoded_input.input_ids[0]):]):
+ if line_cnts >= 10:
+ break
+ for token in beam:
+ if line_cnts >= 10:
+ break
+ decoded = gpt_tokenizer.decode(token)
+ if 'meter' in decoded or 'theme' in decoded:
+ break
+ if decoded in ["<|vsep|>", "|bsep|>"]:
+ result += "\n"
+ line_cnts+=1
+ elif decoded in ['<|bsep|>', '<|psep|>', '|psep|>']:
+ pass
+ else:
+ result += decoded
+ prev_token = decoded
+ else:
+ break
+ # return theme+" "+ f"من بحر {meter} مع قافية بحر ({qafiyah})" + "\n" +result
+ return result, gr.Button.update(interactive=False)
+
+examples = [
+ [
+"""القلب أعلم يا عذول بدائه
+وأحق منك بجفنه وبمائه"""
+ ],
+ [
+"""رمتِ الفؤادَ مليحة عذراءُ
+ بسهامِ لحظٍ ما لهنَّ دواءُ"""
+ ],
+ [
+"""أذَلَّ الحِرْصُ والطَّمَعُ الرِّقابَا
+وقَد يَعفو الكَريمُ، إذا استَرَابَا"""
+ ]
+]
+
+with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
+ with gr.Row():
+ with gr.Column():
+ gr.HTML(TITLE)
+ gr.HTML(DESCRIPTION)
+
+ with gr.Row():
+ with gr.Column():
+ textbox_output = gr.Textbox(lines=10, label=textbox_trg_text, elem_id="textbox")
+ with gr.Column():
+ inputs = gr.Textbox(lines=10, label=textbox_inp_text, elem_id="textbox")
+
+
+ with gr.Row():
+ with gr.Column():
+ if lang == 'ar':
+ trg_btn = gr.Button(btn_trg_text, interactive=False)
+ else:
+ trg_btn = gr.Button(btn_trg_text)
+
+ with gr.Column():
+ if lang == 'ar':
+ inp_btn = gr.Button(btn_inp_text)
+ else:
+ inp_btn = gr.Button(btn_inp_text, interactive = False)
+
+ with gr.Row():
+ html_output = gr.HTML()
+
+ if lang == 'en':
+ gr.Examples(examples, textbox_output)
+ inp_btn.click(generate, inputs = textbox_output, outputs=[inputs, inp_btn])
+ trg_btn.click(analyze, inputs = textbox_output, outputs=[html_output,inp_btn])
+ else:
+ gr.Examples(examples, inputs)
+ trg_btn.click(generate, inputs = inputs, outputs=[textbox_output, trg_btn])
+ inp_btn.click(analyze, inputs = inputs, outputs=[html_output,trg_btn] )
+
+# demo.launch(server_name = '0.0.0.0', share=True)
+demo.launch()
\ No newline at end of file
diff --git a/extra/labels.txt b/extra/labels.txt
new file mode 100644
index 0000000000000000000000000000000000000000..83f603a419018c67d26256ed251e78bddf55e8e9
--- /dev/null
+++ b/extra/labels.txt
@@ -0,0 +1,17 @@
+saree
+kamel
+mutakareb
+mutadarak
+munsareh
+madeed
+mujtath
+ramal
+baseet
+khafeef
+taweel
+wafer
+hazaj
+rajaz
+mudhare
+muqtadheb
+prose
\ No newline at end of file
diff --git a/extra/labels_ar.txt b/extra/labels_ar.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6fd7f7e26266ebb7bec65d6c9c08e98d708654be
--- /dev/null
+++ b/extra/labels_ar.txt
@@ -0,0 +1,17 @@
+السريع
+الكامل
+المتقارب
+المتدارك
+المنسرح
+المديد
+المجتث
+الرمل
+البسيط
+الخفيف
+الطويل
+الوافر
+الهزج
+الرجز
+المضارع
+المقتضب
+النثر
\ No newline at end of file
diff --git a/extra/meter_tokens.json b/extra/meter_tokens.json
new file mode 100644
index 0000000000000000000000000000000000000000..07d91d393eb97e4b333cbb64ec3ee0b2483919b5
--- /dev/null
+++ b/extra/meter_tokens.json
@@ -0,0 +1 @@
+{"\u0627\u0644\u062e\u0641\u064a\u0641": "<|meter_0|>", "\u0627\u0644\u0645\u0636\u0627\u0631\u0639": "<|meter_1|>", "\u0627\u0644\u0645\u062c\u062a\u062b": "<|meter_2|>", "\u0627\u0644\u0631\u0645\u0644": "<|meter_3|>", "\u0627\u0644\u0628\u0633\u064a\u0637": "<|meter_4|>", "\u0627\u0644\u0645\u062a\u0642\u0627\u0631\u0628": "<|meter_5|>", "\u0627\u0644\u0648\u0627\u0641\u0631": "<|meter_6|>", "\u0627\u0644\u0645\u0642\u062a\u0636\u0628": "<|meter_7|>", "\u0627\u0644\u0645\u062f\u064a\u062f": "<|meter_8|>", "\u0627\u0644\u0646\u062b\u0631": "<|meter_9|>", "\u0627\u0644\u0647\u0632\u062c": "<|meter_10|>", "\u0627\u0644\u0645\u062a\u062f\u0627\u0631\u0643": "<|meter_11|>", "\u0627\u0644\u0645\u0646\u0633\u0631\u062d": "<|meter_12|>", "\u0627\u0644\u0637\u0648\u064a\u0644": "<|meter_13|>", "\u0627\u0644\u0643\u0627\u0645\u0644": "<|meter_14|>", "\u0627\u0644\u0631\u062c\u0632": "<|meter_15|>", "\u0627\u0644\u0633\u0631\u064a\u0639": "<|meter_16|>"}
\ No newline at end of file
diff --git a/extra/theme_tokens.json b/extra/theme_tokens.json
new file mode 100644
index 0000000000000000000000000000000000000000..b3a78c4f1e05a32405ace65a7ca4556dfe484987
--- /dev/null
+++ b/extra/theme_tokens.json
@@ -0,0 +1 @@
+{"\u0642\u0635\u064a\u062f\u0629 \u0642\u0635\u064a\u0631\u0647": "<|theme_0|>", "\u0642\u0635\u064a\u062f\u0629 \u0645\u062f\u062d": "<|theme_1|>", "\u0642\u0635\u064a\u062f\u0629 \u0648\u0637\u0646\u064a\u0647": "<|theme_2|>", "\u0642\u0635\u064a\u062f\u0629 \u0631\u0648\u0645\u0646\u0633\u064a\u0647": "<|theme_3|>", "\u0642\u0635\u064a\u062f\u0629 \u0647\u062c\u0627\u0621": "<|theme_4|>", "\u0642\u0635\u064a\u062f\u0629 \u0627\u0639\u062a\u0630\u0627\u0631": "<|theme_5|>", "\u0642\u0635\u064a\u062f\u0629 \u0633\u064a\u0627\u0633\u064a\u0629": "<|theme_6|>", "\u0642\u0635\u064a\u062f\u0629 \u0641\u0631\u0627\u0642": "<|theme_7|>", "\u0642\u0635\u064a\u062f\u0629 \u063a\u0632\u0644": "<|theme_8|>", "\u0642\u0635\u064a\u062f\u0629 \u0630\u0645": "<|theme_9|>", "\u0642\u0635\u064a\u062f\u0629 \u0631\u062b\u0627\u0621": "<|theme_10|>", "null": "<|theme_11|>", "\u0642\u0635\u064a\u062f\u0629 \u0634\u0648\u0642": "<|theme_12|>", "\u0642\u0635\u064a\u062f\u0629 \u0627\u0644\u0645\u0639\u0644\u0642\u0627\u062a": "<|theme_13|>", "\u0642\u0635\u064a\u062f\u0629 \u0627\u0644\u0627\u0646\u0627\u0634\u064a\u062f": "<|theme_14|>", "\u0642\u0635\u064a\u062f\u0629 \u062d\u0632\u064a\u0646\u0647": "<|theme_15|>", "\u0642\u0635\u064a\u062f\u0629 \u0639\u062a\u0627\u0628": "<|theme_16|>", "\u0642\u0635\u064a\u062f\u0629 \u0639\u0627\u0645\u0647": "<|theme_17|>", "\u0642\u0635\u064a\u062f\u0629 \u062f\u064a\u0646\u064a\u0629": "<|theme_18|>"}
\ No newline at end of file
diff --git a/extra/theme_tokens.txt b/extra/theme_tokens.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/langs.py b/langs.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce66ea7bb4884344c705c066657646185ff3ebc0
--- /dev/null
+++ b/langs.py
@@ -0,0 +1,59 @@
+IMG = """
+
+
+
+"""
+TITLE_ar="""أَشْعــَـار: تحليل وإنشاء الشعر العربي
"""
+DESCRIPTION_ar = IMG
+
+DESCRIPTION_ar +="""
+هذا البرنامج يتيح للمستخدم تحليل وإنشاء الشعر العربي.
+لإنشاء الشعر العربي تم تدريب نموج يقوم بإستخدام البحر والقافية والعاطفة لإنشاء أكمال للقصيدة بناء على هذه الشروط.
+بالإضافة إلى نموذج إنشاء الشعر يحتوي البرنامج على نماذج لتصنيف الحقبة الزمنية والعاطفة والبحر و كذلك تشكيل الشعر .
+يقوم البرنامج بإستخدام هذه النماذج لإيجاد الخلل في القصيدة من خلال إضافة ألوان معينة تدل على اماكن الخلل.
+لإستخدام البرنامج قم في البداية بكتابة قصيدة تحتوي على عدد زوجي من الأبيات و من ثم قم بالضغط على تحليل ، وبعد إنتهاء التحليل بالإمكان إنشاء إكمال للقصيدة.
+عند الضغط على زر التحليل يتم إنشاء جدول التحليل الذي يشرح العديد من الأشياء :
+
+"""
+DESCRIPTION_ar+= """
+
+ - المشكل : تشكيل كل شطر من القصيدة المدخلة
+ - الكتابة العروضية: وتقوم هذه الكتابة على التعبير عن كل منطوق في اللغة وتبيانه حتى لو لم يكن يكتب إملائياً
+
+ - التفعيلة: تفعيلات القصيدة ، مثالاً : طَويلٌ لَهُ دُونَ البُحورِ فضائل فَعُوْلُنْ مَفَاْعِيْلُنْ فَعُوْلُنْ مَفَاْعِلُ
+
+ - النمط: يحدد حركة وسكون كل حرف في الكتابة العروضية. نستخدم الألوان التالية للرمز إلى خلل في الكتابة العروضية: الأحمر: حرف محذوف، الأزرق: حرف مضاف، الأصفر: حركة مقلوبة.
+
+
+"""
+DESCRIPTION_ar+= """
+قمنا بتوفير الشفرة البرمجية كلها على
+ GitHub.
+
+"""
+
+TITLE_en="""Ashaar: Arabic Poetry Analysis and Generation
"""
+DESCRIPTION_en = IMG
+
+DESCRIPTION_en +="""
+The demo provides a way to generate analysis for poetry and also complete the poetry.
+The generative model is a character-based conditional GPT-2 model. The pipeline contains many models for
+classification, diacritization and conditional generation. Check our GitHub for more techincal details
+about this work. In the demo we have two basic pipelines. Analyze which predicts the meter, era, theme, diacritized text, qafiyah and, arudi style.
+The other module, Generate which takes the input text, meter, theme and qafiyah to generate the full poem.
+"""
+
+btn_trg_text_ar = "إنشاء"
+btn_inp_text_ar = "تحليل"
+
+btn_inp_text_en = "Generate"
+btn_trg_text_en = "Analyze"
+
+textbox_inp_text_ar = "القصيدة المدخلة"
+textbox_trg_text_ar = "القصيدة المنشئة"
+
+textbox_trg_text_en = "Input Poem"
+textbox_inp_text_en = "Generated Poem"
+
+
+
diff --git a/poetry_diacritizer/__init__.py b/poetry_diacritizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..42dcd7aa19e499d4ac240deb5d7e68bcf33795ed
--- /dev/null
+++ b/poetry_diacritizer/__init__.py
@@ -0,0 +1 @@
+from poetry_diacritizer import predict
\ No newline at end of file
diff --git a/poetry_diacritizer/config/ashaar.yml b/poetry_diacritizer/config/ashaar.yml
new file mode 100644
index 0000000000000000000000000000000000000000..aa7db0ed6cb9cf517007892117bac2ad3e5cb029
--- /dev/null
+++ b/poetry_diacritizer/config/ashaar.yml
@@ -0,0 +1,52 @@
+session_name: base
+
+data_directory: "data"
+data_type: "ashaar_proc"
+log_directory: "log_dir_ashaar"
+load_training_data: true
+load_test_data: false
+load_validation_data: true
+n_training_examples: null # null load all training examples, good for fast loading
+n_test_examples: null # null load all test examples
+n_validation_examples: null # null load all validation examples
+test_file_name: "test.csv"
+is_data_preprocessed: false # The data file is organized as (original text | text | diacritics)
+data_separator: '|' # Required if the data already processed
+diacritics_separator: '*' # Required if the data already processed
+text_encoder: ArabicEncoderWithStartSymbol
+text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
+max_len: 600 # sentences larger than this size will not be used
+max_sen_len: null
+
+max_steps: 10000
+learning_rate: 0.001
+batch_size: 32
+adam_beta1: 0.9
+adam_beta2: 0.999
+use_decay: true
+weight_decay: 0.0
+embedding_dim: 256
+use_prenet: false
+prenet_sizes: [512, 256]
+cbhg_projections: [128, 256]
+cbhg_filters: 16
+cbhg_gru_units: 256
+post_cbhg_layers_units: [256, 256]
+post_cbhg_use_batch_norm: true
+
+use_mixed_precision: false
+optimizer_type: Adam
+device: cuda
+
+# LOGGING
+evaluate_frequency: 50000000
+max_eval_batches: 100
+evaluate_with_error_rates_frequency: 1000
+n_predicted_text_tensorboard: 10 # To be written to the tensorboard
+model_save_frequency: 1000
+train_plotting_frequency: 50000000 # No plotting for this model
+n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
+error_rates_n_batches: 10000 # if calculating error rate is slow, then you can specify the number of batches to be calculated
+
+test_model_path: null # load the last saved model
+train_resume_model_path: null # load last saved model
diff --git a/poetry_diacritizer/config/baseline.yml b/poetry_diacritizer/config/baseline.yml
new file mode 100644
index 0000000000000000000000000000000000000000..55865f95099a5bf3a5a36116526de3fc64d9c11e
--- /dev/null
+++ b/poetry_diacritizer/config/baseline.yml
@@ -0,0 +1,47 @@
+session_name: base
+
+data_directory: "data"
+data_type: "CA_MSA"
+log_directory: "log_dir"
+load_training_data: true
+load_test_data: false
+load_validation_data: true
+n_training_examples: null # null load all training examples, good for fast loading
+n_test_examples: null # null load all test examples
+n_validation_examples: null # null load all validation examples
+test_file_name: "test.csv"
+is_data_preprocessed: false # The data file is organized as (original text | text | diacritics)
+data_separator: '|' # Required if the data already processed
+diacritics_separator: '*' # Required if the data already processed
+text_encoder: ArabicEncoderWithStartSymbol
+text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
+max_len: 600 # sentences larger than this size will not be used
+
+
+max_steps: 2_000_000
+learning_rate: 0.001
+batch_size: 64
+adam_beta1: 0.9
+adam_beta2: 0.999
+use_decay: true
+weight_decay: 0.0
+embedding_dim: 512
+n_layers: 3
+layers_units: [256, 256, 256]
+use_mixed_precision: false
+optimizer_type: Adam
+use_batch_norm: False
+device: cuda
+max_sen_len: 256
+
+# LOGGING
+evaluate_frequency: 5000
+evaluate_with_error_rates_frequency: 5000
+n_predicted_text_tensorboard: 10 # To be written to the tensorboard
+model_save_frequency: 5000
+train_plotting_frequency: 50000000 # No plotting for this model
+n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
+error_rates_n_batches: 10000 # if calculating error rate is slow, then you can specify the number of batches to be calculated
+
+test_model_path: null # load the last saved model
+train_resume_model_path: null # load last saved model
diff --git a/poetry_diacritizer/config/cbhg.yml b/poetry_diacritizer/config/cbhg.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d17d29c129066b5e5c202eed30e97f63768ffe0c
--- /dev/null
+++ b/poetry_diacritizer/config/cbhg.yml
@@ -0,0 +1,52 @@
+session_name: base
+
+data_directory: "data"
+data_type: "CA_MSA"
+log_directory: "log_dir_cbhg"
+load_training_data: true
+load_test_data: false
+load_validation_data: true
+n_training_examples: null # null load all training examples, good for fast loading
+n_test_examples: null # null load all test examples
+n_validation_examples: null # null load all validation examples
+test_file_name: "test.csv"
+is_data_preprocessed: false # The data file is organized as (original text | text | diacritics)
+data_separator: '|' # Required if the data already processed
+diacritics_separator: '*' # Required if the data already processed
+text_encoder: ArabicEncoderWithStartSymbol
+text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
+max_len: 600 # sentences larger than this size will not be used
+max_sen_len: null
+
+max_steps: 5000
+learning_rate: 0.001
+batch_size: 32
+adam_beta1: 0.9
+adam_beta2: 0.999
+use_decay: true
+weight_decay: 0.0
+embedding_dim: 256
+use_prenet: false
+prenet_sizes: [512, 256]
+cbhg_projections: [128, 256]
+cbhg_filters: 16
+cbhg_gru_units: 256
+post_cbhg_layers_units: [256, 256]
+post_cbhg_use_batch_norm: true
+
+use_mixed_precision: false
+optimizer_type: Adam
+device: cuda
+
+# LOGGING
+evaluate_frequency: 50000000
+max_eval_batches: 100
+evaluate_with_error_rates_frequency: 1000
+n_predicted_text_tensorboard: 10 # To be written to the tensorboard
+model_save_frequency: 5000
+train_plotting_frequency: 50000000 # No plotting for this model
+n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
+error_rates_n_batches: 10000 # if calculating error rate is slow, then you can specify the number of batches to be calculated
+
+test_model_path: null # load the last saved model
+train_resume_model_path: null # load last saved model
diff --git a/poetry_diacritizer/config/cbhg2.yml b/poetry_diacritizer/config/cbhg2.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ed6a7984d73eeb5c32f510e46d6c569cada0833d
--- /dev/null
+++ b/poetry_diacritizer/config/cbhg2.yml
@@ -0,0 +1,51 @@
+session_name: base
+
+data_directory: "ashaar"
+data_type: "CA_MSA"
+log_directory: "/content/drive/MyDrive/Research/Barmajan/Diacritization/log_ashaar_dir"
+load_training_data: true
+load_test_data: false
+load_validation_data: true
+n_training_examples: null # null load all training examples, good for fast loading
+n_test_examples: null # null load all test examples
+n_validation_examples: null # null load all validation examples
+test_file_name: "test.csv"
+is_data_preprocessed: false # The data file is organized as (original text | text | diacritics)
+data_separator: '|' # Required if the data already processed
+diacritics_separator: '*' # Required if the data already processed
+text_encoder: ArabicEncoderWithStartSymbol
+text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
+max_len: 600 # sentences larger than this size will not be used
+
+
+max_steps: 25_000
+learning_rate: 0.001
+batch_size: 32
+adam_beta1: 0.9
+adam_beta2: 0.999
+use_decay: true
+weight_decay: 0.0
+embedding_dim: 256
+use_prenet: false
+prenet_sizes: [512, 256]
+cbhg_projections: [128, 256]
+cbhg_filters: 16
+cbhg_gru_units: 256
+post_cbhg_layers_units: [256, 256]
+post_cbhg_use_batch_norm: true
+
+use_mixed_precision: false
+optimizer_type: Adam
+device: cuda
+
+# LOGGING
+evaluate_frequency: 1000
+evaluate_with_error_rates_frequency: 1000
+n_predicted_text_tensorboard: 10 # To be written to the tensorboard
+model_save_frequency: 1000
+train_plotting_frequency: 50000000 # No plotting for this model
+n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
+error_rates_n_batches: 10000 # if calculating error rate is slow, then you can specify the number of batches to be calculated
+
+test_model_path: null # load the last saved model
+train_resume_model_path: "/content/drive/MyDrive/Research/Barmajan/Diacritization/log_cleaned_dir/CA_MSA.base.cbhg/models/20000-snapshot.pt" # load last saved model
diff --git a/poetry_diacritizer/config/gpt-0.yml b/poetry_diacritizer/config/gpt-0.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0fddfb248e208d1be3b4338c6383cedddc0b6a09
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-0.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_0
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 0
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-1.yml b/poetry_diacritizer/config/gpt-1.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a6c3bdd0eae691fceeaad53e0552f6653f7c1165
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-1.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_1
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 1
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-2.yml b/poetry_diacritizer/config/gpt-2.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b675a76fe8744754e1b3aa726348856294a05b43
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-2.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_2
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 2
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-3.yml b/poetry_diacritizer/config/gpt-3.yml
new file mode 100644
index 0000000000000000000000000000000000000000..25d6c6324ecf96d968add522cf3a5ad9e6edba8e
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-3.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_3
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 3
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-4.yml b/poetry_diacritizer/config/gpt-4.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9cb3477f921e982dd158bfc20ccc24757b1abbd7
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-4.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_4
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 4
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-5.yml b/poetry_diacritizer/config/gpt-5.yml
new file mode 100644
index 0000000000000000000000000000000000000000..bb8336e17e40952bdbad36ff3bac217ce338cd2b
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-5.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_5
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 5
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-6.yml b/poetry_diacritizer/config/gpt-6.yml
new file mode 100644
index 0000000000000000000000000000000000000000..7c09933aad8055205c791875a9aafc82538b0705
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-6.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_6
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 6
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-7.yml b/poetry_diacritizer/config/gpt-7.yml
new file mode 100644
index 0000000000000000000000000000000000000000..cd416dd30a81bc20e07e830b4ce67507eb723120
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-7.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_7
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 7
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-8.yml b/poetry_diacritizer/config/gpt-8.yml
new file mode 100644
index 0000000000000000000000000000000000000000..1c88a296e1edfa647b020d2c853344504b82fbdf
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-8.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_8
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 8
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-9.yml b/poetry_diacritizer/config/gpt-9.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f6b1b6aa02e3c08a21dc842aa363959abc2144ef
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-9.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_9
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 9
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-0-tash-proc.yml b/poetry_diacritizer/config/gpt-cls-0-tash-proc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..3592f0642c05dc7691069316252dd6aa80fe0a02
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-0-tash-proc.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: tash_proc
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_0_tash_proc
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 0
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-0-test.yml b/poetry_diacritizer/config/gpt-cls-0-test.yml
new file mode 100644
index 0000000000000000000000000000000000000000..7a423ac3851471a77d20a43f38e9eeeb019314a5
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-0-test.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_0_test
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 0
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-0.yml b/poetry_diacritizer/config/gpt-cls-0.yml
new file mode 100644
index 0000000000000000000000000000000000000000..6241a072a02e76e396ac1424ef19c7600c58dd38
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-0.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_0
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 0
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-1-tash-proc.yml b/poetry_diacritizer/config/gpt-cls-1-tash-proc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c566643b400ca2ce143904dba3b6c04ca140df7c
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-1-tash-proc.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: tash_proc
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_1_tash_proc
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 1
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-1.yml b/poetry_diacritizer/config/gpt-cls-1.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a819f33ab470edaf7844df7e0ee234cd1db0d3ac
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-1.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_1
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 1
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-2-tash-proc.yml b/poetry_diacritizer/config/gpt-cls-2-tash-proc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0e7d1305cce3ed24f6d280117191cf3905c3f58b
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-2-tash-proc.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: tash_proc
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_2_tash_proc
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 2
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-2.yml b/poetry_diacritizer/config/gpt-cls-2.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d0950449d9927f204a835d304a8d65a5b346b09a
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-2.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_2
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 2
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-3-tash-proc.yml b/poetry_diacritizer/config/gpt-cls-3-tash-proc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..2c208a90a99c7d87809f0694480df5f035db0451
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-3-tash-proc.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: tash_proc
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_3_tash_proc
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 3
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-3.yml b/poetry_diacritizer/config/gpt-cls-3.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c3367147fc21fade3392118821f09720dd7861f7
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-3.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_3
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 3
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-4-tash-proc.yml b/poetry_diacritizer/config/gpt-cls-4-tash-proc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..fd264ce30e5254a55bbc7a72af1a2c0d5abd055d
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-4-tash-proc.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: tash_proc
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_4_tash_proc
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 4
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-4.yml b/poetry_diacritizer/config/gpt-cls-4.yml
new file mode 100644
index 0000000000000000000000000000000000000000..36df842d376d6556ce49011b166420dbf1954ddc
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-4.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_4
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 4
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-5-tash-proc.yml b/poetry_diacritizer/config/gpt-cls-5-tash-proc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..3ae054175a5e2be59f950196192ca1b6699604a7
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-5-tash-proc.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: tash_proc
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_5_tash_proc
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 5
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-5-test.yml b/poetry_diacritizer/config/gpt-cls-5-test.yml
new file mode 100644
index 0000000000000000000000000000000000000000..75a7b42ab2c11bb5c31e6e8ec7221d4992ba95c9
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-5-test.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: logs/log_dir_cls_5_test
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 5
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-5.yml b/poetry_diacritizer/config/gpt-cls-5.yml
new file mode 100644
index 0000000000000000000000000000000000000000..cc320e5889ef1d551ef20a0ad41c61655e971faf
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-5.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_5
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 5
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-6-tash-proc.yml b/poetry_diacritizer/config/gpt-cls-6-tash-proc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..caf37a3cffbd020d9ee7cd943491048156141141
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-6-tash-proc.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: tash_proc
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_6_tash_proc
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 6
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-6.yml b/poetry_diacritizer/config/gpt-cls-6.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8a1e515604b80fef39302d29292f2ae2c0306930
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-6.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_6
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 6
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-7-tash-proc.yml b/poetry_diacritizer/config/gpt-cls-7-tash-proc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9e3472ee025d6c4904f996b7938ace9b11846a92
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-7-tash-proc.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: tash_proc
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_7_tash_proc
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 7
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-7.yml b/poetry_diacritizer/config/gpt-cls-7.yml
new file mode 100644
index 0000000000000000000000000000000000000000..6dabe95622bd13bdfb9fb4d03194f142bbea9993
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-7.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_7
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 7
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-8-tash-proc.yml b/poetry_diacritizer/config/gpt-cls-8-tash-proc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8d9f50135a1a69076d34ae3e73e0c59d12d011c4
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-8-tash-proc.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: tash_proc
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_8_tash_proc
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 8
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-8.yml b/poetry_diacritizer/config/gpt-cls-8.yml
new file mode 100644
index 0000000000000000000000000000000000000000..6e306620e0381d588b9365d000810cc8ba373b16
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-8.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_8
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 8
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-9-tash-proc.yml b/poetry_diacritizer/config/gpt-cls-9-tash-proc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..2bd7dc01952584e5104abb0c18431b805afdd3e1
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-9-tash-proc.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: tash_proc
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_9_tash_proc
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 9
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-9-test.yml b/poetry_diacritizer/config/gpt-cls-9-test.yml
new file mode 100644
index 0000000000000000000000000000000000000000..363c56591ff1dbb76e0494e7d6e3491bad22fc81
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-9-test.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: logs/log_dir_cls_9_test
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 9
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-9.yml b/poetry_diacritizer/config/gpt-cls-9.yml
new file mode 100644
index 0000000000000000000000000000000000000000..3349fda2f42e4a61d1c171ee5cf39cefa0c6158b
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-9.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_9
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 9
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-cls-tash-proc.yml b/poetry_diacritizer/config/gpt-cls-tash-proc.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f82b338bca16b7d8b02e144805a449fce250f3e4
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-cls-tash-proc.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: tash_proc
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_cls_0_test
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 5000
+model_save_frequency: 5000
+n_layer: 0
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: false
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-lstm-0-50K.yml b/poetry_diacritizer/config/gpt-lstm-0-50K.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f976d4322589406fa08847a77b95e0d294b18ae0
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-lstm-0-50K.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_lstm_0_50K
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 50000
+model_save_frequency: 5000
+n_layer: 0
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-lstm-1-50K.yml b/poetry_diacritizer/config/gpt-lstm-1-50K.yml
new file mode 100644
index 0000000000000000000000000000000000000000..1fa4f8585bed32dd3ce9b1ebcfb0762001cc369a
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-lstm-1-50K.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_lstm_1_50K
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 50000
+model_save_frequency: 5000
+n_layer: 1
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-lstm-2-50K.yml b/poetry_diacritizer/config/gpt-lstm-2-50K.yml
new file mode 100644
index 0000000000000000000000000000000000000000..3a029f2ff7118a12bcea4b568a610a98686d7b53
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-lstm-2-50K.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_lstm_2_50K
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 50000
+model_save_frequency: 5000
+n_layer: 2
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-lstm-3-50K.yml b/poetry_diacritizer/config/gpt-lstm-3-50K.yml
new file mode 100644
index 0000000000000000000000000000000000000000..20694c3341379a0a0de448feec0b03f75854ea0e
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-lstm-3-50K.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_lstm_3_50K
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 50000
+model_save_frequency: 5000
+n_layer: 3
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-lstm-4-50K.yml b/poetry_diacritizer/config/gpt-lstm-4-50K.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d4e295b500ac7fc6ebe29a65aff04edc32abb124
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-lstm-4-50K.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_lstm_4_50K
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 50000
+model_save_frequency: 5000
+n_layer: 4
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-lstm-5-50K.yml b/poetry_diacritizer/config/gpt-lstm-5-50K.yml
new file mode 100644
index 0000000000000000000000000000000000000000..1d5392d1e2d3234e7160fde9ecca1b3e4aee256b
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-lstm-5-50K.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_lstm_5_50K
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 50000
+model_save_frequency: 5000
+n_layer: 5
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-lstm-6-50K.yml b/poetry_diacritizer/config/gpt-lstm-6-50K.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d12cee47afdc10fd94cd88e4c6aa5971c2dba746
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-lstm-6-50K.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_lstm_6_50K
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 50000
+model_save_frequency: 5000
+n_layer: 6
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-lstm-7-50K.yml b/poetry_diacritizer/config/gpt-lstm-7-50K.yml
new file mode 100644
index 0000000000000000000000000000000000000000..5bb815d27b322765aa7ba5a90e09ca47f52e7947
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-lstm-7-50K.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_lstm_7_50K
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 50000
+model_save_frequency: 5000
+n_layer: 7
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-lstm-8-50K.yml b/poetry_diacritizer/config/gpt-lstm-8-50K.yml
new file mode 100644
index 0000000000000000000000000000000000000000..96f0b8a7f485064bde473b5100e2e73692c41427
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-lstm-8-50K.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_lstm_8_50K
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 50000
+model_save_frequency: 5000
+n_layer: 8
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt-lstm-9-50K.yml b/poetry_diacritizer/config/gpt-lstm-9-50K.yml
new file mode 100644
index 0000000000000000000000000000000000000000..291076f09d97b8eb2000f812e9c5f4d6f820e389
--- /dev/null
+++ b/poetry_diacritizer/config/gpt-lstm-9-50K.yml
@@ -0,0 +1,46 @@
+adam_beta1: 0.9
+adam_beta2: 0.999
+base_model_path: ashaar-from-scratch-with-spaces-no-tatweel-epochs-75
+batch_size: 64
+data_directory: data
+data_separator: '|'
+data_type: CA_MSA
+device: cuda
+diacritics_separator: '*'
+error_rates_n_batches: 10000
+evaluate_frequency: 50000000
+evaluate_with_error_rates_frequency: 1000
+freeze: true
+is_data_preprocessed: false
+learning_rate: 0.001
+load_test_data: false
+load_training_data: true
+load_validation_data: true
+log_directory: log_dir_lstm_9_50K
+max_eval_batches: -1
+max_len: 600
+max_sen_len: 256
+max_steps: 50000
+model_save_frequency: 5000
+n_layer: 9
+n_predicted_text_tensorboard: 10
+n_steps_avg_losses:
+- 100
+- 500
+- 1000
+- 5000
+n_test_examples: null
+n_training_examples: null
+n_validation_examples: null
+optimizer_type: Adam
+session_name: base
+test_file_name: test.csv
+test_model_path: null
+text_cleaner: valid_arabic_cleaners
+text_encoder: ArabicEncoderWithStartSymbol
+train_plotting_frequency: 50000000
+train_resume_model_path: null
+use_decay: true
+use_lstm: true
+use_mixed_precision: false
+weight_decay: 0.0
diff --git a/poetry_diacritizer/config/gpt.yml b/poetry_diacritizer/config/gpt.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d66b2ce5b8ec5760136e58b317f46fd97bf6da12
--- /dev/null
+++ b/poetry_diacritizer/config/gpt.yml
@@ -0,0 +1,48 @@
+session_name: base
+
+data_directory: "data"
+data_type: "CA_MSA"
+log_directory: "log_dir"
+base_model_path: "ashaar-from-scratch-with-spaces-no-tatweel-epochs-75"
+load_training_data: true
+load_test_data: false
+load_validation_data: true
+n_training_examples: null # null load all training examples, good for fast loading
+n_test_examples: null # null load all test examples
+n_validation_examples: null # null load all validation examples
+test_file_name: "test.csv"
+is_data_preprocessed: false # The data file is organized as (original text | text | diacritics)
+data_separator: '|' # Required if the data already processed
+diacritics_separator: '*' # Required if the data already processed
+text_encoder: ArabicEncoderWithStartSymbol
+text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
+max_len: 600 # sentences larger than this size will not be used
+
+
+max_steps: 5000
+learning_rate: 0.001
+batch_size: 64
+adam_beta1: 0.9
+adam_beta2: 0.999
+use_decay: true
+weight_decay: 0.0
+use_mixed_precision: false
+optimizer_type: Adam
+device: cuda
+max_sen_len: 256
+freeze: True
+n_layer: -1
+use_lstm: False
+
+# LOGGING
+evaluate_frequency: 50000000
+max_eval_batches: -1
+evaluate_with_error_rates_frequency: 1000
+n_predicted_text_tensorboard: 10 # To be written to the tensorboard
+model_save_frequency: 5000
+train_plotting_frequency: 50000000 # No plotting for this model
+n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
+error_rates_n_batches: 10000 # if calculating error rate is slow, then you can specify the number of batches to be calculated
+
+test_model_path: null # load the last saved model
+train_resume_model_path: null # load last saved model
diff --git a/poetry_diacritizer/config/seq2seq.yml b/poetry_diacritizer/config/seq2seq.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b277c568a36bc7fe250717ac8ea2fe402e2459eb
--- /dev/null
+++ b/poetry_diacritizer/config/seq2seq.yml
@@ -0,0 +1,62 @@
+session_name: base-baseline-encoder
+
+
+data_directory: "data"
+data_type: "CA_MSA"
+log_directory: "log_dir"
+load_training_data: true
+load_test_data: false
+load_validation_data: true
+n_training_examples: null # null load all training examples, good for fast loading
+n_test_examples: null # null load all test examples
+n_validation_examples: null # null load all validation examples
+test_file_name: "test.csv"
+is_data_preprocessed: false # The data file is organized as (original text | text | diacritics)
+data_separator: '|' # Required if the data already processed
+diacritics_separator: '*' # Required if the data already processed
+text_encoder: ArabicEncoderWithStartSymbol
+text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
+max_len: 600 # sentences larger than this size will not be used
+
+max_steps: 2_000_000
+learning_rate: 0.001
+batch_size: 16
+adam_beta1: 0.9
+adam_beta2: 0.999
+use_decay: true
+weight_decay: 0.0
+
+encoder_embedding_dim: 256
+decoder_embedding_dim: 256
+
+encoder_dim: 512 # used by the decoder
+encoder_units: [256, 256, 256]
+use_batch_norm: true
+decoder_units: 256
+decoder_layers: 2
+attention_units: 256
+use_decoder_prenet: true
+teacher_forcing_probability: 0.0
+decoder_prenet_depth: [256, 128]
+is_attention_accumulative: true
+attention_type: LocationSensitive
+
+
+
+use_mixed_precision: false
+optimizer_type: Adam
+text_encoder: ArabicEncoderWithStartSymbol
+text_cleaner: null
+device: cuda
+
+# LOGGING
+evaluate_frequency: 5000
+evaluate_with_error_rates_frequency: 5000
+n_predicted_text_tensorboard: 10 # To be written to the tensorboard
+model_save_frequency: 5000
+train_plotting_frequency: 1000
+n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
+error_rates_n_batches: 10000 # if calculating error rate is slow, then you can specify the number of batches to be calculated
+
+test_model_path: null # load the last saved model
+train_resume_model_path: null # load last saved model
diff --git a/poetry_diacritizer/config/tacotron_based.yml b/poetry_diacritizer/config/tacotron_based.yml
new file mode 100644
index 0000000000000000000000000000000000000000..2fe694f39c7852c8bd768318c797561bb743b302
--- /dev/null
+++ b/poetry_diacritizer/config/tacotron_based.yml
@@ -0,0 +1,62 @@
+session_name: base
+
+data_directory: "data"
+data_type: "CA_MSA"
+log_directory: "log_dir"
+load_training_data: true
+load_test_data: false
+load_validation_data: true
+n_training_examples: null # null load all training examples, good for fast loading
+n_test_examples: null # null load all test examples
+n_validation_examples: null # null load all validation examples
+test_file_name: "test.csv"
+is_data_preprocessed: false # The data file is organized as (original text | text | diacritics)
+data_separator: '|' # Required if the data already processed
+diacritics_separator: '*' # Required if the data already processed
+text_encoder: ArabicEncoderWithStartSymbol
+text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
+max_len: 600 # sentences larger than this size will not be used
+
+
+max_steps: 2_000_000
+learning_rate: 0.001
+warmup_steps: 3000.0
+batch_size: 16
+adam_beta1: 0.9
+adam_beta2: 0.999
+use_decay: true
+weight_decay: 0.0
+CLIP: 1.0
+
+encoder_embedding_dim: 256
+decoder_embedding_dim: 256
+prenet_sizes: [256, 128]
+cbhg_projections: [128, 128]
+cbhg_filters: 16
+cbhg_gru_units: 128
+
+use_encoder_prenet: true
+use_decoder_prenet: true
+encoder_dim: 256
+decoder_units: 256
+decoder_layers: 2
+attention_units: 256
+teacher_forcing_probability: 0.0
+decoder_prenet_depth: [256, 128]
+is_attention_accumulative: true
+attention_type: LocationSensitive
+use_mixed_precision: false
+optimizer_type: Adam
+device: cuda
+
+# LOGGING
+evaluate_frequency: 5000
+evaluate_with_error_rates_frequency: 5000
+n_predicted_text_tensorboard: 10 # To be written to the tensorboard
+model_save_frequency: 5000
+train_plotting_frequency: 1000
+n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
+error_rates_n_batches: 10000 # if calculating error rate is slow, then you can specify the number of batches to be calculated
+
+test_model_path: null # load the last saved model
+train_resume_model_path: null # load last saved model
diff --git a/poetry_diacritizer/config/tashkeela.yml b/poetry_diacritizer/config/tashkeela.yml
new file mode 100644
index 0000000000000000000000000000000000000000..07c5d2fd76a9a9ac5a3b6b3c491f559d662d81ac
--- /dev/null
+++ b/poetry_diacritizer/config/tashkeela.yml
@@ -0,0 +1,52 @@
+session_name: base
+
+data_directory: "data"
+data_type: "tash_proc"
+log_directory: "log_dir_tashkeela"
+load_training_data: true
+load_test_data: false
+load_validation_data: true
+n_training_examples: null # null load all training examples, good for fast loading
+n_test_examples: null # null load all test examples
+n_validation_examples: null # null load all validation examples
+test_file_name: "/home/g201080740/Arabic_Diacritization/data/ashaar_proc/test.csv"
+is_data_preprocessed: false # The data file is organized as (original text | text | diacritics)
+data_separator: '|' # Required if the data already processed
+diacritics_separator: '*' # Required if the data already processed
+text_encoder: ArabicEncoderWithStartSymbol
+text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
+max_len: 600 # sentences larger than this size will not be used
+max_sen_len: null
+
+max_steps: 10000
+learning_rate: 0.001
+batch_size: 32
+adam_beta1: 0.9
+adam_beta2: 0.999
+use_decay: true
+weight_decay: 0.0
+embedding_dim: 256
+use_prenet: false
+prenet_sizes: [512, 256]
+cbhg_projections: [128, 256]
+cbhg_filters: 16
+cbhg_gru_units: 256
+post_cbhg_layers_units: [256, 256]
+post_cbhg_use_batch_norm: true
+
+use_mixed_precision: false
+optimizer_type: Adam
+device: cuda
+
+# LOGGING
+evaluate_frequency: 50000000
+max_eval_batches: 100
+evaluate_with_error_rates_frequency: 1000
+n_predicted_text_tensorboard: 10 # To be written to the tensorboard
+model_save_frequency: 1000
+train_plotting_frequency: 50000000 # No plotting for this model
+n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
+error_rates_n_batches: 10000 # if calculating error rate is slow, then you can specify the number of batches to be calculated
+
+test_model_path: null # load the last saved model
+train_resume_model_path: null # load last saved model
diff --git a/poetry_diacritizer/config/tashkeela_ashaar.yml b/poetry_diacritizer/config/tashkeela_ashaar.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9ebecb402888cf8df2cd75c0df5fb25651d5c3e7
--- /dev/null
+++ b/poetry_diacritizer/config/tashkeela_ashaar.yml
@@ -0,0 +1,52 @@
+session_name: base
+
+data_directory: "data"
+data_type: "ashaar_proc"
+log_directory: "log_dir_tashkeela_ashaar"
+load_training_data: true
+load_test_data: false
+load_validation_data: true
+n_training_examples: null # null load all training examples, good for fast loading
+n_test_examples: null # null load all test examples
+n_validation_examples: null # null load all validation examples
+test_file_name: "test.csv"
+is_data_preprocessed: false # The data file is organized as (original text | text | diacritics)
+data_separator: '|' # Required if the data already processed
+diacritics_separator: '*' # Required if the data already processed
+text_encoder: ArabicEncoderWithStartSymbol
+text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
+max_len: 600 # sentences larger than this size will not be used
+max_sen_len: null
+
+max_steps: 20000
+learning_rate: 0.001
+batch_size: 32
+adam_beta1: 0.9
+adam_beta2: 0.999
+use_decay: true
+weight_decay: 0.0
+embedding_dim: 256
+use_prenet: false
+prenet_sizes: [512, 256]
+cbhg_projections: [128, 256]
+cbhg_filters: 16
+cbhg_gru_units: 256
+post_cbhg_layers_units: [256, 256]
+post_cbhg_use_batch_norm: true
+
+use_mixed_precision: false
+optimizer_type: Adam
+device: cuda
+
+# LOGGING
+evaluate_frequency: 50000000
+max_eval_batches: 100
+evaluate_with_error_rates_frequency: 100
+n_predicted_text_tensorboard: 10 # To be written to the tensorboard
+model_save_frequency: 100
+train_plotting_frequency: 50000000 # No plotting for this model
+n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
+error_rates_n_batches: 100 # if calculating error rate is slow, then you can specify the number of batches to be calculated
+
+test_model_path: null # load the last saved model
+train_resume_model_path: '/home/g201080740/Arabic_Diacritization/log_dir_tashkeela/tash_proc.base.cbhg/models/10000-snapshot.pt' # load last saved model
diff --git a/poetry_diacritizer/config/test.yml b/poetry_diacritizer/config/test.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8cf70eceb3c978fc23677aea50be8f684fa1b4b2
--- /dev/null
+++ b/poetry_diacritizer/config/test.yml
@@ -0,0 +1,52 @@
+session_name: base
+
+data_directory: "data"
+data_type: "ashaar_proc"
+log_directory: "log_dir_ashaar"
+load_training_data: true
+load_test_data: false
+load_validation_data: true
+n_training_examples: null # null load all training examples, good for fast loading
+n_test_examples: null # null load all test examples
+n_validation_examples: null # null load all validation examples
+test_file_name: "test.csv"
+is_data_preprocessed: false # The data file is organized as (original text | text | diacritics)
+data_separator: '|' # Required if the data already processed
+diacritics_separator: '*' # Required if the data already processed
+text_encoder: ArabicEncoderWithStartSymbol
+text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
+max_len: 600 # sentences larger than this size will not be used
+max_sen_len: null
+
+max_steps: 10000
+learning_rate: 0.001
+batch_size: 32
+adam_beta1: 0.9
+adam_beta2: 0.999
+use_decay: true
+weight_decay: 0.0
+embedding_dim: 256
+use_prenet: false
+prenet_sizes: [512, 256]
+cbhg_projections: [128, 256]
+cbhg_filters: 16
+cbhg_gru_units: 256
+post_cbhg_layers_units: [256, 256]
+post_cbhg_use_batch_norm: true
+
+use_mixed_precision: false
+optimizer_type: Adam
+device: cuda
+
+# LOGGING
+evaluate_frequency: 50000000
+max_eval_batches: 100
+evaluate_with_error_rates_frequency: 1000
+n_predicted_text_tensorboard: 10 # To be written to the tensorboard
+model_save_frequency: 5000
+train_plotting_frequency: 50000000 # No plotting for this model
+n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
+error_rates_n_batches: 10000 # if calculating error rate is slow, then you can specify the number of batches to be calculated
+
+test_model_path: null # load the last saved model
+train_resume_model_path: null # load last saved model
diff --git a/poetry_diacritizer/config_manager.py b/poetry_diacritizer/config_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..4473d6017694823444543bc86d7d9e8d0dee6aba
--- /dev/null
+++ b/poetry_diacritizer/config_manager.py
@@ -0,0 +1,350 @@
+from enum import Enum
+import os
+from pathlib import Path
+import shutil
+import subprocess
+from typing import Any, Dict
+
+import ruamel.yaml
+import torch
+
+from poetry_diacritizer.models.baseline import BaseLineModel
+from poetry_diacritizer.models.cbhg import CBHGModel
+from poetry_diacritizer.models.gpt import GPTModel
+from poetry_diacritizer.models.seq2seq import Decoder as Seq2SeqDecoder, Encoder as Seq2SeqEncoder, Seq2Seq
+from poetry_diacritizer.models.tacotron_based import (
+ Decoder as TacotronDecoder,
+ Encoder as TacotronEncoder,
+ Tacotron,
+)
+
+from poetry_diacritizer.options import AttentionType, LossType, OptimizerType
+from poetry_diacritizer.util.text_encoders import (
+ ArabicEncoderWithStartSymbol,
+ BasicArabicEncoder,
+ TextEncoder,
+)
+
+
+class ConfigManager:
+ """Co/home/almodhfer/Projects/daicritization/temp_results/CA_MSA/cbhg-new/model-10.ptnfig Manager"""
+
+ def __init__(self, config_path: str, model_kind: str):
+ available_models = ["baseline", "cbhg", "seq2seq", "tacotron_based", "gpt"]
+ if model_kind not in available_models:
+ raise TypeError(f"model_kind must be in {available_models}")
+ self.config_path = Path(config_path)
+ self.model_kind = model_kind
+ self.yaml = ruamel.yaml.YAML()
+ self.config: Dict[str, Any] = self._load_config()
+ self.git_hash = self._get_git_hash()
+ self.session_name = ".".join(
+ [
+ self.config["data_type"],
+ self.config["session_name"],
+ f"{model_kind}",
+ ]
+ )
+
+ self.data_dir = Path(
+ os.path.join(self.config["data_directory"], self.config["data_type"])
+ )
+ self.base_dir = Path(
+ os.path.join(self.config["log_directory"], self.session_name)
+ )
+ self.log_dir = Path(os.path.join(self.base_dir, "logs"))
+ self.prediction_dir = Path(os.path.join(self.base_dir, "predictions"))
+ self.plot_dir = Path(os.path.join(self.base_dir, "plots"))
+ self.models_dir = Path(os.path.join(self.base_dir, "models"))
+ if "sp_model_path" in self.config:
+ self.sp_model_path = self.config["sp_model_path"]
+ else:
+ self.sp_model_path = None
+ self.text_encoder: TextEncoder = self.get_text_encoder()
+ self.config["len_input_symbols"] = len(self.text_encoder.input_symbols)
+ self.config["len_target_symbols"] = len(self.text_encoder.target_symbols)
+ if self.model_kind in ["seq2seq", "tacotron_based"]:
+ self.config["attention_type"] = AttentionType[self.config["attention_type"]]
+ self.config["optimizer"] = OptimizerType[self.config["optimizer_type"]]
+
+ def _load_config(self):
+ with open(self.config_path, "rb") as model_yaml:
+ _config = self.yaml.load(model_yaml)
+ return _config
+
+ @staticmethod
+ def _get_git_hash():
+ try:
+ return (
+ subprocess.check_output(["git", "describe", "--always"])
+ .strip()
+ .decode()
+ )
+ except Exception as e:
+ print(f"WARNING: could not retrieve git hash. {e}")
+
+ def _check_hash(self):
+ try:
+ git_hash = (
+ subprocess.check_output(["git", "describe", "--always"])
+ .strip()
+ .decode()
+ )
+ if self.config["git_hash"] != git_hash:
+ print(
+ f"""WARNING: git hash mismatch. Current: {git_hash}.
+ Config hash: {self.config['git_hash']}"""
+ )
+ except Exception as e:
+ print(f"WARNING: could not check git hash. {e}")
+
+ @staticmethod
+ def _print_dict_values(values, key_name, level=0, tab_size=2):
+ tab = level * tab_size * " "
+ print(tab + "-", key_name, ":", values)
+
+ def _print_dictionary(self, dictionary, recursion_level=0):
+ for key in dictionary.keys():
+ if isinstance(key, dict):
+ recursion_level += 1
+ self._print_dictionary(dictionary[key], recursion_level)
+ else:
+ self._print_dict_values(
+ dictionary[key], key_name=key, level=recursion_level
+ )
+
+ def print_config(self):
+ print("\nCONFIGURATION", self.session_name)
+ self._print_dictionary(self.config)
+
+ def update_config(self):
+ self.config["git_hash"] = self._get_git_hash()
+
+ def dump_config(self):
+ self.update_config()
+ _config = {}
+ for key, val in self.config.items():
+ if isinstance(val, Enum):
+ _config[key] = val.name
+ else:
+ _config[key] = val
+ with open(self.base_dir / "config.yml", "w") as model_yaml:
+ self.yaml.dump(_config, model_yaml)
+
+ def create_remove_dirs(
+ self,
+ clear_dir: bool = False,
+ clear_logs: bool = False,
+ clear_weights: bool = False,
+ clear_all: bool = False,
+ ):
+ self.base_dir.mkdir(exist_ok=True, parents=True)
+ self.plot_dir.mkdir(exist_ok=True)
+ self.prediction_dir.mkdir(exist_ok=True)
+ if clear_dir:
+ delete = input(f"Delete {self.log_dir} AND {self.models_dir}? (y/[n])")
+ if delete == "y":
+ shutil.rmtree(self.log_dir, ignore_errors=True)
+ shutil.rmtree(self.models_dir, ignore_errors=True)
+ if clear_logs:
+ delete = input(f"Delete {self.log_dir}? (y/[n])")
+ if delete == "y":
+ shutil.rmtree(self.log_dir, ignore_errors=True)
+ if clear_weights:
+ delete = input(f"Delete {self.models_dir}? (y/[n])")
+ if delete == "y":
+ shutil.rmtree(self.models_dir, ignore_errors=True)
+ self.log_dir.mkdir(exist_ok=True)
+ self.models_dir.mkdir(exist_ok=True)
+
+ def get_last_model_path(self):
+ """
+ Given a checkpoint, get the last save model name
+ Args:
+ checkpoint (str): the path where models are saved
+ """
+ models = os.listdir(self.models_dir)
+ models = [model for model in models if model[-3:] == ".pt"]
+ if len(models) == 0:
+ return None
+ _max = max(int(m.split(".")[0].split("-")[0]) for m in models)
+ model_name = f"{_max}-snapshot.pt"
+ last_model_path = os.path.join(self.models_dir, model_name)
+
+ return last_model_path
+
+ def load_model(self, model_path: str = None):
+ """
+ loading a model from path
+ Args:
+ checkpoint (str): the path to the model
+ name (str): the name of the model, which is in the path
+ model (Tacotron): the model to load its save state
+ optimizer: the optimizer to load its saved state
+ """
+
+ model = self.get_model()
+
+ with open(self.base_dir / f"{self.model_kind}_network.txt", "w") as file:
+ file.write(str(model))
+
+ if model_path is None:
+ last_model_path = self.get_last_model_path()
+ if last_model_path is None:
+ return model, 1
+ else:
+ last_model_path = model_path
+
+ saved_model = torch.load(last_model_path)
+ out = model.load_state_dict(saved_model["model_state_dict"])
+ print(out)
+ global_step = saved_model["global_step"] + 1
+ return model, global_step
+
+ def get_model(self, ignore_hash=False):
+ if not ignore_hash:
+ self._check_hash()
+ if self.model_kind == "cbhg":
+ return self.get_cbhg()
+
+ elif self.model_kind == "seq2seq":
+ return self.get_seq2seq()
+
+ elif self.model_kind == "tacotron_based":
+ return self.get_tacotron_based()
+
+ elif self.model_kind == "baseline":
+ return self.get_baseline()
+
+ elif self.model_kind == "gpt":
+ return self.get_gpt()
+
+ def get_gpt(self):
+ model = GPTModel(
+ self.config["base_model_path"],
+ freeze=self.config["freeze"],
+ n_layer=self.config["n_layer"],
+ use_lstm=self.config["use_lstm"],
+ )
+ return model
+
+ def get_baseline(self):
+ model = BaseLineModel(
+ embedding_dim=self.config["embedding_dim"],
+ inp_vocab_size=self.config["len_input_symbols"],
+ targ_vocab_size=self.config["len_target_symbols"],
+ layers_units=self.config["layers_units"],
+ use_batch_norm=self.config["use_batch_norm"],
+ )
+
+ return model
+
+ def get_cbhg(self):
+ model = CBHGModel(
+ embedding_dim=self.config["embedding_dim"],
+ inp_vocab_size=self.config["len_input_symbols"],
+ targ_vocab_size=self.config["len_target_symbols"],
+ use_prenet=self.config["use_prenet"],
+ prenet_sizes=self.config["prenet_sizes"],
+ cbhg_gru_units=self.config["cbhg_gru_units"],
+ cbhg_filters=self.config["cbhg_filters"],
+ cbhg_projections=self.config["cbhg_projections"],
+ post_cbhg_layers_units=self.config["post_cbhg_layers_units"],
+ post_cbhg_use_batch_norm=self.config["post_cbhg_use_batch_norm"],
+ )
+
+ return model
+
+ def get_seq2seq(self):
+ encoder = Seq2SeqEncoder(
+ embedding_dim=self.config["encoder_embedding_dim"],
+ inp_vocab_size=self.config["len_input_symbols"],
+ layers_units=self.config["encoder_units"],
+ use_batch_norm=self.config["use_batch_norm"],
+ )
+
+ decoder = TacotronDecoder(
+ self.config["len_target_symbols"],
+ start_symbol_id=self.text_encoder.start_symbol_id,
+ embedding_dim=self.config["decoder_embedding_dim"],
+ encoder_dim=self.config["encoder_dim"],
+ decoder_units=self.config["decoder_units"],
+ decoder_layers=self.config["decoder_layers"],
+ attention_type=self.config["attention_type"],
+ attention_units=self.config["attention_units"],
+ is_attention_accumulative=self.config["is_attention_accumulative"],
+ use_prenet=self.config["use_decoder_prenet"],
+ prenet_depth=self.config["decoder_prenet_depth"],
+ teacher_forcing_probability=self.config["teacher_forcing_probability"],
+ )
+
+ model = Tacotron(encoder=encoder, decoder=decoder)
+
+ return model
+
+ def get_tacotron_based(self):
+ encoder = TacotronEncoder(
+ embedding_dim=self.config["encoder_embedding_dim"],
+ inp_vocab_size=self.config["len_input_symbols"],
+ prenet_sizes=self.config["prenet_sizes"],
+ use_prenet=self.config["use_encoder_prenet"],
+ cbhg_gru_units=self.config["cbhg_gru_units"],
+ cbhg_filters=self.config["cbhg_filters"],
+ cbhg_projections=self.config["cbhg_projections"],
+ )
+
+ decoder = TacotronDecoder(
+ self.config["len_target_symbols"],
+ start_symbol_id=self.text_encoder.start_symbol_id,
+ embedding_dim=self.config["decoder_embedding_dim"],
+ encoder_dim=self.config["encoder_dim"],
+ decoder_units=self.config["decoder_units"],
+ decoder_layers=self.config["decoder_layers"],
+ attention_type=self.config["attention_type"],
+ attention_units=self.config["attention_units"],
+ is_attention_accumulative=self.config["is_attention_accumulative"],
+ use_prenet=self.config["use_decoder_prenet"],
+ prenet_depth=self.config["decoder_prenet_depth"],
+ teacher_forcing_probability=self.config["teacher_forcing_probability"],
+ )
+
+ model = Tacotron(encoder=encoder, decoder=decoder)
+
+ return model
+
+ def get_text_encoder(self):
+ """Getting the class of TextEncoder from config"""
+ if self.config["text_cleaner"] not in [
+ "basic_cleaners",
+ "valid_arabic_cleaners",
+ None,
+ ]:
+ raise Exception(f"cleaner is not known {self.config['text_cleaner']}")
+
+ if self.config["text_encoder"] == "BasicArabicEncoder":
+ text_encoder = BasicArabicEncoder(
+ cleaner_fn=self.config["text_cleaner"], sp_model_path=self.sp_model_path
+ )
+ elif self.config["text_encoder"] == "ArabicEncoderWithStartSymbol":
+ text_encoder = ArabicEncoderWithStartSymbol(
+ cleaner_fn=self.config["text_cleaner"], sp_model_path=self.sp_model_path
+ )
+ else:
+ raise Exception(
+ f"the text encoder is not found {self.config['text_encoder']}"
+ )
+
+ return text_encoder
+
+ def get_loss_type(self):
+ try:
+ loss_type = LossType[self.config["loss_type"]]
+ except:
+ raise Exception(f"The loss type is not correct {self.config['loss_type']}")
+ return loss_type
+
+
+if __name__ == "__main__":
+ config_path = "config/tacotron-base-config.yml"
+ model_kind = "tacotron"
+ config = ConfigManager(config_path=config_path, model_kind=model_kind)
diff --git a/poetry_diacritizer/create_configs.py b/poetry_diacritizer/create_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd47aacb01ee07c8bc673ff33daff334fe85d0f2
--- /dev/null
+++ b/poetry_diacritizer/create_configs.py
@@ -0,0 +1,13 @@
+import yaml
+
+fname = "config/gpt-cls-tash-proc.yml"
+
+stream = open(fname, 'r')
+data = yaml.load(stream, Loader=yaml.FullLoader)
+
+for i in range(0, 10):
+ data['n_layer'] = i
+ data['log_directory'] = f'log_dir_cls_{i}_tash_proc'
+ data['max_steps'] = 5000
+ with open(f"config/gpt-cls-{i}-tash-proc.yml", 'w') as yaml_file:
+ yaml_file.write( yaml.dump(data, default_flow_style=False))
diff --git a/poetry_diacritizer/dataset.py b/poetry_diacritizer/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..12969e9d2397096258c69dea7363547f3442c2dc
--- /dev/null
+++ b/poetry_diacritizer/dataset.py
@@ -0,0 +1,238 @@
+"""
+Loading the diacritization dataset
+"""
+
+import os
+
+from diacritization_evaluation import util
+import pandas as pd
+import torch
+from torch.utils.data import DataLoader, Dataset
+
+from .config_manager import ConfigManager
+
+BASIC_HARAQAT = {
+ "َ": "Fatha ",
+ "ً": "Fathatah ",
+ "ُ": "Damma ",
+ "ٌ": "Dammatan ",
+ "ِ": "Kasra ",
+ "ٍ": "Kasratan ",
+ "ْ": "Sukun ",
+ "ّ": "Shaddah ",
+}
+
+
+class DiacritizationDataset(Dataset):
+ """
+ The diacritization dataset
+ """
+
+ def __init__(self, config_manager: ConfigManager, list_ids, data):
+ "Initialization"
+ self.list_ids = list_ids
+ self.data = data
+ self.text_encoder = config_manager.text_encoder
+ self.config = config_manager.config
+
+ def __len__(self):
+ "Denotes the total number of samples"
+ return len(self.list_ids)
+
+ def preprocess(self, book):
+ out = ""
+ i = 0
+ while i < len(book):
+ if i < len(book) - 1:
+ if book[i] in BASIC_HARAQAT and book[i + 1] in BASIC_HARAQAT:
+ i += 1
+ continue
+ out += book[i]
+ i += 1
+ return out
+
+ def __getitem__(self, index):
+ "Generates one sample of data"
+ # Select sample
+ id = self.list_ids[index]
+ if self.config["is_data_preprocessed"]:
+ data = self.data.iloc[id]
+ inputs = torch.Tensor(self.text_encoder.input_to_sequence(data[1]))
+ targets = torch.Tensor(
+ self.text_encoder.target_to_sequence(
+ data[2].split(self.config["diacritics_separator"])
+ )
+ )
+ return inputs, targets, data[0]
+
+ data = self.data[id]
+ non_cleaned = data
+
+ data = self.text_encoder.clean(data)
+ data = data[: self.config["max_sen_len"]]
+ text, inputs, diacritics = util.extract_haraqat(data)
+
+ inputs = torch.Tensor(self.text_encoder.input_to_sequence("".join(inputs)))
+ diacritics = torch.Tensor(self.text_encoder.target_to_sequence(diacritics))
+
+ return inputs, diacritics, text
+
+
+def collate_fn(data):
+ """
+ Padding the input and output sequences
+ """
+
+ def merge(sequences):
+ lengths = [len(seq) for seq in sequences]
+ padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
+ for i, seq in enumerate(sequences):
+ end = lengths[i]
+ padded_seqs[i, :end] = seq[:end]
+ return padded_seqs, lengths
+
+ data.sort(key=lambda x: len(x[0]), reverse=True)
+
+ # separate source and target sequences
+ src_seqs, trg_seqs, original = zip(*data)
+
+ # merge sequences (from tuple of 1D tensor to 2D tensor)
+ src_seqs, src_lengths = merge(src_seqs)
+ trg_seqs, trg_lengths = merge(trg_seqs)
+
+ batch = {
+ "original": original,
+ "src": src_seqs,
+ "target": trg_seqs,
+ "lengths": torch.LongTensor(src_lengths), # src_lengths = trg_lengths
+ }
+ return batch
+
+
+def load_training_data(config_manager: ConfigManager, loader_parameters):
+ """
+ Loading the training data using pandas
+ """
+
+ if not config_manager.config["load_training_data"]:
+ return []
+
+ path = os.path.join(config_manager.data_dir, "train.csv")
+ if config_manager.config["is_data_preprocessed"]:
+ train_data = pd.read_csv(
+ path,
+ encoding="utf-8",
+ sep=config_manager.config["data_separator"],
+ nrows=config_manager.config["n_training_examples"],
+ header=None,
+ )
+
+ # train_data = train_data[train_data[0] <= config_manager.config["max_len"]]
+ training_set = DiacritizationDataset(
+ config_manager, train_data.index, train_data
+ )
+ else:
+ with open(path, encoding="utf8") as file:
+ train_data = file.readlines()
+ train_data = [
+ text
+ for text in train_data
+ if len(text) <= config_manager.config["max_len"] and len(text) > 0
+ ]
+ training_set = DiacritizationDataset(
+ config_manager, [idx for idx in range(len(train_data))], train_data
+ )
+
+ train_iterator = DataLoader(
+ training_set, collate_fn=collate_fn, **loader_parameters
+ )
+
+ print(f"Length of training iterator = {len(train_iterator)}")
+ return train_iterator
+
+
+def load_test_data(config_manager: ConfigManager, loader_parameters):
+ """
+ Loading the test data using pandas
+ """
+ if not config_manager.config["load_test_data"]:
+ return []
+ test_file_name = config_manager.config.get("test_file_name", "test.csv")
+ path = os.path.join(config_manager.data_dir, test_file_name)
+ if config_manager.config["is_data_preprocessed"]:
+ test_data = pd.read_csv(
+ path,
+ encoding="utf-8",
+ sep=config_manager.config["data_separator"],
+ nrows=config_manager.config["n_test_examples"],
+ header=None,
+ )
+ # test_data = test_data[test_data[0] <= config_manager.config["max_len"]]
+ test_dataset = DiacritizationDataset(config_manager, test_data.index, test_data)
+ else:
+ with open(path, encoding="utf8") as file:
+ test_data = file.readlines()
+ max_len = config_manager.config["max_len"]
+ test_data = [text[:max_len] for text in test_data]
+ test_dataset = DiacritizationDataset(
+ config_manager, [idx for idx in range(len(test_data))], test_data
+ )
+
+ test_iterator = DataLoader(test_dataset, collate_fn=collate_fn, **loader_parameters)
+
+ print(f"Length of test iterator = {len(test_iterator)}")
+ return test_iterator
+
+
+def load_validation_data(config_manager: ConfigManager, loader_parameters):
+ """
+ Loading the validation data using pandas
+ """
+
+ if not config_manager.config["load_validation_data"]:
+ return []
+ path = os.path.join(config_manager.data_dir, "eval.csv")
+ if config_manager.config["is_data_preprocessed"]:
+ valid_data = pd.read_csv(
+ path,
+ encoding="utf-8",
+ sep=config_manager.config["data_separator"],
+ nrows=config_manager.config["n_validation_examples"],
+ header=None,
+ )
+ valid_data = valid_data[valid_data[0] <= config_manager.config["max_len"]]
+ valid_dataset = DiacritizationDataset(
+ config_manager, valid_data.index, valid_data
+ )
+ else:
+ with open(path, encoding="utf8") as file:
+ valid_data = file.readlines()
+
+ max_len = config_manager.config["max_len"]
+ valid_data = [text[:max_len] for text in valid_data]
+ valid_dataset = DiacritizationDataset(
+ config_manager, [idx for idx in range(len(valid_data))], valid_data
+ )
+
+ valid_iterator = DataLoader(
+ valid_dataset, collate_fn=collate_fn, **loader_parameters
+ )
+
+ print(f"Length of valid iterator = {len(valid_iterator)}")
+ return valid_iterator
+
+
+def load_iterators(config_manager: ConfigManager):
+ """
+ Load the data iterators
+ Args:
+ """
+ params = {
+ "batch_size": config_manager.config["batch_size"],
+ "shuffle": True,
+ "num_workers": 2,
+ }
+ train_iterator = load_training_data(config_manager, loader_parameters=params)
+ valid_iterator = load_validation_data(config_manager, loader_parameters=params)
+ test_iterator = load_test_data(config_manager, loader_parameters=params)
+ return train_iterator, test_iterator, valid_iterator
diff --git a/poetry_diacritizer/diacritize.py b/poetry_diacritizer/diacritize.py
new file mode 100644
index 0000000000000000000000000000000000000000..09314d9a8eb3afa437e69046c112c48e1450b01f
--- /dev/null
+++ b/poetry_diacritizer/diacritize.py
@@ -0,0 +1,36 @@
+import argparse
+from diacritizer import TransformerDiacritizer
+from itertools import repeat
+import random
+
+import numpy as np
+import torch
+
+
+SEED = 1234
+random.seed(SEED)
+np.random.seed(SEED)
+torch.manual_seed(SEED)
+torch.cuda.manual_seed(SEED)
+torch.backends.cudnn.deterministic = True
+torch.backends.cudnn.benchmark = False
+
+
+def diacritization_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_kind", dest="model_kind", type=str, required=True)
+ parser.add_argument("--config", dest="config", type=str, required=True)
+ parser.add_argument("--text", dest="text", type=str, required=True)
+ return parser
+
+
+parser = diacritization_parser()
+args = parser.parse_args()
+
+
+if args.model_kind in ["transformer"]:
+ diacirtizer = TransformerDiacritizer(args.config, args.model_kind)
+else:
+ raise ValueError("The model kind is not supported")
+
+diacirtizer.diacritize_text(args.text)
diff --git a/poetry_diacritizer/diacritizer.py b/poetry_diacritizer/diacritizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..63fc3ed940a81dc560d68781dd4d73357cfc6350
--- /dev/null
+++ b/poetry_diacritizer/diacritizer.py
@@ -0,0 +1,98 @@
+from typing import Dict
+import torch
+from .config_manager import ConfigManager
+
+
+class Diacritizer:
+ def __init__(
+ self, config_path: str, model_kind: str, load_model: bool = False
+ ) -> None:
+ self.config_path = config_path
+ self.model_kind = model_kind
+ self.config_manager = ConfigManager(
+ config_path=config_path, model_kind=model_kind
+ )
+ self.config = self.config_manager.config
+ self.text_encoder = self.config_manager.text_encoder
+ if self.config.get("device"):
+ self.device = self.config["device"]
+ else:
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ if load_model:
+ self.model, self.global_step = self.config_manager.load_model()
+ self.model = self.model.to(self.device)
+
+ self.start_symbol_id = self.text_encoder.start_symbol_id
+
+ def set_model(self, model: torch.nn.Module):
+ self.model = model
+
+ def diacritize_text(self, text: str):
+ seq = self.text_encoder.input_to_sequence(text)
+ output = self.diacritize_batch(torch.LongTensor([seq]).to(self.device))
+
+ def diacritize_batch(self, batch):
+ raise NotImplementedError()
+
+ def diacritize_iterators(self, iterator):
+ pass
+
+
+class CBHGDiacritizer(Diacritizer):
+ def diacritize_batch(self, batch):
+ self.model.eval()
+ inputs = batch["src"]
+ lengths = batch["lengths"]
+ outputs = self.model(inputs.to(self.device), lengths.to("cpu"))
+ diacritics = outputs["diacritics"]
+ predictions = torch.max(diacritics, 2).indices
+ sentences = []
+
+ for src, prediction in zip(inputs, predictions):
+ sentence = self.text_encoder.combine_text_and_haraqat(
+ list(src.detach().cpu().numpy()),
+ list(prediction.detach().cpu().numpy()),
+ )
+ sentences.append(sentence)
+
+ return sentences
+
+
+class Seq2SeqDiacritizer(Diacritizer):
+ def diacritize_batch(self, batch):
+ self.model.eval()
+ inputs = batch["src"]
+ lengths = batch["lengths"]
+ outputs = self.model(inputs.to(self.device), lengths.to("cpu"))
+ diacritics = outputs["diacritics"]
+ predictions = torch.max(diacritics, 2).indices
+ sentences = []
+
+ for src, prediction in zip(inputs, predictions):
+ sentence = self.text_encoder.combine_text_and_haraqat(
+ list(src.detach().cpu().numpy()),
+ list(prediction.detach().cpu().numpy()),
+ )
+ sentences.append(sentence)
+
+ return sentences
+
+class GPTDiacritizer(Diacritizer):
+ def diacritize_batch(self, batch):
+ self.model.eval()
+ inputs = batch["src"]
+ lengths = batch["lengths"]
+ outputs = self.model(inputs.to(self.device), lengths.to("cpu"))
+ diacritics = outputs["diacritics"]
+ predictions = torch.max(diacritics, 2).indices
+ sentences = []
+
+ for src, prediction in zip(inputs, predictions):
+ sentence = self.text_encoder.combine_text_and_haraqat(
+ list(src.detach().cpu().numpy()),
+ list(prediction.detach().cpu().numpy()),
+ )
+ sentences.append(sentence)
+
+ return sentences
diff --git a/poetry_diacritizer/models/__init__.py b/poetry_diacritizer/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..750e4bee526b17e354cbb6dcdda8e5ea759e9634
--- /dev/null
+++ b/poetry_diacritizer/models/__init__.py
@@ -0,0 +1,5 @@
+from . import baseline
+from . import cbhg
+from . import gpt
+from . import seq2seq
+from . import tacotron_based
\ No newline at end of file
diff --git a/poetry_diacritizer/models/__pycache__/__init__.cpython-310.pyc b/poetry_diacritizer/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b44e2e2a05a44627455428319a8d1da595a849c
Binary files /dev/null and b/poetry_diacritizer/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/poetry_diacritizer/models/__pycache__/baseline.cpython-310.pyc b/poetry_diacritizer/models/__pycache__/baseline.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aebe6cbf643fccd468daf5bc18f3998d2432926e
Binary files /dev/null and b/poetry_diacritizer/models/__pycache__/baseline.cpython-310.pyc differ
diff --git a/poetry_diacritizer/models/__pycache__/baseline.cpython-38.pyc b/poetry_diacritizer/models/__pycache__/baseline.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..25a2d5476993f8f063ea5d169c800c96e012e976
Binary files /dev/null and b/poetry_diacritizer/models/__pycache__/baseline.cpython-38.pyc differ
diff --git a/poetry_diacritizer/models/__pycache__/cbhg.cpython-310.pyc b/poetry_diacritizer/models/__pycache__/cbhg.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54c4f644a9c82f3d9a2ab581e2fb0c229cebe79b
Binary files /dev/null and b/poetry_diacritizer/models/__pycache__/cbhg.cpython-310.pyc differ
diff --git a/poetry_diacritizer/models/__pycache__/cbhg.cpython-38.pyc b/poetry_diacritizer/models/__pycache__/cbhg.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c218cfcd6d5a02ad4e79d00a83bb0df147b38db
Binary files /dev/null and b/poetry_diacritizer/models/__pycache__/cbhg.cpython-38.pyc differ
diff --git a/poetry_diacritizer/models/__pycache__/gpt.cpython-310.pyc b/poetry_diacritizer/models/__pycache__/gpt.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a1ca00cb0007f38dd08c72557597e6bc8ae1f43b
Binary files /dev/null and b/poetry_diacritizer/models/__pycache__/gpt.cpython-310.pyc differ
diff --git a/poetry_diacritizer/models/__pycache__/gpt.cpython-38.pyc b/poetry_diacritizer/models/__pycache__/gpt.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a83b710c0bd2798d39636a4525f28bc095b012f8
Binary files /dev/null and b/poetry_diacritizer/models/__pycache__/gpt.cpython-38.pyc differ
diff --git a/poetry_diacritizer/models/__pycache__/gpt_model.cpython-310.pyc b/poetry_diacritizer/models/__pycache__/gpt_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e03dfa21187af6abb94d6f5e7382e82e011a5210
Binary files /dev/null and b/poetry_diacritizer/models/__pycache__/gpt_model.cpython-310.pyc differ
diff --git a/poetry_diacritizer/models/__pycache__/seq2seq.cpython-310.pyc b/poetry_diacritizer/models/__pycache__/seq2seq.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..957672735a1200c548fc37693f332cc1062de0e2
Binary files /dev/null and b/poetry_diacritizer/models/__pycache__/seq2seq.cpython-310.pyc differ
diff --git a/poetry_diacritizer/models/__pycache__/seq2seq.cpython-38.pyc b/poetry_diacritizer/models/__pycache__/seq2seq.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b12fecbc881b569d27512947e9dc4b661d0ed76
Binary files /dev/null and b/poetry_diacritizer/models/__pycache__/seq2seq.cpython-38.pyc differ
diff --git a/poetry_diacritizer/models/__pycache__/tacotron_based.cpython-310.pyc b/poetry_diacritizer/models/__pycache__/tacotron_based.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75a8dca96bdd441548c52a4347684fdab920e185
Binary files /dev/null and b/poetry_diacritizer/models/__pycache__/tacotron_based.cpython-310.pyc differ
diff --git a/poetry_diacritizer/models/__pycache__/tacotron_based.cpython-38.pyc b/poetry_diacritizer/models/__pycache__/tacotron_based.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f9b1808d583f5c935bceff01ee35ace1c4ffac1
Binary files /dev/null and b/poetry_diacritizer/models/__pycache__/tacotron_based.cpython-38.pyc differ
diff --git a/poetry_diacritizer/models/baseline.py b/poetry_diacritizer/models/baseline.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b1e2c6ccb2160e394ecde108020689d7cf30290
--- /dev/null
+++ b/poetry_diacritizer/models/baseline.py
@@ -0,0 +1,60 @@
+from typing import List
+from torch import nn
+import torch
+
+
+class BaseLineModel(nn.Module):
+ def __init__(
+ self,
+ inp_vocab_size: int,
+ targ_vocab_size: int,
+ embedding_dim: int = 512,
+ layers_units: List[int] = [256, 256, 256],
+ use_batch_norm: bool = False,
+ ):
+ super().__init__()
+ self.targ_vocab_size = targ_vocab_size
+ self.embedding = nn.Embedding(inp_vocab_size, embedding_dim)
+
+ layers_units = [embedding_dim // 2] + layers_units
+
+ layers = []
+
+ for i in range(1, len(layers_units)):
+ layers.append(
+ nn.LSTM(
+ layers_units[i - 1] * 2,
+ layers_units[i],
+ bidirectional=True,
+ batch_first=True,
+ )
+ )
+ if use_batch_norm:
+ layers.append(nn.BatchNorm1d(layers_units[i] * 2))
+
+ self.layers = nn.ModuleList(layers)
+ self.projections = nn.Linear(layers_units[-1] * 2, targ_vocab_size)
+ self.layers_units = layers_units
+ self.use_batch_norm = use_batch_norm
+
+ def forward(self, src: torch.Tensor, lengths: torch.Tensor, target=None):
+
+ outputs = self.embedding(src)
+
+ # embedded_inputs = [batch_size, src_len, embedding_dim]
+
+ for i, layer in enumerate(self.layers):
+ if isinstance(layer, nn.BatchNorm1d):
+ outputs = layer(outputs.permute(0, 2, 1))
+ outputs = outputs.permute(0, 2, 1)
+ continue
+ if i > 0:
+ outputs, (hn, cn) = layer(outputs, (hn, cn))
+ else:
+ outputs, (hn, cn) = layer(outputs)
+
+ predictions = self.projections(outputs)
+
+ output = {"diacritics": predictions}
+
+ return output
diff --git a/poetry_diacritizer/models/cbhg.py b/poetry_diacritizer/models/cbhg.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2b8f061d10ec6b1a9490029a4b4ed43fdd5e861
--- /dev/null
+++ b/poetry_diacritizer/models/cbhg.py
@@ -0,0 +1,121 @@
+"""
+The CBHG model implementation
+"""
+from typing import List, Optional
+
+from torch import nn
+import torch
+
+from poetry_diacritizer.modules.tacotron_modules import CBHG, Prenet
+
+
+class CBHGModel(nn.Module):
+ """CBHG model implementation as described in the paper:
+ https://ieeexplore.ieee.org/document/9274427
+
+ Args:
+ inp_vocab_size (int): the number of the input symbols
+ targ_vocab_size (int): the number of the target symbols (diacritics)
+ embedding_dim (int): the embedding size
+ use_prenet (bool): whether to use prenet or not
+ prenet_sizes (List[int]): the sizes of the prenet networks
+ cbhg_gru_units (int): the number of units of the CBHG GRU, which is the last
+ layer of the CBHG Model.
+ cbhg_filters (int): number of filters used in the CBHG module
+ cbhg_projections: projections used in the CBHG module
+
+ Returns:
+ diacritics Dict[str, Tensor]:
+ """
+
+ def __init__(
+ self,
+ inp_vocab_size: int,
+ targ_vocab_size: int,
+ embedding_dim: int = 512,
+ use_prenet: bool = True,
+ prenet_sizes: List[int] = [512, 256],
+ cbhg_gru_units: int = 512,
+ cbhg_filters: int = 16,
+ cbhg_projections: List[int] = [128, 256],
+ post_cbhg_layers_units: List[int] = [256, 256],
+ post_cbhg_use_batch_norm: bool = True
+ ):
+ super().__init__()
+ self.use_prenet = use_prenet
+ self.embedding = nn.Embedding(inp_vocab_size, embedding_dim)
+ if self.use_prenet:
+ self.prenet = Prenet(embedding_dim, prenet_depth=prenet_sizes)
+
+ self.cbhg = CBHG(
+ prenet_sizes[-1] if self.use_prenet else embedding_dim,
+ cbhg_gru_units,
+ K=cbhg_filters,
+ projections=cbhg_projections,
+ )
+
+ layers = []
+ post_cbhg_layers_units = [cbhg_gru_units] + post_cbhg_layers_units
+
+ for i in range(1, len(post_cbhg_layers_units)):
+ layers.append(
+ nn.LSTM(
+ post_cbhg_layers_units[i - 1] * 2,
+ post_cbhg_layers_units[i],
+ bidirectional=True,
+ batch_first=True,
+ )
+ )
+ if post_cbhg_use_batch_norm:
+ layers.append(nn.BatchNorm1d(post_cbhg_layers_units[i] * 2))
+
+ self.post_cbhg_layers = nn.ModuleList(layers)
+ self.projections = nn.Linear(post_cbhg_layers_units[-1] * 2, targ_vocab_size)
+ self.post_cbhg_layers_units = post_cbhg_layers_units
+ self.post_cbhg_use_batch_norm = post_cbhg_use_batch_norm
+
+
+ def forward(
+ self,
+ src: torch.Tensor,
+ lengths: Optional[torch.Tensor] = None,
+ target: Optional[torch.Tensor] = None, # not required in this model
+ ):
+ """Compute forward propagation"""
+
+ # src = [batch_size, src len]
+ # lengths = [batch_size]
+ # target = [batch_size, trg len]
+
+ embedding_out = self.embedding(src)
+ # embedding_out; [batch_size, src_len, embedding_dim]
+
+ cbhg_input = embedding_out
+ if self.use_prenet:
+ cbhg_input = self.prenet(embedding_out)
+
+ # cbhg_input = [batch_size, src_len, prenet_sizes[-1]]
+
+ outputs = self.cbhg(cbhg_input, lengths)
+
+ hn = torch.zeros((2, 2, 2))
+ cn = torch.zeros((2, 2, 2))
+
+ for i, layer in enumerate(self.post_cbhg_layers):
+ if isinstance(layer, nn.BatchNorm1d):
+ outputs = layer(outputs.permute(0, 2, 1))
+ outputs = outputs.permute(0, 2, 1)
+ continue
+ if i > 0:
+ outputs, (hn, cn) = layer(outputs, (hn, cn))
+ else:
+ outputs, (hn, cn) = layer(outputs)
+
+
+ predictions = self.projections(outputs)
+
+ # predictions = [batch_size, src len, targ_vocab_size]
+
+ output = {"diacritics": predictions}
+
+ return output
diff --git a/poetry_diacritizer/models/gpt.py b/poetry_diacritizer/models/gpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..45a53a63377930d1f90d48dcb185c63c50b237ae
--- /dev/null
+++ b/poetry_diacritizer/models/gpt.py
@@ -0,0 +1,83 @@
+from typing import List
+from torch import nn
+import torch
+from pathlib import Path
+import json
+from .gpt_model import Model, HParams
+
+
+class GPTModel(nn.Module):
+ def __init__(self, path, n_layer=-1, freeze=True, use_lstm=False):
+ super().__init__()
+ root = Path(path)
+
+ params = json.loads((root / "params.json").read_text())
+ hparams = params["hparams"]
+ hparams.setdefault("n_hidden", hparams["n_embed"])
+ self.model = Model(HParams(**hparams))
+ state = torch.load(root / "model.pt", map_location="cpu")
+ state_dict = self.fixed_state_dict(state["state_dict"])
+ self.model.load_state_dict(state_dict)
+ self.activation = {}
+ self.freeze = freeze
+ self.n_layer = n_layer
+ if self.freeze:
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ self.activation = {}
+ self.use_lstm = use_lstm
+ self.set_hook(self.n_layer)
+ self.in_fc_layer = 512 if self.use_lstm else 768
+ self.lstm1 = nn.LSTM(
+ 768,
+ 256,
+ bidirectional=True,
+ batch_first=True,
+ )
+ self.lstm2 = nn.LSTM(
+ 512,
+ 256,
+ bidirectional=True,
+ batch_first=True,
+ )
+ self.lstm3 = nn.LSTM(
+ 512,
+ 256,
+ bidirectional=True,
+ batch_first=True,
+ )
+ self.fc = nn.Linear(self.in_fc_layer, 17)
+
+ def get_activation(self, name):
+ def hook(model, input, output):
+ self.activation[name] = output[0].detach()
+
+ return hook
+
+ def set_hook(self, n_layer=0):
+ self.model.blocks[n_layer].register_forward_hook(self.get_activation("feats"))
+
+ def fixed_state_dict(self, state_dict):
+ if all(k.startswith("module.") for k in state_dict):
+ # legacy multi-GPU format
+ state_dict = {k[len("module.") :]: v for k, v in state_dict.items()}
+ return state_dict
+
+ def forward(self, src: torch.Tensor, lengths: torch.Tensor, target=None):
+
+ # logits shape [batch_size, 256, 500]
+ logits = self.model(src)["logits"]
+ logits = self.activation["feats"]
+
+ if self.use_lstm:
+ x, (h, cn) = self.lstm1(logits)
+ x, (h, cn) = self.lstm2(x)
+ x, (h, cn) = self.lstm3(x)
+ else:
+ x = logits
+ predictions = self.fc(x)
+
+ output = {"diacritics": predictions}
+
+ return output
diff --git a/poetry_diacritizer/models/gpt_model.py b/poetry_diacritizer/models/gpt_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a64aaf9e56067543a2aab17d9b20f6170b5b75f
--- /dev/null
+++ b/poetry_diacritizer/models/gpt_model.py
@@ -0,0 +1,213 @@
+"""
+OpenAI's GPT-2 ported to PyTorch.
+"""
+import math
+
+import attr
+import torch
+from torch import nn
+from torch.nn import functional as F
+import torch.utils.checkpoint
+
+
+@attr.s(auto_attribs=True, frozen=True)
+class HParams:
+ n_vocab: int
+ n_ctx: int
+ n_embed: int
+ n_hidden: int
+ n_head: int
+ n_layer: int
+ gradient_checkpointing: bool = False
+
+
+class Model(nn.Module):
+ def __init__(self, hparams: HParams):
+ super().__init__()
+ self.hparams = hparams
+ self.wpe = nn.Embedding(hparams.n_ctx, hparams.n_embed)
+ nn.init.normal_(self.wpe.weight, std=0.01)
+ self.wte = nn.Embedding(hparams.n_vocab, hparams.n_embed)
+ nn.init.normal_(self.wte.weight, std=0.02)
+ self.blocks = nn.ModuleList(
+ [Block(hparams) for _ in range(hparams.n_layer)])
+ self.ln_f = Norm(self.hparams.n_hidden)
+ if hparams.n_hidden != hparams.n_embed:
+ self.in_proj = Conv1D(hparams.n_embed, hparams.n_hidden)
+ self.out_proj = Conv1D(hparams.n_hidden, hparams.n_embed)
+ else:
+ self.in_proj = self.out_proj = None
+
+ def forward(self, x, past=None):
+ # Embedding
+ past_length = 0 if past is None else past.shape[-2]
+ batch_size, n_ctx = x.shape
+ position = position_for(batch_size, n_ctx, past_length, x.device)
+ h = self.wte(x) + self.wpe(position)
+ assert h.shape == (batch_size, n_ctx, self.hparams.n_embed)
+ if self.in_proj:
+ h = self.in_proj(h)
+ # Transformer
+ presents = []
+ for i, block in enumerate(self.blocks):
+ if self.hparams.gradient_checkpointing:
+ h, present = torch.utils.checkpoint.checkpoint(
+ block, h, past[:, i] if past is not None else None)
+ else:
+ h, present = block(
+ h, past=past[:, i] if past is not None else None)
+ presents.append(present)
+ h = self.ln_f(h)
+ if self.out_proj:
+ h = self.out_proj(h)
+ # Output logits
+ h_flat = h.reshape([batch_size * n_ctx, self.hparams.n_embed])
+ logits = torch.matmul(h_flat, self.wte.weight.t())
+ logits = logits.reshape([batch_size, n_ctx, self.hparams.n_vocab])
+ return {
+ 'presents': torch.stack(tuple(presents), dim=1),
+ 'logits': logits,
+ }
+
+
+class Block(nn.Module):
+ def __init__(self, hparams: HParams):
+ super().__init__()
+ self.ln_1 = Norm(hparams.n_hidden)
+ self.ln_2 = Norm(hparams.n_hidden)
+ self.mlp = MLP(hparams.n_hidden, hparams.n_hidden * 4)
+ self.attn = Attention(hparams)
+
+ def forward(self, x, past):
+ a, present = self.attn(self.ln_1(x), past=past)
+ x = x + a
+ m = self.mlp(self.ln_2(x))
+ x = x + m
+ return x, present
+
+
+class Norm(nn.Module):
+ """ Normalize to mean = 0, std = 1, then do a diagonal affine transform.
+ """
+ def __init__(self, n_features, *, dim=-1, epsilon=1e-5):
+ super().__init__()
+ self.n_features = n_features
+ self.dim = dim
+ self.epsilon = epsilon
+ self.g = nn.Parameter(torch.ones(n_features))
+ self.b = nn.Parameter(torch.zeros(n_features))
+
+ def forward(self, x):
+ assert x.shape[-1] == self.n_features
+ u = torch.mean(x, dim=self.dim, keepdim=True)
+ xmu = x - u
+ s = torch.mean(xmu * xmu, dim=self.dim, keepdim=True)
+ return xmu * torch.rsqrt(s + self.epsilon) * self.g + self.b
+
+
+class MLP(nn.Module):
+ def __init__(self, n_features, n_hidden):
+ super().__init__()
+ self.c_fc = Conv1D(n_features, n_hidden)
+ self.c_proj = Conv1D(n_hidden, n_features)
+
+ def forward(self, x):
+ x = gelu(self.c_fc(x))
+ x = self.c_proj(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, hparams: HParams):
+ super().__init__()
+ assert hparams.n_hidden % hparams.n_head == 0
+ self.hparams = hparams
+ self.c_attn = Conv1D(hparams.n_hidden, hparams.n_hidden * 3)
+ self.c_proj = Conv1D(hparams.n_hidden, hparams.n_hidden)
+
+ def forward(self, x, past):
+ assert len(x.shape) == 3 # [batch, sequence, features]
+ assert x.shape[-1] == self.hparams.n_hidden
+ if past is not None:
+ # Should be [batch, 2, heads, sequence, features], where 2 is [k, v]
+ assert len(past.shape) == 5
+ assert past.shape[-1] == self.hparams.n_hidden
+ c = self.c_attn(x)
+ q, k, v = map(self.split_heads, torch.split(c, x.shape[-1], dim=2))
+ present = torch.stack([k, v], dim=1)
+ if past is not None:
+ pk, pv = past[:, 0], past[:, 1]
+ k = torch.cat([pk, k], dim=-2)
+ v = torch.cat([pv, v], dim=-2)
+ a = self.multihead_attn(q, k, v)
+ a = self.merge_heads(a)
+ a = self.c_proj(a)
+ return a, present
+
+ def split_heads(self, x):
+ """ From [batch, sequence, features] to
+ [batch, heads, sequence, features].
+ """
+ return self.split_states(x, self.hparams.n_head).permute(0, 2, 1, 3)
+
+ @staticmethod
+ def split_states(x, n):
+ """ Reshape the last dimension of x into [n, x.shape[-1]/n].
+ """
+ *start, m = x.shape
+ return x.reshape(start + [n, m // n])
+
+ def merge_heads(self, x):
+ """ Reverse of split_heads.
+ """
+ return self.merge_states(x.permute(0, 2, 1, 3))
+
+ @staticmethod
+ def merge_states(x):
+ """ Smash the last two dimensions of x into a single dimension.
+ """
+ *start, a, b = x.shape
+ return x.reshape(start + [a * b])
+
+ def mask_attn_weights(self, w):
+ # w has shape [batch, heads, dst_sequence, src_sequence],
+ # where information flows from src to dst.
+ _, _, nd, ns = w.shape
+ b = self.attention_mask(nd, ns, dtype=w.dtype, device=w.device)
+ b = b.reshape((1, 1, nd, ns))
+ w = w * b - 1e4 * (1 - b)
+ return w
+
+ @staticmethod
+ def attention_mask(nd, ns, *, dtype, device=None):
+ """ 1's in the lower triangle, counting from the lower right corner.
+ Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd),
+ but doesn't produce garbage on TPUs.
+ """
+ i = torch.arange(0, nd).unsqueeze(1)
+ j = torch.arange(ns)
+ return (i >= j - ns + nd).to(dtype=dtype, device=device)
+
+ def multihead_attn(self, q, k, v):
+ # q, k, v have shape [batch, heads, sequence, features]
+ w = torch.matmul(q, k.permute(0, 1, 3, 2))
+ w = w / math.sqrt(v.shape[-1])
+ w = self.mask_attn_weights(w)
+ w = F.softmax(w, dim=-1)
+ a = torch.matmul(w, v)
+ return a
+
+
+class Conv1D(nn.Linear):
+ def reset_parameters(self):
+ nn.init.normal_(self.weight, std=0.02)
+ nn.init.zeros_(self.bias)
+
+
+def gelu(x, c=math.sqrt(2 / math.pi)):
+ return 0.5 * x * (1 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3))))
+
+
+def position_for(batch_size, n_steps, past_length, device=None):
+ return (torch.arange(past_length, n_steps + past_length, device=device)
+ .unsqueeze(0).repeat(batch_size, 1))
diff --git a/poetry_diacritizer/models/seq2seq.py b/poetry_diacritizer/models/seq2seq.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fcf05c5cfb6087d90601b246a1235c47ded6903
--- /dev/null
+++ b/poetry_diacritizer/models/seq2seq.py
@@ -0,0 +1,277 @@
+from typing import List
+from typing import List, Optional
+
+import torch
+from torch import nn
+from torch.autograd import Variable
+
+from poetry_diacritizer.modules.attention import AttentionWrapper
+from poetry_diacritizer.modules.layers import ConvNorm
+from poetry_diacritizer.modules.tacotron_modules import CBHG, Prenet
+from poetry_diacritizer.options import AttentionType
+from poetry_diacritizer.util.utils import get_mask_from_lengths
+
+
+class Seq2Seq(nn.Module):
+ def __init__(self, encoder: nn.Module, decoder: nn.Module):
+ super().__init__()
+ # Trying smaller std
+ self.encoder = encoder
+ self.decoder = decoder
+
+ def forward(
+ self,
+ src: torch.Tensor,
+ lengths: torch.Tensor,
+ target: Optional[torch.Tensor] = None,
+ ):
+
+ encoder_outputs = self.encoder(src, lengths)
+ mask = get_mask_from_lengths(encoder_outputs, lengths)
+ outputs, alignments = self.decoder(encoder_outputs, target, mask)
+
+ output = {"diacritics": outputs, "attention": alignments}
+
+ return output
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ inp_vocab_size: int,
+ embedding_dim: int = 512,
+ layers_units: List[int] = [256, 256, 256],
+ use_batch_norm: bool = False,
+ ):
+ super().__init__()
+ self.embedding = nn.Embedding(inp_vocab_size, embedding_dim)
+
+ layers_units = [embedding_dim // 2] + layers_units
+
+ layers = []
+
+ for i in range(1, len(layers_units)):
+ layers.append(
+ nn.LSTM(
+ layers_units[i - 1] * 2,
+ layers_units[i],
+ bidirectional=True,
+ batch_first=True,
+ )
+ )
+ if use_batch_norm:
+ layers.append(nn.BatchNorm1d(layers_units[i] * 2))
+
+ self.layers = nn.ModuleList(layers)
+ self.layers_units = layers_units
+ self.use_batch_norm = use_batch_norm
+
+ def forward(self, inputs: torch.Tensor, inputs_lengths: torch.Tensor):
+
+ outputs = self.embedding(inputs)
+
+ # embedded_inputs = [batch_size, src_len, embedding_dim]
+
+ for i, layer in enumerate(self.layers):
+ if isinstance(layer, nn.BatchNorm1d):
+ outputs = layer(outputs.permute(0, 2, 1))
+ outputs = outputs.permute(0, 2, 1)
+ continue
+ if i > 0:
+ outputs, (hn, cn) = layer(outputs, (hn, cn))
+ else:
+ outputs, (hn, cn) = layer(outputs)
+
+ return outputs
+
+class Decoder(nn.Module):
+ """A seq2seq decoder that decode a diacritic at a time ,
+ Args:
+ encoder_dim (int): the encoder output dim
+ decoder_units (int): the number of neurons for each decoder layer
+ decoder_layers (int): number of decoder layers
+ """
+
+ def __init__(
+ self,
+ trg_vocab_size: int,
+ start_symbol_id: int,
+ encoder_dim: int = 256,
+ embedding_dim: int = 256,
+ decoder_units: int = 256,
+ decoder_layers: int = 2,
+ attention_units: int = 256,
+ attention_type: AttentionType = AttentionType.LocationSensitive,
+ is_attention_accumulative: bool = False,
+ prenet_depth: List[int] = [256, 128],
+ use_prenet: bool = True,
+ teacher_forcing_probability: float = 0.0,
+ ):
+ super().__init__()
+
+ self.output_dim: int = trg_vocab_size
+ self.start_symbol_id = start_symbol_id
+ self.attention_units = attention_units
+ self.decoder_units = decoder_units
+ self.encoder_dim = encoder_dim
+ self.use_prenet = use_prenet
+ self.teacher_forcing_probability = teacher_forcing_probability
+ self.is_attention_accumulative = is_attention_accumulative
+ self.embbeding = nn.Embedding(trg_vocab_size, embedding_dim, padding_idx=0)
+ attention_in = embedding_dim
+ if use_prenet:
+ self.prenet = Prenet(embedding_dim, prenet_depth)
+ attention_in = prenet_depth[-1]
+
+ self.attention_layer = nn.GRUCell(encoder_dim + attention_in, attention_units)
+ self.attention_wrapper = AttentionWrapper(attention_type, attention_units)
+ self.keys_layer = nn.Linear(encoder_dim, attention_units, bias=False)
+ self.project_to_decoder_in = nn.Linear(
+ attention_units + encoder_dim,
+ decoder_units,
+ )
+
+ self.decoder_rnns = nn.ModuleList(
+ [nn.GRUCell(decoder_units, decoder_units) for _ in range(decoder_layers)]
+ )
+
+ self.diacritics_layer = nn.Linear(decoder_units, trg_vocab_size)
+ self.device = "cuda"
+
+ def decode(
+ self,
+ diacritic: torch.Tensor,
+ ):
+ """
+ Decode one time-step
+ Args:
+ diacritic (Tensor): (batch_size, 1)
+ Returns:
+ """
+
+ diacritic = self.embbeding(diacritic)
+ if self.use_prenet:
+ prenet_out = self.prenet(diacritic)
+ else:
+ prenet_out = diacritic
+
+ cell_input = torch.cat((prenet_out, self.prev_attention), -1)
+
+ self.attention_hidden = self.attention_layer(cell_input, self.attention_hidden)
+ output = self.attention_hidden
+
+ # The queries are the hidden state of the RNN layer
+ attention, alignment = self.attention_wrapper(
+ query=self.attention_hidden,
+ values=self.encoder_outputs,
+ keys=self.keys,
+ mask=self.mask,
+ prev_alignment=self.prev_alignment,
+ )
+
+ decoder_input = torch.cat((output, attention), -1)
+
+ decoder_input = self.project_to_decoder_in(decoder_input)
+
+ for idx in range(len(self.decoder_rnns)):
+ self.decoder_hiddens[idx] = self.decoder_rnns[idx](
+ decoder_input, self.decoder_hiddens[idx]
+ )
+ decoder_input = self.decoder_hiddens[idx] + decoder_input
+
+ output = decoder_input
+
+ output = self.diacritics_layer(output)
+
+ if self.is_attention_accumulative:
+ self.prev_alignment = self.prev_alignment + alignment
+ else:
+ self.prev_alignment = alignment
+
+ self.prev_attention = attention
+
+ return output, alignment
+
+ def inference(self):
+ """Generate diacritics one at a time"""
+ batch_size = self.encoder_outputs.size(0)
+ trg_len = self.encoder_outputs.size(1)
+ diacritic = (
+ torch.full((batch_size,), self.start_symbol_id).to(self.device).long()
+ )
+ outputs, alignments = [], []
+ self.initialize()
+
+ for _ in range(trg_len):
+ output, alignment = self.decode(diacritic=diacritic)
+
+ outputs.append(output)
+ alignments.append(alignment)
+ diacritic = torch.max(output, 1).indices
+
+ alignments = torch.stack(alignments).transpose(0, 1)
+ outputs = torch.stack(outputs).transpose(0, 1).contiguous()
+ return outputs, alignments
+
+ def forward(
+ self,
+ encoder_outputs: torch.Tensor,
+ diacritics: Optional[torch.Tensor] = None,
+ input_mask: Optional[torch.Tensor] = None,
+ ):
+ """calculate forward propagation
+ Args:
+ encoder_outputs (Tensor): the output of the encoder
+ (batch_size, Tx, encoder_units * 2)
+ diacritics(Tensor): target sequence
+ input_mask (Tensor): the inputs mask (batch_size, Tx)
+ """
+ self.mask = input_mask
+ self.encoder_outputs = encoder_outputs
+ self.keys = self.keys_layer(encoder_outputs)
+
+ if diacritics is None:
+ return self.inference()
+
+ batch_size = diacritics.size(0)
+ trg_len = diacritics.size(1)
+
+ # Init decoder states
+ outputs = []
+ alignments = []
+
+ self.initialize()
+
+ diacritic = (
+ torch.full((batch_size,), self.start_symbol_id).to(self.device).long()
+ )
+
+ for time in range(trg_len):
+ output, alignment = self.decode(diacritic=diacritic)
+ outputs += [output]
+ alignments += [alignment]
+ #if random.random() > self.teacher_forcing_probability:
+ diacritic = diacritics[:, time] # use training input
+ #else:
+ #diacritic = torch.max(output, 1).indices # use last output
+
+ alignments = torch.stack(alignments).transpose(0, 1)
+ outputs = torch.stack(outputs).transpose(0, 1).contiguous()
+
+ return outputs, alignments
+
+ def initialize(self):
+ """Initialize the first step variables"""
+ batch_size = self.encoder_outputs.size(0)
+ src_len = self.encoder_outputs.size(1)
+ self.attention_hidden = Variable(
+ torch.zeros(batch_size, self.attention_units)
+ ).to(self.device)
+ self.decoder_hiddens = [
+ Variable(torch.zeros(batch_size, self.decoder_units)).to(self.device)
+ for _ in range(len(self.decoder_rnns))
+ ]
+ self.prev_attention = Variable(torch.zeros(batch_size, self.encoder_dim)).to(
+ self.device
+ )
+ self.prev_alignment = Variable(torch.zeros(batch_size, src_len)).to(self.device)
diff --git a/poetry_diacritizer/models/tacotron_based.py b/poetry_diacritizer/models/tacotron_based.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bbd408e25b485fb80040683658c42ab9d382221
--- /dev/null
+++ b/poetry_diacritizer/models/tacotron_based.py
@@ -0,0 +1,47 @@
+from typing import List
+from poetry_diacritizer.models.seq2seq import Seq2Seq, Decoder as Seq2SeqDecoder
+from poetry_diacritizer.modules.tacotron_modules import CBHG, Prenet
+from torch import nn
+
+
+class Tacotron(Seq2Seq):
+ pass
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ inp_vocab_size: int,
+ embedding_dim: int = 512,
+ use_prenet: bool = True,
+ prenet_sizes: List[int] = [256, 128],
+ cbhg_gru_units: int = 128,
+ cbhg_filters: int = 16,
+ cbhg_projections: List[int] = [128, 128],
+ padding_idx: int = 0,
+ ):
+ super().__init__()
+ self.use_prenet = use_prenet
+
+ self.embedding = nn.Embedding(
+ inp_vocab_size, embedding_dim, padding_idx=padding_idx
+ )
+ if use_prenet:
+ self.prenet = Prenet(embedding_dim, prenet_depth=prenet_sizes)
+ self.cbhg = CBHG(
+ prenet_sizes[-1] if use_prenet else embedding_dim,
+ cbhg_gru_units,
+ K=cbhg_filters,
+ projections=cbhg_projections,
+ )
+
+ def forward(self, inputs, input_lengths=None):
+
+ outputs = self.embedding(inputs)
+ if self.use_prenet:
+ outputs = self.prenet(outputs)
+ return self.cbhg(outputs, input_lengths)
+
+
+class Decoder(Seq2SeqDecoder):
+ pass
diff --git a/poetry_diacritizer/modules/__pycache__/attention.cpython-310.pyc b/poetry_diacritizer/modules/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eaf3c110b2ab63768c06ae6c31a1c24e7ac3b7e1
Binary files /dev/null and b/poetry_diacritizer/modules/__pycache__/attention.cpython-310.pyc differ
diff --git a/poetry_diacritizer/modules/__pycache__/attention.cpython-38.pyc b/poetry_diacritizer/modules/__pycache__/attention.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46cafe6f4216a7cb1a00a144b0b2375f42c823fc
Binary files /dev/null and b/poetry_diacritizer/modules/__pycache__/attention.cpython-38.pyc differ
diff --git a/poetry_diacritizer/modules/__pycache__/layers.cpython-310.pyc b/poetry_diacritizer/modules/__pycache__/layers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5264df65507ddc7d18638598b3ee69099d170c83
Binary files /dev/null and b/poetry_diacritizer/modules/__pycache__/layers.cpython-310.pyc differ
diff --git a/poetry_diacritizer/modules/__pycache__/layers.cpython-38.pyc b/poetry_diacritizer/modules/__pycache__/layers.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..383065f87cf4b5c18994e4268ed2f2d8c8699d9b
Binary files /dev/null and b/poetry_diacritizer/modules/__pycache__/layers.cpython-38.pyc differ
diff --git a/poetry_diacritizer/modules/__pycache__/tacotron_modules.cpython-310.pyc b/poetry_diacritizer/modules/__pycache__/tacotron_modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..98f0f1826b41ab5837423bf6fbb3f64f8b2e1833
Binary files /dev/null and b/poetry_diacritizer/modules/__pycache__/tacotron_modules.cpython-310.pyc differ
diff --git a/poetry_diacritizer/modules/__pycache__/tacotron_modules.cpython-38.pyc b/poetry_diacritizer/modules/__pycache__/tacotron_modules.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..233c924d260ced5d2ff99c41777172202655646d
Binary files /dev/null and b/poetry_diacritizer/modules/__pycache__/tacotron_modules.cpython-38.pyc differ
diff --git a/poetry_diacritizer/modules/attention.py b/poetry_diacritizer/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae916b43783efa55f2f29e7df79dc4d2dfffbc1b
--- /dev/null
+++ b/poetry_diacritizer/modules/attention.py
@@ -0,0 +1,199 @@
+from typing import Optional
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from poetry_diacritizer.options import AttentionType
+
+
+class BahdanauAttention(nn.Module):
+ def __init__(self, dim):
+ super(BahdanauAttention, self).__init__()
+ self.query_layer = nn.Linear(dim, dim, bias=False)
+ self.tanh = nn.Tanh()
+ self.v = nn.Linear(dim, 1, bias=False)
+
+ def forward(self, query: torch.Tensor, keys: torch.Tensor):
+ """
+ Args:
+ query: (B, 1, dim) or (batch, dim)
+ processed_memory: (batch, max_time, dim)
+ """
+ if query.dim() == 2:
+ # insert time-axis for broadcasting
+ query = query.unsqueeze(1)
+ # (batch, 1, dim)
+ query = self.query_layer(query)
+
+ # (batch, max_time, 1)
+ alignment = self.v(self.tanh(query + keys))
+
+ # (batch, max_time)
+ return alignment.squeeze(-1)
+
+
+class LocationSensitive(nn.Module):
+ def __init__(self, dim):
+ super(LocationSensitive, self).__init__()
+ self.query_layer = nn.Linear(dim, dim, bias=False)
+ self.v = nn.Linear(dim, 1, bias=True)
+ self.location_layer = nn.Linear(32, dim, bias=False)
+ padding = int((31 - 1) / 2)
+ self.location_conv = torch.nn.Conv1d(
+ 1, 32, kernel_size=31, stride=1, padding=padding, dilation=1, bias=False
+ )
+
+ self.score_mask_value = -float("inf")
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ keys: torch.Tensor,
+ prev_alignments: torch.Tensor,
+ ):
+ # keys = keys.permute(1,0,2)
+ query = self.query_layer(query)
+ if query.dim() == 2:
+ # insert time-axis for broadcasting
+ query = query.unsqueeze(1)
+ # -> [batch_size, 1, attention_dim]
+
+ alignments = prev_alignments.unsqueeze(1)
+
+ # location features [batch_size, max_time, filters]
+ filters = self.location_conv(alignments)
+ location_features = self.location_layer(filters.transpose(1, 2))
+
+ alignments = self.v(torch.tanh(query + location_features + keys))
+ return alignments.squeeze(-1)
+
+
+class AttentionWrapper(nn.Module):
+ def __init__(
+ self,
+ attention_type: AttentionType = AttentionType.LocationSensitive,
+ attention_units: int = 256,
+ score_mask_value=-float("inf"),
+ ):
+ super().__init__()
+ self.score_mask_value = score_mask_value
+ self.attention_type = attention_type
+
+ if attention_type == AttentionType.LocationSensitive:
+ self.attention_mechanism = LocationSensitive(attention_units)
+ elif attention_type == AttentionType.Content_Based:
+ self.attention_mechanism = BahdanauAttention(attention_units)
+ else:
+ raise Exception("The attention type is not known")
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ keys: torch.Tensor,
+ values: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ prev_alignment: Optional[torch.Tensor] = None,
+ ):
+
+ # Alignment
+ # (batch, max_time)
+ if self.attention_type == AttentionType.Content_Based:
+ alignment = self.attention_mechanism(query, keys)
+ else:
+ alignment = self.attention_mechanism(query, keys, prev_alignment)
+
+ # Attention context vector
+
+ if mask is not None:
+ alignment.data.masked_fill_(mask, self.score_mask_value)
+
+ alignment = F.softmax(alignment, dim=1)
+ attention = torch.bmm(alignment.unsqueeze(1), values)
+ attention = attention.squeeze(1)
+
+ return attention, alignment
+
+
+class MultiHeadAttentionLayer(nn.Module):
+ def __init__(self, hid_dim: int, n_heads: int, dropout: float = 0.0):
+ super().__init__()
+
+ assert hid_dim % n_heads == 0
+
+ self.hid_dim = hid_dim
+ self.n_heads = n_heads
+ self.head_dim = hid_dim // n_heads
+
+ self.fc_q = nn.Linear(hid_dim, hid_dim)
+ self.fc_k = nn.Linear(hid_dim, hid_dim)
+ self.fc_v = nn.Linear(hid_dim, hid_dim)
+
+ self.fc_o = nn.Linear(hid_dim * 2, hid_dim)
+
+ if dropout != 0.0:
+ self.dropout = nn.Dropout(dropout)
+
+ self.use_dropout = dropout != 0.0
+
+ device = next(self.parameters()).device
+
+ self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
+
+ def forward(self, query, key, value, mask=None):
+
+ batch_size = query.shape[0]
+
+ # query = [batch size, query len, hid dim]
+ # key = [batch size, key len, hid dim]
+ # value = [batch size, value len, hid dim]
+
+ Q = self.fc_q(query)
+ K = self.fc_k(key)
+ V = self.fc_v(value)
+
+ # Q = [batch size, query len, hid dim]
+ # K = [batch size, key len, hid dim]
+ # V = [batch size, value len, hid dim]
+
+ Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
+ K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
+ V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
+
+ # Q = [batch size, n heads, query len, head dim]
+ # K = [batch size, n heads, key len, head dim]
+ # V = [batch size, n heads, value len, head dim]
+
+ energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
+
+ # energy = [batch size, n heads, query len, key len]
+
+ if mask is not None:
+ energy = energy.masked_fill(mask == 0, -float("inf"))
+
+ attention = torch.softmax(energy, dim=-1)
+
+ # attention = [batch size, n heads, query len, key len]
+
+ if self.use_dropout:
+ context_vector = torch.matmul(self.dropout(attention), V)
+ else:
+ context_vector = torch.matmul(attention, V)
+
+ # x = [batch size, n heads, query len, head dim]
+
+ context_vector = context_vector.permute(0, 2, 1, 3).contiguous()
+
+ # x = [batch size, query len, n heads, head dim]
+
+ context_vector = context_vector.view(batch_size, -1, self.hid_dim)
+
+ x = torch.cat((query, context_vector), dim=-1)
+
+ # x = [batch size, query len, hid dim * 2]
+
+ x = self.fc_o(x)
+
+ # x = [batch size, query len, hid dim]
+
+ return x, attention
diff --git a/poetry_diacritizer/modules/layers.py b/poetry_diacritizer/modules/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..64d7d68f5d3a7d58c2615939220168a94bbd4475
--- /dev/null
+++ b/poetry_diacritizer/modules/layers.py
@@ -0,0 +1,70 @@
+import torch
+from torch import nn
+from typing import Any
+
+
+class BatchNormConv1d(nn.Module):
+ """
+ A nn.Conv1d followed by an optional activation function, and nn.BatchNorm1d
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ kernel_size: int,
+ stride: int,
+ padding: int,
+ activation: Any = None,
+ ):
+ super().__init__()
+ self.conv1d = nn.Conv1d(
+ in_dim,
+ out_dim,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ bias=False,
+ )
+ self.bn = nn.BatchNorm1d(out_dim)
+ self.activation = activation
+
+ def forward(self, x: Any):
+ x = self.conv1d(x)
+ if self.activation is not None:
+ x = self.activation(x)
+ return self.bn(x)
+
+
+class LinearNorm(torch.nn.Module):
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
+ super().__init__()
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.linear_layer.weight,
+ gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ def forward(self, x):
+ return self.linear_layer(x)
+
+
+class ConvNorm(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
+ super().__init__()
+ if padding is None:
+ assert(kernel_size % 2 == 1)
+ padding = int(dilation * (kernel_size - 1) / 2)
+
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
+ kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation,
+ bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ def forward(self, signal):
+ conv_signal = self.conv(signal)
+ return conv_signal
diff --git a/poetry_diacritizer/modules/tacotron_modules.py b/poetry_diacritizer/modules/tacotron_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..81859a8f1b2999e3c92bfc9c1fbc1d5a0d1d4a27
--- /dev/null
+++ b/poetry_diacritizer/modules/tacotron_modules.py
@@ -0,0 +1,174 @@
+"""
+Some custom modules that are used by the TTS model
+"""
+from typing import List
+import torch
+from torch import nn
+
+from poetry_diacritizer.modules.layers import BatchNormConv1d
+
+
+class Prenet(nn.Module):
+ """
+ A prenet is a collection of linear layers with dropout(0.5),
+ and RELU activation function
+ Args:
+ config: the hyperparameters object
+ in_dim (int): the input dim
+ """
+
+ def __init__(
+ self, in_dim: int, prenet_depth: List[int] = [256, 128], dropout: int = 0.5
+ ):
+ """ Initializing the prenet module """
+ super().__init__()
+ in_sizes = [in_dim] + prenet_depth[:-1]
+ self.layers = nn.ModuleList(
+ [
+ nn.Linear(in_size, out_size)
+ for (in_size, out_size) in zip(in_sizes, prenet_depth)
+ ]
+ )
+ self.relu = nn.ReLU()
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, inputs: torch.Tensor):
+ """Calculate forward propagation
+ Args:
+ inputs (batch_size, seqLen): the inputs to the prenet, the input shapes could
+ be different as it is being used in both encoder and decoder.
+ Returns:
+ Tensor: the output of the forward propagation
+ """
+ for linear in self.layers:
+ inputs = self.dropout(self.relu(linear(inputs)))
+ return inputs
+
+
+class Highway(nn.Module):
+ """Highway Networks were developed by (Srivastava et al., 2015)
+ to overcome the difficulty of training deep neural networks
+ (https://arxiv.org/abs/1507.06228).
+ Args:
+ in_size (int): the input size
+ out_size (int): the output size
+ """
+
+ def __init__(self, in_size, out_size):
+ """
+ Initializing Highway networks
+ """
+ super().__init__()
+ self.H = nn.Linear(in_size, out_size)
+ self.H.bias.data.zero_()
+ self.T = nn.Linear(in_size, out_size)
+ self.T.bias.data.fill_(-1)
+ self.relu = nn.ReLU()
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, inputs: torch.Tensor):
+ """Calculate forward propagation
+ Args:
+ inputs (Tensor):
+ """
+ H = self.relu(self.H(inputs))
+ T = self.sigmoid(self.T(inputs))
+ return H * T + inputs * (1.0 - T)
+
+
+class CBHG(nn.Module):
+ """The CBHG module (1-D Convolution Bank + Highway network + Bidirectional GRU)
+ was proposed by (Lee et al., 2017, https://www.aclweb.org/anthology/Q17-1026)
+ for a character-level NMT model.
+ It was adapted by (Wang et al., 2017) for building the Tacotron.
+ It is used in both the encoder and decoder with different parameters.
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ K: int,
+ projections: List[int],
+ highway_layers: int = 4,
+ ):
+ """Initializing the CBHG module
+ Args:
+ in_dim (int): the input size
+ out_dim (int): the output size
+ k (int): number of filters
+ """
+ super().__init__()
+
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.relu = nn.ReLU()
+ self.conv1d_banks = nn.ModuleList(
+ [
+ BatchNormConv1d(
+ in_dim,
+ in_dim,
+ kernel_size=k,
+ stride=1,
+ padding=k // 2,
+ activation=self.relu,
+ )
+ for k in range(1, K + 1)
+ ]
+ )
+ self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
+
+ in_sizes = [K * in_dim] + projections[:-1]
+ activations = [self.relu] * (len(projections) - 1) + [None]
+ self.conv1d_projections = nn.ModuleList(
+ [
+ BatchNormConv1d(
+ in_size, out_size, kernel_size=3, stride=1, padding=1, activation=ac
+ )
+ for (in_size, out_size, ac) in zip(in_sizes, projections, activations)
+ ]
+ )
+
+ self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False)
+ self.highways = nn.ModuleList([Highway(in_dim, in_dim) for _ in range(4)])
+
+ self.gru = nn.GRU(in_dim, out_dim, 1, batch_first=True, bidirectional=True)
+
+ def forward(self, inputs, input_lengths=None):
+ # (B, T_in, in_dim)
+ x = inputs
+ x = x.transpose(1, 2)
+ T = x.size(-1)
+
+ # (B, in_dim*K, T_in)
+ # Concat conv1d bank outputs
+ x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1)
+ assert x.size(1) == self.in_dim * len(self.conv1d_banks)
+ x = self.max_pool1d(x)[:, :, :T]
+
+ for conv1d in self.conv1d_projections:
+ x = conv1d(x)
+
+ # (B, T_in, in_dim)
+ # Back to the original shape
+ x = x.transpose(1, 2)
+
+ if x.size(-1) != self.in_dim:
+ x = self.pre_highway(x)
+
+ # Residual connection
+ x += inputs
+ for highway in self.highways:
+ x = highway(x)
+
+ if input_lengths is not None:
+ x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True)
+
+ # (B, T_in, in_dim*2)
+ self.gru.flatten_parameters()
+ outputs, _ = self.gru(x)
+
+ if input_lengths is not None:
+ outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
+
+ return outputs
diff --git a/poetry_diacritizer/options.py b/poetry_diacritizer/options.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b850c03d2bab803449965f724fbc61d74f2bde0
--- /dev/null
+++ b/poetry_diacritizer/options.py
@@ -0,0 +1,39 @@
+"""
+Types of various choices used during training
+"""
+from enum import Enum
+
+
+class AttentionType(Enum):
+ """Type of attention used during training"""
+
+ LocationSensitive = 1
+ Content_Based = 2
+ MultiHead = 3
+
+
+class LearningRateType(Enum):
+ """Type of learning rate used during training"""
+
+ Learning_Rate_Decay = 1
+ Cosine_Scheduler = 2
+ SquareRoot_Scheduler = 3
+
+
+class OptimizerType(Enum):
+ """Type of optimizer used during training"""
+
+ Adam = 1
+ SGD = 2
+ AdamW = 3
+
+
+class LossType(Enum):
+ """Type of loss function used during training"""
+
+ L1_LOSS = 1
+ MSE_LOSS = 2
+ L1_LOSS_MASKED = 3
+ MSE_LOSS_MASKED = 4
+ BOTH = 5
+ BOTH_MASKED = 6
diff --git a/poetry_diacritizer/predict.py b/poetry_diacritizer/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..5787dbd8a67b8ce535663bd5d848dca7e460e554
--- /dev/null
+++ b/poetry_diacritizer/predict.py
@@ -0,0 +1,167 @@
+import os
+from typing import Dict
+
+from diacritization_evaluation import der, wer
+import torch
+from torch import nn
+from torch import optim
+from torch.cuda.amp import autocast
+from torch.utils.tensorboard.writer import SummaryWriter
+from tqdm.notebook import tqdm
+from tqdm import trange
+from diacritization_evaluation import util
+
+from .config_manager import ConfigManager
+from .dataset import load_iterators
+from .diacritizer import CBHGDiacritizer, Seq2SeqDiacritizer
+from .options import OptimizerType
+import gdown
+
+class Trainer:
+ def run(self):
+ raise NotImplementedError
+
+
+class GeneralTrainer(Trainer):
+ def __init__(self, config_path: str, model_kind: str) -> None:
+ self.config_path = config_path
+ self.model_kind = model_kind
+ self.config_manager = ConfigManager(
+ config_path=config_path, model_kind=model_kind
+ )
+ self.config = self.config_manager.config
+ self.losses = []
+ self.lr = 0
+ self.pad_idx = 0
+ self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_idx)
+ self.set_device()
+
+ self.config_manager.create_remove_dirs()
+ self.text_encoder = self.config_manager.text_encoder
+ self.start_symbol_id = self.text_encoder.start_symbol_id
+ self.summary_manager = SummaryWriter(log_dir=self.config_manager.log_dir)
+
+ self.model = self.config_manager.get_model()
+
+ self.optimizer = self.get_optimizer()
+ self.model = self.model.to(self.device)
+
+ self.load_model(model_path=self.config.get("train_resume_model_path"))
+ self.load_diacritizer()
+
+ self.initialize_model()
+
+
+ def set_device(self):
+ if self.config.get("device"):
+ self.device = self.config["device"]
+ else:
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ def load_diacritizer(self):
+ if self.model_kind in ["cbhg", "baseline"]:
+ self.diacritizer = CBHGDiacritizer(self.config_path, self.model_kind)
+ elif self.model_kind in ["seq2seq", "tacotron_based"]:
+ self.diacritizer = Seq2SeqDiacritizer(self.config_path, self.model_kind)
+
+ def initialize_model(self):
+ if self.global_step > 1:
+ return
+ if self.model_kind == "transformer":
+ print("Initializing using xavier_uniform_")
+ self.model.apply(initialize_weights)
+
+
+ def load_model(self, model_path: str = None, load_optimizer: bool = True):
+ with open(
+ self.config_manager.base_dir / f"{self.model_kind}_network.txt", "w"
+ ) as file:
+ file.write(str(self.model))
+
+ if model_path is None:
+ last_model_path = self.config_manager.get_last_model_path()
+ if last_model_path is None:
+ self.global_step = 1
+ return
+ else:
+ last_model_path = model_path
+
+ print(f"loading from {last_model_path}")
+ saved_model = torch.load(last_model_path, torch.device(self.config.get("device")))
+ self.model.load_state_dict(saved_model["model_state_dict"])
+ if load_optimizer:
+ self.optimizer.load_state_dict(saved_model["optimizer_state_dict"])
+ self.global_step = saved_model["global_step"] + 1
+
+class DiacritizationTester(GeneralTrainer):
+ def __init__(self, config_path: str, model_kind: str, model_path: str) -> None:
+ # if config_path == 'config/test.yml' or config_path == "Arabic_Diacritization/config/test.yml":
+ # print("Exporting the pretrained models ... ")
+ # url = 'https://drive.google.com/uc?id=12aYNY7cbsLNzhdPdC2K3u1sgrb1lpzwO'
+ # gdown.cached_download(url,'model.zip', quiet=False, postprocess=gdown.extractall)
+
+ self.config_path = config_path
+ self.model_kind = model_kind
+ self.config_manager = ConfigManager(
+ config_path=config_path, model_kind=model_kind
+ )
+ self.config = self.config_manager.config
+ # print(self.config)
+ self.pad_idx = 0
+ self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_idx)
+ self.set_device()
+
+ self.text_encoder = self.config_manager.text_encoder
+ self.start_symbol_id = self.text_encoder.start_symbol_id
+
+ self.model = self.config_manager.get_model()
+
+ self.model = self.model.to(self.device)
+ self.load_model(model_path=model_path, load_optimizer=False)
+ self.load_diacritizer()
+ self.diacritizer.set_model(self.model)
+ self.initialize_model()
+
+ def collate_fn(self, data):
+ """
+ Padding the input and output sequences
+ """
+
+ def merge(sequences):
+ lengths = [len(seq) for seq in sequences]
+ padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
+ for i, seq in enumerate(sequences):
+ end = lengths[i]
+ padded_seqs[i, :end] = seq[:end]
+ return padded_seqs, lengths
+
+ data.sort(key=lambda x: len(x[0]), reverse=True)
+
+ # separate source and target sequences
+ src_seqs, trg_seqs, original = zip(*data)
+
+ # merge sequences (from tuple of 1D tensor to 2D tensor)
+ src_seqs, src_lengths = merge(src_seqs)
+ trg_seqs, trg_lengths = merge(trg_seqs)
+
+ batch = {
+ "original": original,
+ "src": src_seqs,
+ "target": trg_seqs,
+ "lengths": torch.LongTensor(src_lengths), # src_lengths = trg_lengths
+ }
+ return batch
+
+ def get_batch(self, sentence):
+ data = self.text_encoder.clean(sentence)
+ text, inputs, diacritics = util.extract_haraqat(data)
+ inputs = torch.Tensor(self.text_encoder.input_to_sequence("".join(inputs)))
+ diacritics = torch.Tensor(self.text_encoder.target_to_sequence(diacritics))
+ batch = self.collate_fn([(inputs, diacritics, text)])
+ return batch
+
+ def infer(self, sentence):
+ self.model.eval()
+ batch = self.get_batch(sentence)
+ predicted = self.diacritizer.diacritize_batch(batch)
+ return predicted[0]
diff --git a/poetry_diacritizer/test.py b/poetry_diacritizer/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b230ddf5ba4901aee0cf5e5d102fcca328038eeb
--- /dev/null
+++ b/poetry_diacritizer/test.py
@@ -0,0 +1,31 @@
+import argparse
+import random
+from tester import DiacritizationTester
+
+import numpy as np
+import torch
+
+
+SEED = 1234
+random.seed(SEED)
+np.random.seed(SEED)
+torch.manual_seed(SEED)
+torch.cuda.manual_seed(SEED)
+torch.backends.cudnn.deterministic = True
+torch.backends.cudnn.benchmark = False
+
+
+def train_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", dest="model_kind", type=str, required=True)
+ parser.add_argument("--config", dest="config", type=str, required=True)
+ parser.add_argument("--model_path", dest="model_path", type=str, required=False)
+ parser.add_argument("--test", dest="test", type=bool)
+ return parser
+
+
+parser = train_parser()
+args = parser.parse_args()
+
+tester = DiacritizationTester(args.config, args.model_kind)
+tester.run()
diff --git a/poetry_diacritizer/tester.py b/poetry_diacritizer/tester.py
new file mode 100644
index 0000000000000000000000000000000000000000..50d622e7edb0ed989fcd3273d35e74d66f11ce75
--- /dev/null
+++ b/poetry_diacritizer/tester.py
@@ -0,0 +1,63 @@
+from .config_manager import ConfigManager
+import os
+from typing import Dict
+
+from torch import nn
+from tqdm import tqdm
+from tqdm import trange
+
+from dataset import load_iterators
+from trainer import GeneralTrainer
+
+
+class DiacritizationTester(GeneralTrainer):
+ def __init__(self, config_path: str, model_kind: str) -> None:
+ self.config_path = config_path
+ self.model_kind = model_kind
+ self.config_manager = ConfigManager(
+ config_path=config_path, model_kind=model_kind
+ )
+ self.config = self.config_manager.config
+ self.pad_idx = 0
+ self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_idx)
+ self.set_device()
+
+ self.text_encoder = self.config_manager.text_encoder
+ self.start_symbol_id = self.text_encoder.start_symbol_id
+
+ self.model = self.config_manager.get_model()
+
+ self.model = self.model.to(self.device)
+
+ self.load_model(model_path=self.config["test_model_path"], load_optimizer=False)
+ self.load_diacritizer()
+ self.diacritizer.set_model(self.model)
+
+ self.initialize_model()
+
+ self.print_config()
+
+ def run(self):
+ self.config_manager.config["load_training_data"] = False
+ self.config_manager.config["load_validation_data"] = False
+ self.config_manager.config["load_test_data"] = True
+ _, test_iterator, _ = load_iterators(self.config_manager)
+ tqdm_eval = trange(0, len(test_iterator), leave=True)
+ tqdm_error_rates = trange(0, len(test_iterator), leave=True)
+
+ loss, acc = self.evaluate(test_iterator, tqdm_eval, log = False)
+ error_rates, _ = self.evaluate_with_error_rates(test_iterator, tqdm_error_rates, log = False)
+
+ tqdm_eval.close()
+ tqdm_error_rates.close()
+
+ WER = error_rates["WER"]
+ DER = error_rates["DER"]
+ DER1 = error_rates["DER*"]
+ WER1 = error_rates["WER*"]
+
+ error_rates = f"DER: {DER}, WER: {WER}, DER*: {DER1}, WER*: {WER1}"
+
+ print(f"global step : {self.global_step}")
+ print(f"Evaluate {self.global_step}: accuracy, {acc}, loss: {loss}")
+ print(f"WER/DER {self.global_step}: {error_rates}")
diff --git a/poetry_diacritizer/train.py b/poetry_diacritizer/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1c233ce494365e0bb3a31c671aa015b3ecc8043
--- /dev/null
+++ b/poetry_diacritizer/train.py
@@ -0,0 +1,49 @@
+import argparse
+import random
+
+import numpy as np
+import torch
+
+from trainer import CBHGTrainer, Seq2SeqTrainer, GPTTrainer
+
+SEED = 1234
+random.seed(SEED)
+np.random.seed(SEED)
+torch.manual_seed(SEED)
+torch.cuda.manual_seed(SEED)
+torch.backends.cudnn.deterministic = True
+torch.backends.cudnn.benchmark = False
+
+
+def train_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_kind", dest="model_kind", type=str, required=True)
+ parser.add_argument(
+ "--model_desc", dest="model_desc", type=str, required=False, default=""
+ )
+ parser.add_argument("--config", dest="config", type=str, required=True)
+ parser.add_argument(
+ "--reset_dir",
+ dest="clear_dir",
+ action="store_true",
+ help="deletes everything under this config's folder.",
+ )
+ return parser
+
+
+parser = train_parser()
+args = parser.parse_args()
+
+
+if args.model_kind in ["seq2seq"]:
+ trainer = Seq2SeqTrainer(args.config, args.model_kind, args.model_desc)
+elif args.model_kind in ["tacotron_based"]:
+ trainer = Seq2SeqTrainer(args.config, args.model_kind, args.model_desc)
+elif args.model_kind in ["baseline", "cbhg"]:
+ trainer = CBHGTrainer(args.config, args.model_kind, args.model_desc)
+elif args.model_kind in ["gpt"]:
+ trainer = GPTTrainer(args.config, args.model_kind, args.model_desc)
+else:
+ raise ValueError("The model kind is not supported")
+
+trainer.run()
diff --git a/poetry_diacritizer/trainer.py b/poetry_diacritizer/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..748a21465d7c93ad8fdc374fbc6bd6d40a575ee7
--- /dev/null
+++ b/poetry_diacritizer/trainer.py
@@ -0,0 +1,447 @@
+import os
+from typing import Dict
+
+from diacritization_evaluation import der, wer
+import torch
+from torch import nn
+from torch import optim
+from torch.cuda.amp import autocast
+from torch.utils.tensorboard.writer import SummaryWriter
+from tqdm import tqdm
+from tqdm import trange
+
+from .config_manager import ConfigManager
+from dataset import load_iterators
+from diacritizer import CBHGDiacritizer, Seq2SeqDiacritizer, GPTDiacritizer
+from poetry_diacritizer.util.learning_rates import LearningRateDecay
+from poetry_diacritizer.options import OptimizerType
+from poetry_diacritizer.util.utils import (
+ categorical_accuracy,
+ count_parameters,
+ initialize_weights,
+ plot_alignment,
+ repeater,
+)
+
+import wandb
+
+wandb.login()
+
+
+class Trainer:
+ def run(self):
+ raise NotImplementedError
+
+
+class GeneralTrainer(Trainer):
+ def __init__(self, config_path: str, model_kind: str, model_desc: str) -> None:
+ self.config_path = config_path
+ self.model_kind = model_kind
+ self.config_manager = ConfigManager(
+ config_path=config_path, model_kind=model_kind
+ )
+ self.config = self.config_manager.config
+ self.losses = []
+ self.lr = 0
+ self.pad_idx = 0
+ self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_idx)
+ self.set_device()
+
+ self.config_manager.create_remove_dirs()
+ self.text_encoder = self.config_manager.text_encoder
+ self.start_symbol_id = self.text_encoder.start_symbol_id
+ self.summary_manager = SummaryWriter(log_dir=self.config_manager.log_dir)
+ if model_desc == "":
+ model_desc = self.model_kind
+ wandb.init(project="diacratization", name=model_desc, config=self.config)
+ self.model = self.config_manager.get_model()
+
+ self.optimizer = self.get_optimizer()
+ self.model = self.model.to(self.device)
+
+ self.load_model(model_path=self.config.get("train_resume_model_path"))
+ self.load_diacritizer()
+
+ self.initialize_model()
+
+ self.print_config()
+
+ def set_device(self):
+ if self.config.get("device"):
+ self.device = self.config["device"]
+ else:
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ def print_config(self):
+ self.config_manager.dump_config()
+ self.config_manager.print_config()
+
+ if self.global_step > 1:
+ print(f"loaded form {self.global_step}")
+
+ parameters_count = count_parameters(self.model)
+ print(f"The model has {parameters_count} trainable parameters parameters")
+
+ def load_diacritizer(self):
+ if self.model_kind in ["cbhg", "baseline"]:
+ self.diacritizer = CBHGDiacritizer(self.config_path, self.model_kind)
+ elif self.model_kind in ["seq2seq", "tacotron_based"]:
+ self.diacritizer = Seq2SeqDiacritizer(self.config_path, self.model_kind)
+ elif self.model_kind in ["gpt"]:
+ self.diacritizer = GPTDiacritizer(self.config_path, self.model_kind)
+
+ def initialize_model(self):
+ if self.global_step > 1:
+ return
+ if self.model_kind == "transformer":
+ print("Initializing using xavier_uniform_")
+ self.model.apply(initialize_weights)
+
+ def print_losses(self, step_results, tqdm):
+ self.summary_manager.add_scalar(
+ "loss/loss", step_results["loss"], global_step=self.global_step
+ )
+
+ tqdm.display(f"loss: {step_results['loss']}", pos=3)
+ for pos, n_steps in enumerate(self.config["n_steps_avg_losses"]):
+ if len(self.losses) > n_steps:
+
+ self.summary_manager.add_scalar(
+ f"loss/loss-{n_steps}",
+ sum(self.losses[-n_steps:]) / n_steps,
+ global_step=self.global_step,
+ )
+ tqdm.display(
+ f"{n_steps}-steps average loss: {sum(self.losses[-n_steps:]) / n_steps}",
+ pos=pos + 4,
+ )
+
+ def evaluate(self, iterator, tqdm, use_target=True, log = True):
+ epoch_loss = 0
+ epoch_acc = 0
+ self.model.eval()
+ tqdm.set_description(f"Eval: {self.global_step}")
+ with torch.no_grad():
+ for batch_inputs in iterator:
+ batch_inputs["src"] = batch_inputs["src"].to(self.device)
+ batch_inputs["lengths"] = batch_inputs["lengths"].to("cpu")
+ if use_target:
+ batch_inputs["target"] = batch_inputs["target"].to(self.device)
+ else:
+ batch_inputs["target"] = None
+
+ outputs = self.model(
+ src=batch_inputs["src"],
+ target=batch_inputs["target"],
+ lengths=batch_inputs["lengths"],
+ )
+
+ predictions = outputs["diacritics"]
+
+ predictions = predictions.view(-1, predictions.shape[-1])
+ targets = batch_inputs["target"]
+ targets = targets.view(-1)
+ loss = self.criterion(predictions, targets.to(self.device))
+ acc = categorical_accuracy(
+ predictions, targets.to(self.device), self.pad_idx
+ )
+
+ epoch_loss += loss.item()
+ epoch_acc += acc.item()
+ if log:
+ wandb.log({"evaluate_loss": loss.item(), "evaluate_acc": acc.item()})
+ tqdm.update()
+
+ tqdm.reset()
+ return epoch_loss / len(iterator), epoch_acc / len(iterator)
+
+ def evaluate_with_error_rates(self, iterator, tqdm, log = True):
+ all_orig = []
+ all_predicted = []
+ results = {}
+ self.diacritizer.set_model(self.model)
+ evaluated_batches = 0
+ tqdm.set_description(f"Calculating DER/WER {self.global_step}: ")
+ for i, batch in enumerate(iterator):
+ if evaluated_batches > int(self.config["error_rates_n_batches"]):
+ break
+
+ predicted = self.diacritizer.diacritize_batch(batch)
+ all_predicted += predicted
+ all_orig += batch["original"]
+ if i > self.config["max_eval_batches"]:
+ break
+ tqdm.update()
+
+ summary_texts = []
+ orig_path = os.path.join(self.config_manager.prediction_dir, f"original.txt")
+ predicted_path = os.path.join(
+ self.config_manager.prediction_dir, f"predicted.txt"
+ )
+
+ table = wandb.Table(columns=["original", "predicted"])
+ with open(orig_path, "w", encoding="utf8") as file:
+ for sentence in all_orig:
+ file.write(f"{sentence}\n")
+
+ with open(predicted_path, "w", encoding="utf8") as file:
+ for sentence in all_predicted:
+ file.write(f"{sentence}\n")
+
+ for i in range(int(self.config["n_predicted_text_tensorboard"])):
+ if i > len(all_predicted):
+ break
+
+ summary_texts.append(
+ (f"eval-text/{i}", f"{ all_orig[i]} |-> {all_predicted[i]}")
+ )
+ if i < 10:
+ table.add_data(all_orig[i], all_predicted[i])
+
+ if log:
+ wandb.log({f"prediction_{self.global_step}": table}, commit=False)
+
+ results["DER"] = der.calculate_der_from_path(orig_path, predicted_path)
+ results["DER*"] = der.calculate_der_from_path(
+ orig_path, predicted_path, case_ending=False
+ )
+ results["WER"] = wer.calculate_wer_from_path(orig_path, predicted_path)
+ results["WER*"] = wer.calculate_wer_from_path(
+ orig_path, predicted_path, case_ending=False
+ )
+ if log:
+ wandb.log(results)
+ tqdm.reset()
+ return results, summary_texts
+
+ def run(self):
+ scaler = torch.cuda.amp.GradScaler()
+ train_iterator, _, validation_iterator = load_iterators(self.config_manager)
+ print("data loaded")
+ print("----------------------------------------------------------")
+ tqdm_eval = trange(0, len(validation_iterator), leave=True)
+ tqdm_error_rates = trange(0, len(validation_iterator), leave=True)
+ tqdm_eval.set_description("Eval")
+ tqdm_error_rates.set_description("WER/DER : ")
+ tqdm = trange(self.global_step, self.config["max_steps"] + 1, leave=True)
+
+ for batch_inputs in repeater(train_iterator):
+ tqdm.set_description(f"Global Step {self.global_step}")
+ if self.config["use_decay"]:
+ self.lr = self.adjust_learning_rate(
+ self.optimizer, global_step=self.global_step
+ )
+ self.optimizer.zero_grad()
+ if self.device == "cuda" and self.config["use_mixed_precision"]:
+ with autocast():
+ step_results = self.run_one_step(batch_inputs)
+ scaler.scale(step_results["loss"]).backward()
+ scaler.unscale_(self.optimizer)
+ if self.config.get("CLIP"):
+ torch.nn.utils.clip_grad_norm_(
+ self.model.parameters(), self.config["CLIP"]
+ )
+
+ scaler.step(self.optimizer)
+
+ scaler.update()
+ else:
+ step_results = self.run_one_step(batch_inputs)
+
+ loss = step_results["loss"]
+ loss.backward()
+ if self.config.get("CLIP"):
+ torch.nn.utils.clip_grad_norm_(
+ self.model.parameters(), self.config["CLIP"]
+ )
+ self.optimizer.step()
+
+ self.losses.append(step_results["loss"].item())
+ wandb.log({"train_loss": step_results["loss"].item()})
+
+ self.print_losses(step_results, tqdm)
+
+ self.summary_manager.add_scalar(
+ "meta/learning_rate", self.lr, global_step=self.global_step
+ )
+
+ if self.global_step % self.config["model_save_frequency"] == 0:
+ torch.save(
+ {
+ "global_step": self.global_step,
+ "model_state_dict": self.model.state_dict(),
+ "optimizer_state_dict": self.optimizer.state_dict(),
+ },
+ os.path.join(
+ self.config_manager.models_dir,
+ f"{self.global_step}-snapshot.pt",
+ ),
+ )
+
+ if self.global_step % self.config["evaluate_frequency"] == 0:
+ loss, acc = self.evaluate(validation_iterator, tqdm_eval)
+ self.summary_manager.add_scalar(
+ "evaluate/loss", loss, global_step=self.global_step
+ )
+ self.summary_manager.add_scalar(
+ "evaluate/acc", acc, global_step=self.global_step
+ )
+ tqdm.display(
+ f"Evaluate {self.global_step}: accuracy, {acc}, loss: {loss}", pos=8
+ )
+ self.model.train()
+
+ if (
+ self.global_step % self.config["evaluate_with_error_rates_frequency"]
+ == 0
+ ):
+ error_rates, summery_texts = self.evaluate_with_error_rates(
+ validation_iterator, tqdm_error_rates
+ )
+ if error_rates:
+ WER = error_rates["WER"]
+ DER = error_rates["DER"]
+ DER1 = error_rates["DER*"]
+ WER1 = error_rates["WER*"]
+
+ self.summary_manager.add_scalar(
+ "error_rates/WER",
+ WER / 100,
+ global_step=self.global_step,
+ )
+ self.summary_manager.add_scalar(
+ "error_rates/DER",
+ DER / 100,
+ global_step=self.global_step,
+ )
+ self.summary_manager.add_scalar(
+ "error_rates/DER*",
+ DER1 / 100,
+ global_step=self.global_step,
+ )
+ self.summary_manager.add_scalar(
+ "error_rates/WER*",
+ WER1 / 100,
+ global_step=self.global_step,
+ )
+
+ error_rates = f"DER: {DER}, WER: {WER}, DER*: {DER1}, WER*: {WER1}"
+ tqdm.display(f"WER/DER {self.global_step}: {error_rates}", pos=9)
+
+ for tag, text in summery_texts:
+ self.summary_manager.add_text(tag, text)
+
+ self.model.train()
+
+ if self.global_step % self.config["train_plotting_frequency"] == 0:
+ self.plot_attention(step_results)
+
+ self.report(step_results, tqdm)
+
+ self.global_step += 1
+ if self.global_step > self.config["max_steps"]:
+ print("Training Done.")
+ return
+
+ tqdm.update()
+
+ def run_one_step(self, batch_inputs: Dict[str, torch.Tensor]):
+ batch_inputs["src"] = batch_inputs["src"].to(self.device)
+ batch_inputs["lengths"] = batch_inputs["lengths"].to("cpu")
+ batch_inputs["target"] = batch_inputs["target"].to(self.device)
+
+ outputs = self.model(
+ src=batch_inputs["src"],
+ target=batch_inputs["target"],
+ lengths=batch_inputs["lengths"],
+ )
+
+ predictions = outputs["diacritics"].contiguous()
+ targets = batch_inputs["target"].contiguous()
+ predictions = predictions.view(-1, predictions.shape[-1])
+ targets = targets.view(-1)
+ loss = self.criterion(predictions.to(self.device), targets.to(self.device))
+ outputs.update({"loss": loss})
+ return outputs
+
+ def predict(self, iterator):
+ pass
+
+ def load_model(self, model_path: str = None, load_optimizer: bool = True):
+ with open(
+ self.config_manager.base_dir / f"{self.model_kind}_network.txt", "w"
+ ) as file:
+ file.write(str(self.model))
+
+ if model_path is None:
+ last_model_path = self.config_manager.get_last_model_path()
+ if last_model_path is None:
+ self.global_step = 1
+ return
+ else:
+ last_model_path = model_path
+
+ print(f"loading from {last_model_path}")
+ saved_model = torch.load(last_model_path)
+ self.model.load_state_dict(saved_model["model_state_dict"])
+ if load_optimizer:
+ self.optimizer.load_state_dict(saved_model["optimizer_state_dict"])
+ self.global_step = saved_model["global_step"] + 1
+
+ def get_optimizer(self):
+ if self.config["optimizer"] == OptimizerType.Adam:
+ optimizer = optim.Adam(
+ self.model.parameters(),
+ lr=self.config["learning_rate"],
+ betas=(self.config["adam_beta1"], self.config["adam_beta2"]),
+ weight_decay=self.config["weight_decay"],
+ )
+ elif self.config["optimizer"] == OptimizerType.SGD:
+ optimizer = optim.SGD(
+ self.model.parameters(), lr=self.config["learning_rate"], momentum=0.9
+ )
+ else:
+ raise ValueError("Optimizer option is not valid")
+
+ return optimizer
+
+ def get_learning_rate(self):
+ return LearningRateDecay(
+ lr=self.config["learning_rate"],
+ warmup_steps=self.config.get("warmup_steps", 4000.0),
+ )
+
+ def adjust_learning_rate(self, optimizer, global_step):
+ learning_rate = self.get_learning_rate()(global_step=global_step)
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = learning_rate
+ return learning_rate
+
+ def plot_attention(self, results):
+ pass
+
+ def report(self, results, tqdm):
+ pass
+
+
+class Seq2SeqTrainer(GeneralTrainer):
+ def plot_attention(self, results):
+ plot_alignment(
+ results["attention"][0],
+ str(self.config_manager.plot_dir),
+ self.global_step,
+ )
+
+ self.summary_manager.add_image(
+ "Train/attention",
+ results["attention"][0].unsqueeze(0),
+ global_step=self.global_step,
+ )
+
+
+class GPTTrainer(GeneralTrainer):
+ pass
+
+
+class CBHGTrainer(GeneralTrainer):
+ pass
diff --git a/poetry_diacritizer/util/__pycache__/constants.cpython-310.pyc b/poetry_diacritizer/util/__pycache__/constants.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7667a7b623fb5c0a2c33e2d0631354406f5d11a8
Binary files /dev/null and b/poetry_diacritizer/util/__pycache__/constants.cpython-310.pyc differ
diff --git a/poetry_diacritizer/util/__pycache__/constants.cpython-38.pyc b/poetry_diacritizer/util/__pycache__/constants.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f76f18037ad096b27541207785aebd552fb22f1d
Binary files /dev/null and b/poetry_diacritizer/util/__pycache__/constants.cpython-38.pyc differ
diff --git a/poetry_diacritizer/util/__pycache__/decorators.cpython-310.pyc b/poetry_diacritizer/util/__pycache__/decorators.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c30a0f7e5f2f11e2ab5980adf1b15932c89bf8c7
Binary files /dev/null and b/poetry_diacritizer/util/__pycache__/decorators.cpython-310.pyc differ
diff --git a/poetry_diacritizer/util/__pycache__/decorators.cpython-38.pyc b/poetry_diacritizer/util/__pycache__/decorators.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac7e742ef78842ad558ae8f9aeec340fefa74944
Binary files /dev/null and b/poetry_diacritizer/util/__pycache__/decorators.cpython-38.pyc differ
diff --git a/poetry_diacritizer/util/__pycache__/learning_rates.cpython-310.pyc b/poetry_diacritizer/util/__pycache__/learning_rates.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ebc9a5fe65ba911226f424d45bef6d8ce9934015
Binary files /dev/null and b/poetry_diacritizer/util/__pycache__/learning_rates.cpython-310.pyc differ
diff --git a/poetry_diacritizer/util/__pycache__/learning_rates.cpython-38.pyc b/poetry_diacritizer/util/__pycache__/learning_rates.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..924dd008ea08e2e555edf64a1511b7b3bfd1884d
Binary files /dev/null and b/poetry_diacritizer/util/__pycache__/learning_rates.cpython-38.pyc differ
diff --git a/poetry_diacritizer/util/__pycache__/text_cleaners.cpython-310.pyc b/poetry_diacritizer/util/__pycache__/text_cleaners.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70986b3183eb0addd941bc805cce7c5d1edd68af
Binary files /dev/null and b/poetry_diacritizer/util/__pycache__/text_cleaners.cpython-310.pyc differ
diff --git a/poetry_diacritizer/util/__pycache__/text_cleaners.cpython-38.pyc b/poetry_diacritizer/util/__pycache__/text_cleaners.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d341a35ceb11174c762ef87efbf0d484f2600f34
Binary files /dev/null and b/poetry_diacritizer/util/__pycache__/text_cleaners.cpython-38.pyc differ
diff --git a/poetry_diacritizer/util/__pycache__/text_encoders.cpython-310.pyc b/poetry_diacritizer/util/__pycache__/text_encoders.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eba84512ec27dcf06e859e53b3ce3d1ba1dbb5d0
Binary files /dev/null and b/poetry_diacritizer/util/__pycache__/text_encoders.cpython-310.pyc differ
diff --git a/poetry_diacritizer/util/__pycache__/text_encoders.cpython-38.pyc b/poetry_diacritizer/util/__pycache__/text_encoders.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b0bac5dd415dee85738f570ceb1efdc2b93f25a8
Binary files /dev/null and b/poetry_diacritizer/util/__pycache__/text_encoders.cpython-38.pyc differ
diff --git a/poetry_diacritizer/util/__pycache__/utils.cpython-310.pyc b/poetry_diacritizer/util/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c67cfab38f6864b98242288e80d669e5ebdf2aa3
Binary files /dev/null and b/poetry_diacritizer/util/__pycache__/utils.cpython-310.pyc differ
diff --git a/poetry_diacritizer/util/__pycache__/utils.cpython-38.pyc b/poetry_diacritizer/util/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..39437988f40485fd7565bf01bbe6c8634176c59f
Binary files /dev/null and b/poetry_diacritizer/util/__pycache__/utils.cpython-38.pyc differ
diff --git a/poetry_diacritizer/util/constants.py b/poetry_diacritizer/util/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..2915b2846e5f1b1678991e81f6572776ace8a4c9
--- /dev/null
+++ b/poetry_diacritizer/util/constants.py
@@ -0,0 +1,34 @@
+"""
+Constants that are used by the model
+"""
+HARAQAT = ["ْ", "ّ", "ٌ", "ٍ", "ِ", "ً", "َ", "ُ"]
+ARAB_CHARS = "ىعظحرسيشضق ثلصطكآماإهزءأفؤغجئدةخوبذتن"
+PUNCTUATIONS = [".", "،", ":", "؛", "-", "؟"]
+VALID_ARABIC = HARAQAT + list(ARAB_CHARS)
+BASIC_HARAQAT = {
+ "َ": "Fatha ",
+ "ً": "Fathatah ",
+ "ُ": "Damma ",
+ "ٌ": "Dammatan ",
+ "ِ": "Kasra ",
+ "ٍ": "Kasratan ",
+ "ْ": "Sukun ",
+ "ّ": "Shaddah ",
+}
+ALL_POSSIBLE_HARAQAT = {
+ "": "No Diacritic ",
+ "َ": "Fatha ",
+ "ً": "Fathatah ",
+ "ُ": "Damma ",
+ "ٌ": "Dammatan ",
+ "ِ": "Kasra ",
+ "ٍ": "Kasratan ",
+ "ْ": "Sukun ",
+ "ّ": "Shaddah ",
+ "َّ": "Shaddah + Fatha ",
+ "ًّ": "Shaddah + Fathatah ",
+ "ُّ": "Shaddah + Damma ",
+ "ٌّ": "Shaddah + Dammatan ",
+ "ِّ": "Shaddah + Kasra ",
+ "ٍّ": "Shaddah + Kasratan ",
+}
diff --git a/poetry_diacritizer/util/decorators.py b/poetry_diacritizer/util/decorators.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a1a46c8ae63dfb6d9cb99c0ef7321c26985f275
--- /dev/null
+++ b/poetry_diacritizer/util/decorators.py
@@ -0,0 +1,27 @@
+import traceback
+from time import time
+
+
+def ignore_exception(f):
+ def apply_func(*args, **kwargs):
+ try:
+ result = f(*args, **kwargs)
+ return result
+ except Exception:
+ if False:
+ print(f"Catched exception in {f}:")
+ traceback.print_exc()
+ return None
+
+ return apply_func
+
+
+def time_it(f):
+ def apply_func(*args, **kwargs):
+ t_start = time()
+ result = f(*args, **kwargs)
+ t_end = time()
+ dur = round(t_end - t_start, ndigits=2)
+ return result, dur
+
+ return apply_func
diff --git a/poetry_diacritizer/util/learning_rates.py b/poetry_diacritizer/util/learning_rates.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd3325b4ed746f2d65e00750e40156aef6b6d851
--- /dev/null
+++ b/poetry_diacritizer/util/learning_rates.py
@@ -0,0 +1,70 @@
+import numpy as np
+import math
+
+
+class LearningRateDecay:
+ def __init__(self, lr=0.002, warmup_steps=4000.0) -> None:
+ self.lr = lr
+ self.warmup_steps = warmup_steps
+
+ def __call__(self, global_step) -> float:
+ step = global_step + 1.0
+ lr = (
+ self.lr
+ * self.warmup_steps ** 0.5
+ * np.minimum(step * self.warmup_steps ** -1.5, step ** -0.5)
+ )
+
+ return lr
+
+class SquareRootScheduler:
+ def __init__(self, lr=0.1):
+ self.lr = lr
+
+ def __call__(self, global_step):
+ global_step = global_step // 1000
+ return self.lr * pow(global_step + 1.0, -0.5)
+
+
+class CosineScheduler:
+ def __init__(
+ self, max_update, base_lr=0.02, final_lr=0, warmup_steps=0, warmup_begin_lr=0
+ ):
+ self.base_lr_orig = base_lr
+ self.max_update = max_update
+ self.final_lr = final_lr
+ self.warmup_steps = warmup_steps
+ self.warmup_begin_lr = warmup_begin_lr
+ self.max_steps = self.max_update - self.warmup_steps
+
+ def get_warmup_lr(self, global_step):
+ increase = (
+ (self.base_lr_orig - self.warmup_begin_lr)
+ * float(global_step)
+ / float(self.warmup_steps)
+ )
+ return self.warmup_begin_lr + increase
+
+ def __call__(self, global_step):
+ if global_step < self.warmup_steps:
+ return self.get_warmup_lr(global_step)
+ if global_step <= self.max_update:
+ self.base_lr = (
+ self.final_lr
+ + (self.base_lr_orig - self.final_lr)
+ * (
+ 1
+ + math.cos(
+ math.pi * (global_step - self.warmup_steps) / self.max_steps
+ )
+ )
+ / 2
+ )
+ return self.base_lr
+
+def adjust_learning_rate(optimizer, global_step):
+ lr = LearningRateDecay()(global_step=global_step)
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+ return lr
+
diff --git a/poetry_diacritizer/util/text_cleaners.py b/poetry_diacritizer/util/text_cleaners.py
new file mode 100644
index 0000000000000000000000000000000000000000..04b66ee7a261feb58e5636147e9af1213abb2c75
--- /dev/null
+++ b/poetry_diacritizer/util/text_cleaners.py
@@ -0,0 +1,146 @@
+import re
+from .constants import VALID_ARABIC
+from itertools import product, combinations
+
+_whitespace_re = re.compile(r"\s+")
+
+
+def collapse_whitespace(text):
+ text = re.sub(_whitespace_re, " ", text)
+ return text
+
+
+def basic_cleaners(text):
+ text = collapse_whitespace(text)
+ return text.strip()
+
+
+# def valid_arabic_cleaners(text):
+# text = filter(lambda char: char in VALID_ARABIC, text)
+# text = collapse_whitespace(''.join(list(text)))
+# return text.strip()
+
+harakat = ["\u0650", "\u064E", "\u064F"] # [kasra, fatha, damma, ]
+sukun = ["\u0652"] # [sukun]
+mostly_saken = [
+ "\u0627",
+ "\u0648",
+ "\u0649",
+ "\u064A",
+] # [alef, waw, alef maqsurah, ya'a]
+
+always_saken = [
+ "\u0627",
+ "\u0649",
+]
+
+tnween_chars = [
+ "\u064c",
+ "\u064d",
+ "\u064b",
+] # damm tanween, kasra tanween, fatha tanween, maddah
+shadda_chars = ["\u0651"]
+all_tashkeel = harakat+tnween_chars+sukun+shadda_chars
+
+
+all_chars = list("إةابتثجحخدذرزسشصضطظعغفقكلمنهويىأءئؤ ")
+prem_chars = harakat + sukun + mostly_saken + tnween_chars + shadda_chars + all_chars
+
+def not_valid_tashkeel_comb(comb):
+ all_comb = list(product(harakat+sukun+tnween_chars, repeat = 2))+list(product(shadda_chars+sukun, repeat = 2))
+ if comb in all_comb or comb[::-1] in all_comb:
+ return True
+ else:
+ return False
+
+def remove_tanween_on_alef(text):
+ text_copy = ""
+ for i in range(0, len(text)):
+
+ # if there is shaddah or character followed by alef followed by tanween add
+ if i < len(text) - 2 and text[i] in all_chars+shadda_chars and text[i+1] in always_saken and text[i+2] == tnween_chars[2]:
+ text_copy += text[i] + tnween_chars[2]
+
+ #ignore current harakah if there is alef followed by tanween
+ elif i < len(text) - 2 and text[i] in harakat and text[i+1] in always_saken and text[i+2] == tnween_chars[2] :
+ text_copy += tnween_chars[2]
+
+ # if the current char is tanween with alef is the previous character drop tanween
+ elif i > 0 and text[i] == tnween_chars[2] and text[i-1] in always_saken:
+ continue
+
+ else:
+ text_copy += text[i]
+ return text_copy
+
+def dont_start_by_harakah(text):
+ text_copy = ""
+ for i, char in enumerate(text):
+ if not(char in all_tashkeel):
+ text_copy = text[i:]
+ break
+ return text_copy
+
+def valid_arabic_cleaners(text):
+ prev_text = text
+ for i in range(5):
+ text = prev_text
+ cleaned_text = ""
+ text = filter(lambda char: char in VALID_ARABIC, text)
+ text = collapse_whitespace(''.join(list(text)))
+ text = dont_start_by_harakah(text)
+ text = text.strip()
+ i = 0
+ cnt = 0
+ len_text = len(text)
+ while( i < len_text):
+ if text[i] in all_tashkeel:
+ cnt += 1
+ else:
+ cnt = 0
+
+ # don't allow three consecutive tashkeel
+ if cnt > 2:
+ i+= 1
+ continue
+
+ # remove second tanween and sukun
+ if i > 1 and text[i] in tnween_chars+sukun and text[i-2] in tnween_chars+sukun:
+ i += 1
+ continue
+
+ # don't allow harakah followed by shaddah or tanween
+ if i < len(text) - 1 and text[i] in harakat and text[i+1] in tnween_chars+sukun+shadda_chars:
+ i += 1
+ continue
+
+ # don't allow harkah on space
+ if i> 0 and text[i] in all_tashkeel and text[i-1] == " " :
+ i += 1
+ continue
+
+ # only allow permissable combinations
+ if not_valid_tashkeel_comb((text[i], text[i-1])):
+ i+=1
+ continue
+
+ # don't allow harkah on alef, alef maqsura, if there is no tashkeel before move it back
+ if i> 1 and text[i] in harakat and text[i-1] in always_saken :
+ if text[i-2] in all_tashkeel: # in case there is a tashkeelah before alef
+ continue
+ else:
+ cleaned_text = text[:i-1]+text[i]+ always_saken[always_saken.index(text[i-1])]
+ i += 1
+
+ if i < len(text):
+ cleaned_text+= text[i]
+ i += 1
+
+ # only allow tanween before alef
+ cleaned_text = remove_tanween_on_alef(cleaned_text)
+ cleaned_text = re.sub(r" +", " ", cleaned_text).strip()
+ if prev_text == cleaned_text:
+ break
+ else:
+ prev_text = cleaned_text
+ return cleaned_text
\ No newline at end of file
diff --git a/poetry_diacritizer/util/text_encoders.py b/poetry_diacritizer/util/text_encoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..b49c5603afa2d41ad6e0145b719443f0f4ce9301
--- /dev/null
+++ b/poetry_diacritizer/util/text_encoders.py
@@ -0,0 +1,160 @@
+from . import text_cleaners
+from typing import Dict, List, Optional
+from .constants import ALL_POSSIBLE_HARAQAT
+import sentencepiece as spm
+
+
+class TextEncoder:
+ pad = "P"
+
+ def __init__(
+ self,
+ input_chars: List[str],
+ target_charts: List[str],
+ cleaner_fn: Optional[str] = None,
+ reverse_input: bool = False,
+ reverse_target: bool = False,
+ sp_model_path=None,
+ ):
+ if cleaner_fn:
+ self.cleaner_fn = getattr(text_cleaners, cleaner_fn)
+ else:
+ self.cleaner_fn = None
+
+ self.input_symbols: List[str] = [TextEncoder.pad] + input_chars
+ self.target_symbols: List[str] = [TextEncoder.pad] + target_charts
+
+ if sp_model_path is None:
+ self.input_symbol_to_id: Dict[str, int] = {
+ s: i for i, s in enumerate(self.input_symbols)
+ }
+ self.input_id_to_symbol: Dict[int, str] = {
+ i: s for i, s in enumerate(self.input_symbols)
+ }
+ else:
+ sp_model = spm.SentencePieceProcessor()
+ sp_model.load(sp_model_path + "/sp.model")
+ self.input_symbol_to_id: Dict[str, int] = {
+ s: sp_model.PieceToId(s+'▁') for s in self.input_symbols
+ }
+ self.input_symbol_to_id[" "] = sp_model.PieceToId("|") # encode space
+ self.input_symbol_to_id[TextEncoder.pad] = 0 # encode padding
+
+ self.input_space_id = sp_model.PieceToId("|")
+ self.input_id_to_symbol: Dict[int, str] = {
+ i: s for s, i in self.input_symbol_to_id.items()
+ }
+
+ self.target_symbol_to_id: Dict[str, int] = {
+ s: i for i, s in enumerate(self.target_symbols)
+ }
+ self.target_id_to_symbol: Dict[int, str] = {
+ i: s for i, s in enumerate(self.target_symbols)
+ }
+
+ self.reverse_input = reverse_input
+ self.reverse_target = reverse_target
+ self.input_pad_id = self.input_symbol_to_id[self.pad]
+ self.target_pad_id = self.target_symbol_to_id[self.pad]
+ self.start_symbol_id = None
+
+ def input_to_sequence(self, text: str) -> List[int]:
+ if self.reverse_input:
+ text = "".join(list(reversed(text)))
+ sequence = [self.input_symbol_to_id[s] for s in text if s not in [self.pad]]
+
+ return sequence
+
+ def target_to_sequence(self, text: str) -> List[int]:
+ if self.reverse_target:
+ text = "".join(list(reversed(text)))
+ sequence = [self.target_symbol_to_id[s] for s in text if s not in [self.pad]]
+
+ return sequence
+
+ def sequence_to_input(self, sequence: List[int]):
+ return [
+ self.input_id_to_symbol[symbol]
+ for symbol in sequence
+ if symbol in self.input_id_to_symbol and symbol not in [self.input_pad_id]
+ ]
+
+ def sequence_to_target(self, sequence: List[int]):
+ return [
+ self.target_id_to_symbol[symbol]
+ for symbol in sequence
+ if symbol in self.target_id_to_symbol and symbol not in [self.target_pad_id]
+ ]
+
+ def clean(self, text):
+ if self.cleaner_fn:
+ return self.cleaner_fn(text)
+ return text
+
+ def combine_text_and_haraqat(self, input_ids: List[int], output_ids: List[int]):
+ """
+ Combines the input text with its corresponding haraqat
+ Args:
+ inputs: a list of ids representing the input text
+ outputs: a list of ids representing the output text
+ Returns:
+ text: the text after merging the inputs text representation with the output
+ representation
+ """
+ output = ""
+ for i, input_id in enumerate(input_ids):
+ if input_id == self.input_pad_id:
+ break
+ output += self.input_id_to_symbol[input_id]
+ # if input_id == self.input_space_id:
+ # continue
+ output += self.target_id_to_symbol[output_ids[i]]
+ return output
+
+ def __str__(self):
+ return type(self).__name__
+
+
+class BasicArabicEncoder(TextEncoder):
+ def __init__(
+ self,
+ cleaner_fn="basic_cleaners",
+ reverse_input: bool = False,
+ reverse_target: bool = False,
+ sp_model_path=None,
+ ):
+ input_chars: List[str] = list("بض.غىهظخة؟:طس،؛فندؤلوئآك-يذاصشحزءمأجإ ترقعث")
+ target_charts: List[str] = list(ALL_POSSIBLE_HARAQAT.keys())
+
+ super().__init__(
+ input_chars,
+ target_charts,
+ cleaner_fn=cleaner_fn,
+ reverse_input=reverse_input,
+ reverse_target=reverse_target,
+ sp_model_path=sp_model_path,
+ )
+
+
+class ArabicEncoderWithStartSymbol(TextEncoder):
+ def __init__(
+ self,
+ cleaner_fn="basic_cleaners",
+ reverse_input: bool = False,
+ reverse_target: bool = False,
+ sp_model_path=None,
+ ):
+ input_chars: List[str] = list("بض.غىهظخة؟:طس،؛فندؤلوئآك-يذاصشحزءمأجإ ترقعث")
+ # the only difference from the basic encoder is adding the start symbol
+ target_charts: List[str] = list(ALL_POSSIBLE_HARAQAT.keys()) + ["s"]
+
+ super().__init__(
+ input_chars,
+ target_charts,
+ cleaner_fn=cleaner_fn,
+ reverse_input=reverse_input,
+ reverse_target=reverse_target,
+ sp_model_path=sp_model_path,
+ )
+
+ self.start_symbol_id = self.target_symbol_to_id["s"]
diff --git a/poetry_diacritizer/util/utils.py b/poetry_diacritizer/util/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9d33ca2361e48e9781cfee644dd9ddcffd6a59a
--- /dev/null
+++ b/poetry_diacritizer/util/utils.py
@@ -0,0 +1,238 @@
+import os
+from typing import Any
+
+import matplotlib.pyplot as plt
+import torch
+from torch import nn
+from itertools import repeat
+from poetry_diacritizer.util.decorators import ignore_exception
+from dataclasses import dataclass
+import numpy as np
+
+
+@dataclass
+class ErrorRate:
+ wer: float
+ der: float
+ wer_without_case_ending: float
+ der_without_case_ending: float
+
+
+def epoch_time(start_time, end_time):
+ elapsed_time = end_time - start_time
+ elapsed_mins = int(elapsed_time / 60)
+ elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
+ return elapsed_mins, elapsed_secs
+
+
+@ignore_exception
+def plot_alignment(alignment: torch.Tensor, path: str, global_step: Any = 0):
+ """
+ Plot alignment and save it into a path
+ Args:
+ alignment (Tensor): the encoder-decoder alignment
+ path (str): a path used to save the alignment plot
+ global_step (int): used in the name of the output alignment plot
+ """
+ alignment = alignment.squeeze(1).transpose(0, 1).cpu().detach().numpy()
+ fig, axs = plt.subplots()
+ img = axs.imshow(alignment, aspect="auto", origin="lower", interpolation="none")
+ fig.colorbar(img, ax=axs)
+ xlabel = "Decoder timestep"
+ plt.xlabel(xlabel)
+ plt.ylabel("Encoder timestep")
+ plt.tight_layout()
+ plot_name = f"{global_step}.png"
+ plt.savefig(os.path.join(path, plot_name), dpi=300, format="png")
+ plt.close()
+
+
+def get_mask_from_lengths(memory, memory_lengths):
+ """Get mask tensor from list of length
+ Args:
+ memory: (batch, max_time, dim)
+ memory_lengths: array like
+ """
+ mask = memory.data.new(memory.size(0), memory.size(1)).bool().zero_()
+ for idx, length in enumerate(memory_lengths):
+ mask[idx][:length] = 1
+ return ~mask
+
+
+def repeater(data_loader):
+ for loader in repeat(data_loader):
+ for data in loader:
+ yield data
+
+
+def count_parameters(model):
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+def initialize_weights(m):
+ if hasattr(m, "weight") and m.weight.dim() > 1:
+ nn.init.xavier_uniform_(m.weight.data)
+
+
+def get_encoder_layers_attentions(model):
+ attentions = []
+ for layer in model.encoder.layers:
+ attentions.append(layer.self_attention.attention)
+ return attentions
+
+
+def get_decoder_layers_attentions(model):
+ self_attns, src_attens = [], []
+ for layer in model.decoder.layers:
+ self_attns.append(layer.self_attention.attention)
+ src_attens.append(layer.encoder_attention.attention)
+ return self_attns, src_attens
+
+
+def display_attention(
+ attention, path, global_step: int, name="att", n_heads=4, n_rows=2, n_cols=2
+):
+ assert n_rows * n_cols == n_heads
+
+ fig = plt.figure(figsize=(15, 15))
+
+ for i in range(n_heads):
+
+ ax = fig.add_subplot(n_rows, n_cols, i + 1)
+
+ _attention = attention.squeeze(0)[i].transpose(0, 1).cpu().detach().numpy()
+ cax = ax.imshow(_attention, aspect="auto", origin="lower", interpolation="none")
+
+ plot_name = f"{global_step}-{name}.png"
+ plt.savefig(os.path.join(path, plot_name), dpi=300, format="png")
+ plt.close()
+
+
+def plot_multi_head(model, path, global_step):
+ encoder_attentions = get_encoder_layers_attentions(model)
+ decoder_attentions, attentions = get_decoder_layers_attentions(model)
+ for i in range(len(attentions)):
+ display_attention(
+ attentions[0][0], path, global_step, f"encoder-decoder-layer{i + 1}"
+ )
+ for i in range(len(decoder_attentions)):
+ display_attention(
+ decoder_attentions[0][0], path, global_step, f"decoder-layer{i + 1}"
+ )
+ for i in range(len(encoder_attentions)):
+ display_attention(
+ encoder_attentions[0][0], path, global_step, f"encoder-layer {i + 1}"
+ )
+
+
+def make_src_mask(src, pad_idx=0):
+
+ # src = [batch size, src len]
+
+ src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
+
+ # src_mask = [batch size, 1, 1, src len]
+
+ return src_mask
+
+
+def get_angles(pos, i, model_dim):
+ angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(model_dim))
+ return pos * angle_rates
+
+
+def positional_encoding(position, model_dim):
+ angle_rads = get_angles(
+ np.arange(position)[:, np.newaxis],
+ np.arange(model_dim)[np.newaxis, :],
+ model_dim,
+ )
+
+ # apply sin to even indices in the array; 2i
+ angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
+
+ # apply cos to odd indices in the array; 2i+1
+ angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
+
+ pos_encoding = angle_rads[np.newaxis, ...]
+
+ return torch.from_numpy(pos_encoding)
+
+
+def calculate_error_rates(original_file_path: str, target_file_path: str) -> ErrorRate:
+ """
+ Calculates ErrorRates from paths
+ """
+ assert os.path.isfile(original_file_path)
+ assert os.path.isfile(target_file_path)
+
+ _wer = wer.calculate_wer_from_path(
+ inp_path=original_file_path, out_path=target_file_path, case_ending=True
+ )
+
+ _wer_without_case_ending = wer.calculate_wer_from_path(
+ inp_path=original_file_path, out_path=target_file_path, case_ending=False
+ )
+
+ _der = der.calculate_der_from_path(
+ inp_path=original_file_path, out_path=target_file_path, case_ending=True
+ )
+
+ _der_without_case_ending = der.calculate_der_from_path(
+ inp_path=original_file_path, out_path=target_file_path, case_ending=False
+ )
+
+ error_rates = ErrorRate(
+ _wer,
+ _der,
+ _wer_without_case_ending,
+ _der_without_case_ending,
+ )
+
+ return error_rates
+
+
+def categorical_accuracy(preds, y, tag_pad_idx, device="cuda"):
+ """
+ Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
+ """
+ max_preds = preds.argmax(
+ dim=1, keepdim=True
+ ) # get the index of the max probability
+ non_pad_elements = torch.nonzero((y != tag_pad_idx))
+ correct = max_preds[non_pad_elements].squeeze(1).eq(y[non_pad_elements])
+ return correct.sum() / torch.FloatTensor([y[non_pad_elements].shape[0]]).to(device)
+
+
+def write_to_files(input_path, output_path, input_list, output_list):
+ with open(input_path, "w", encoding="utf8") as file:
+ for inp in input_list:
+ file.write(inp + "\n")
+ with open(output_path, "w", encoding="utf8") as file:
+ for out in output_list:
+ file.write(out + "\n")
+
+
+def make_src_mask(src: torch.Tensor, pad_idx=0):
+ return (src != pad_idx).unsqueeze(1).unsqueeze(2)
+
+
+def make_trg_mask(trg, trg_pad_idx=0):
+
+ # trg = [batch size, trg len]
+
+ trg_pad_mask = (trg != trg_pad_idx).unsqueeze(1).unsqueeze(2)
+
+ # trg_pad_mask = [batch size, 1, 1, trg len]
+
+ trg_len = trg.shape[1]
+
+ trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len))).bool()
+
+ # trg_sub_mask = [trg len, trg len]
+
+ trg_mask = trg_pad_mask & trg_sub_mask
+
+ # trg_mask = [batch size, 1, trg len, trg len]
+
+ return trg_mask
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..35cd9acdb6ab6a1d1f5b5ba87270003e7b93bb60
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1 @@
+ashaar @ git+https://github.com/arbml/Ashaar.git
\ No newline at end of file
diff --git a/test.yml b/test.yml
new file mode 100644
index 0000000000000000000000000000000000000000..bc932ab55c283460995f8cc5a6be8282b236884a
--- /dev/null
+++ b/test.yml
@@ -0,0 +1,52 @@
+session_name: base
+
+data_directory: "data"
+data_type: "ashaar_proc"
+log_directory: "deep-learning-models/log_dir_ashaar"
+load_training_data: true
+load_test_data: false
+load_validation_data: true
+n_training_examples: null # null load all training examples, good for fast loading
+n_test_examples: null # null load all test examples
+n_validation_examples: null # null load all validation examples
+test_file_name: "test.csv"
+is_data_preprocessed: false # The data file is organized as (original text | text | diacritics)
+data_separator: '|' # Required if the data already processed
+diacritics_separator: '*' # Required if the data already processed
+text_encoder: ArabicEncoderWithStartSymbol
+text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
+max_len: 600 # sentences larger than this size will not be used
+max_sen_len: null
+
+max_steps: 10000
+learning_rate: 0.001
+batch_size: 32
+adam_beta1: 0.9
+adam_beta2: 0.999
+use_decay: true
+weight_decay: 0.0
+embedding_dim: 256
+use_prenet: false
+prenet_sizes: [512, 256]
+cbhg_projections: [128, 256]
+cbhg_filters: 16
+cbhg_gru_units: 256
+post_cbhg_layers_units: [256, 256]
+post_cbhg_use_batch_norm: true
+
+use_mixed_precision: false
+optimizer_type: Adam
+device: cpu
+
+# LOGGING
+evaluate_frequency: 50000000
+max_eval_batches: 100
+evaluate_with_error_rates_frequency: 1000
+n_predicted_text_tensorboard: 10 # To be written to the tensorboard
+model_save_frequency: 5000
+train_plotting_frequency: 50000000 # No plotting for this model
+n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
+error_rates_n_batches: 10000 # if calculating error rate is slow, then you can specify the number of batches to be calculated
+
+test_model_path: null # load the last saved model
+train_resume_model_path: null # load last saved model