Zaid commited on
Commit
2fb81a4
1 Parent(s): 9f20162

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ deep-learning-models/*
2
+ deep-learning-models.zip
3
+ __MACOSX/*
4
+ __pycache__/*
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
3
+ import gradio as gr
4
+ from transformers import pipeline
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from Ashaar.utils import get_output_df, get_highlighted_patterns_html
7
+ from Ashaar.bait_analysis import BaitAnalysis
8
+ from langs import *
9
+ import sys
10
+ import json
11
+ import argparse
12
+
13
+ arg_parser = argparse.ArgumentParser()
14
+ arg_parser.add_argument('--lang', type = str, default = 'ar', required=True)
15
+ args = arg_parser.parse_args()
16
+ lang = args.lang
17
+
18
+ if lang == 'ar':
19
+ TITLE = TITLE_ar
20
+ DESCRIPTION = DESCRIPTION_ar
21
+ textbox_trg_text = textbox_trg_text_ar
22
+ textbox_inp_text = textbox_inp_text_ar
23
+ btn_trg_text = btn_trg_text_ar
24
+ btn_inp_text = btn_inp_text_ar
25
+ css = """ #textbox{ direction: RTL;}"""
26
+
27
+ else:
28
+ TITLE = TITLE_en
29
+ DESCRIPTION = DESCRIPTION_en
30
+ textbox_trg_text = textbox_trg_text_en
31
+ textbox_inp_text = textbox_inp_text_en
32
+ btn_trg_text = btn_trg_text_en
33
+ btn_inp_text = btn_inp_text_en
34
+ css = ""
35
+
36
+ gpt_tokenizer = AutoTokenizer.from_pretrained('arbml/ashaar_tokenizer')
37
+ model = AutoModelForCausalLM.from_pretrained('arbml/Ashaar_model')
38
+
39
+ theme_to_token = json.load(open("extra/theme_tokens.json", "r"))
40
+ token_to_theme = {t:m for m,t in theme_to_token.items()}
41
+ meter_to_token = json.load(open("extra/meter_tokens.json", "r"))
42
+ token_to_meter = {t:m for m,t in meter_to_token.items()}
43
+
44
+ analysis = BaitAnalysis()
45
+ meter, theme, qafiyah = "", "", ""
46
+
47
+ def analyze(poem):
48
+ global meter,theme,qafiyah
49
+ shatrs = poem.split("\n")
50
+ baits = [' # '.join(shatrs[2*i:2*i+2]) for i in range(len(shatrs)//2)]
51
+ output = analysis.analyze(baits,override_tashkeel=True)
52
+ meter = output['meter']
53
+ qafiyah = output['qafiyah'][0]
54
+ theme = output['theme'][-1]
55
+
56
+ df = get_output_df(output)
57
+ return get_highlighted_patterns_html(df)
58
+
59
+ def generate(inputs, top_p = 3):
60
+ baits = inputs.split('\n')
61
+ print(baits)
62
+ poem = ' '.join(['<|bsep|> '+baits[i]+' <|vsep|> '+baits[i+1]+' </|bsep|>' for i in range(0, len(baits), 2)])
63
+ print(poem)
64
+ prompt = f"""
65
+ {meter_to_token[meter]} {qafiyah} {theme_to_token[theme]}
66
+ <|psep|>
67
+ {poem}
68
+ """.strip()
69
+ print(prompt)
70
+ encoded_input = gpt_tokenizer(prompt, return_tensors='pt')
71
+ output = model.generate(**encoded_input, max_length = 512, top_p = 3, do_sample=True)
72
+
73
+ result = ""
74
+ prev_token = ""
75
+ line_cnts = 0
76
+ for i, beam in enumerate(output[:, len(encoded_input.input_ids[0]):]):
77
+ if line_cnts >= 10:
78
+ break
79
+ for token in beam:
80
+ if line_cnts >= 10:
81
+ break
82
+ decoded = gpt_tokenizer.decode(token)
83
+ if 'meter' in decoded or 'theme' in decoded:
84
+ break
85
+ if decoded in ["<|vsep|>", "</|bsep|>"]:
86
+ result += "\n"
87
+ line_cnts+=1
88
+ elif decoded in ['<|bsep|>', '<|psep|>', '</|psep|>']:
89
+ pass
90
+ else:
91
+ result += decoded
92
+ prev_token = decoded
93
+ else:
94
+ break
95
+ # return theme+" "+ f"من بحر {meter} مع قافية بحر ({qafiyah})" + "\n" +result
96
+ return result
97
+ examples = [
98
+ [
99
+ """القلب أعلم يا عذول بدائه
100
+ وأحق منك بجفنه وبمائه"""
101
+ ],
102
+ [
103
+ """ألا ليت شعري هل أبيتن ليلة
104
+ بجنب الغضى أزجي الغلاص النواجيا"""
105
+ ],
106
+ ]
107
+
108
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
109
+ with gr.Row():
110
+ with gr.Column():
111
+ gr.HTML(TITLE)
112
+ gr.HTML(DESCRIPTION)
113
+
114
+ with gr.Row():
115
+ with gr.Column():
116
+ textbox_output = gr.Textbox(lines=10, label=textbox_trg_text, elem_id="textbox")
117
+ with gr.Column():
118
+ inputs = gr.Textbox(lines=10, label=textbox_inp_text, elem_id="textbox")
119
+
120
+
121
+ with gr.Row():
122
+ with gr.Column():
123
+ trg_btn = gr.Button(btn_trg_text)
124
+ with gr.Column():
125
+ inp_btn = gr.Button(btn_inp_text)
126
+
127
+ with gr.Row():
128
+ html_output = gr.HTML()
129
+
130
+ if lang == 'en':
131
+ gr.Examples(examples, textbox_output)
132
+ inp_btn.click(generate, inputs = textbox_output, outputs=inputs)
133
+ trg_btn.click(analyze, inputs = textbox_output, outputs=html_output)
134
+ else:
135
+ gr.Examples(examples, inputs)
136
+ trg_btn.click(generate, inputs = inputs, outputs=textbox_output)
137
+ inp_btn.click(analyze, inputs = inputs, outputs=html_output)
138
+
139
+ demo.launch(server_name = "0.0.0.0", share = True, debug = True)
extra/labels.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ saree
2
+ kamel
3
+ mutakareb
4
+ mutadarak
5
+ munsareh
6
+ madeed
7
+ mujtath
8
+ ramal
9
+ baseet
10
+ khafeef
11
+ taweel
12
+ wafer
13
+ hazaj
14
+ rajaz
15
+ mudhare
16
+ muqtadheb
17
+ prose
extra/labels_ar.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ السريع
2
+ الكامل
3
+ المتقارب
4
+ المتدارك
5
+ المنسرح
6
+ المديد
7
+ المجتث
8
+ الرمل
9
+ البسيط
10
+ الخفيف
11
+ الطويل
12
+ الوافر
13
+ الهزج
14
+ الرجز
15
+ المضارع
16
+ المقتضب
17
+ النثر
extra/meter_tokens.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"\u0627\u0644\u062e\u0641\u064a\u0641": "<|meter_0|>", "\u0627\u0644\u0645\u0636\u0627\u0631\u0639": "<|meter_1|>", "\u0627\u0644\u0645\u062c\u062a\u062b": "<|meter_2|>", "\u0627\u0644\u0631\u0645\u0644": "<|meter_3|>", "\u0627\u0644\u0628\u0633\u064a\u0637": "<|meter_4|>", "\u0627\u0644\u0645\u062a\u0642\u0627\u0631\u0628": "<|meter_5|>", "\u0627\u0644\u0648\u0627\u0641\u0631": "<|meter_6|>", "\u0627\u0644\u0645\u0642\u062a\u0636\u0628": "<|meter_7|>", "\u0627\u0644\u0645\u062f\u064a\u062f": "<|meter_8|>", "\u0627\u0644\u0646\u062b\u0631": "<|meter_9|>", "\u0627\u0644\u0647\u0632\u062c": "<|meter_10|>", "\u0627\u0644\u0645\u062a\u062f\u0627\u0631\u0643": "<|meter_11|>", "\u0627\u0644\u0645\u0646\u0633\u0631\u062d": "<|meter_12|>", "\u0627\u0644\u0637\u0648\u064a\u0644": "<|meter_13|>", "\u0627\u0644\u0643\u0627\u0645\u0644": "<|meter_14|>", "\u0627\u0644\u0631\u062c\u0632": "<|meter_15|>", "\u0627\u0644\u0633\u0631\u064a\u0639": "<|meter_16|>"}
extra/theme_tokens.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"\u0642\u0635\u064a\u062f\u0629 \u0642\u0635\u064a\u0631\u0647": "<|theme_0|>", "\u0642\u0635\u064a\u062f\u0629 \u0645\u062f\u062d": "<|theme_1|>", "\u0642\u0635\u064a\u062f\u0629 \u0648\u0637\u0646\u064a\u0647": "<|theme_2|>", "\u0642\u0635\u064a\u062f\u0629 \u0631\u0648\u0645\u0646\u0633\u064a\u0647": "<|theme_3|>", "\u0642\u0635\u064a\u062f\u0629 \u0647\u062c\u0627\u0621": "<|theme_4|>", "\u0642\u0635\u064a\u062f\u0629 \u0627\u0639\u062a\u0630\u0627\u0631": "<|theme_5|>", "\u0642\u0635\u064a\u062f\u0629 \u0633\u064a\u0627\u0633\u064a\u0629": "<|theme_6|>", "\u0642\u0635\u064a\u062f\u0629 \u0641\u0631\u0627\u0642": "<|theme_7|>", "\u0642\u0635\u064a\u062f\u0629 \u063a\u0632\u0644": "<|theme_8|>", "\u0642\u0635\u064a\u062f\u0629 \u0630\u0645": "<|theme_9|>", "\u0642\u0635\u064a\u062f\u0629 \u0631\u062b\u0627\u0621": "<|theme_10|>", "null": "<|theme_11|>", "\u0642\u0635\u064a\u062f\u0629 \u0634\u0648\u0642": "<|theme_12|>", "\u0642\u0635\u064a\u062f\u0629 \u0627\u0644\u0645\u0639\u0644\u0642\u0627\u062a": "<|theme_13|>", "\u0642\u0635\u064a\u062f\u0629 \u0627\u0644\u0627\u0646\u0627\u0634\u064a\u062f": "<|theme_14|>", "\u0642\u0635\u064a\u062f\u0629 \u062d\u0632\u064a\u0646\u0647": "<|theme_15|>", "\u0642\u0635\u064a\u062f\u0629 \u0639\u062a\u0627\u0628": "<|theme_16|>", "\u0642\u0635\u064a\u062f\u0629 \u0639\u0627\u0645\u0647": "<|theme_17|>", "\u0642\u0635\u064a\u062f\u0629 \u062f\u064a\u0646\u064a\u0629": "<|theme_18|>"}
extra/theme_tokens.txt ADDED
File without changes
langs.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ IMG = """<p align = 'center'>
2
+ <img src='https://raw.githubusercontent.com/ARBML/Ashaar/master/images/ashaar_icon.png' width='150px' alt='logo for Ashaar'/>
3
+ </p>
4
+
5
+ """
6
+ TITLE_ar="""<h1 style="font-size: 30px;" align="center">أَشْعــَـار: تحليل وإنشاء الشعر العربي</h1>"""
7
+ DESCRIPTION_ar = IMG
8
+
9
+ DESCRIPTION_ar +=""" <p dir='rtl'>
10
+ هذا البرنامج يتيح للمستخدم تحليل وإنشاء الشعر العربي.
11
+ لإنشاء الشعر العربي تم تدريب نموج يقوم بإستخدام البحر والقافية والعاطفة لإنشاء أكمال للقصيدة بناء على هذه الشورط.
12
+ بالإضافة إلى نموذج إنشاء الشعر يحتوي البرنامج على نماذج لتصنيف الحقبة الزمنية والعاطفة والبحر و كذلك تشكيل الشعر العربي بالإضافة إلى إكمال الشعر.
13
+ قمنا بتوفير الشفرة البرمجية كلها على
14
+ <a href ='https://github.com/ARBML/Ashaar'> GitHub</a>.
15
+
16
+ </p>
17
+ """
18
+
19
+ TITLE_en="""<h1 style="font-size: 30px;" align="center">Ashaar: Arabic Poetry Analysis and Generation</h1>"""
20
+ DESCRIPTION_en = IMG
21
+
22
+ DESCRIPTION_en +="""
23
+ The demo provides a way to generate analysis for poetry and also complete the poetry.
24
+ The generative model is a character-based conditional GPT-2 model. The pipeline contains many models for
25
+ classification, diacritization and conditional generation. Check our <a src='https://github.com/ARBML/Ashaar'>GitHub</a> for more techincal details
26
+ about this work. In the demo we have two basic pipelines. Analyze which predicts the meter, era, theme, diacritized text, qafiyah and, arudi style.
27
+ The other module, Generate which takes the input text, meter, theme and qafiyah to generate the full poem.
28
+ """
29
+
30
+ btn_trg_text_ar = "إنشاء"
31
+ btn_inp_text_ar = "تحليل"
32
+
33
+ btn_inp_text_en = "Generate"
34
+ btn_trg_text_en = "Analyze"
35
+
36
+ textbox_inp_text_ar = "القصيدة المدخلة"
37
+ textbox_trg_text_ar = "القصيدة المنشئة"
38
+
39
+ textbox_trg_text_en = "Input Poem"
40
+ textbox_inp_text_en = "Generated Poem"
41
+
42
+
43
+
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ashaar @ git+https://github.com/arbml/Ashaar.git
test.yml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ session_name: base
2
+
3
+ data_directory: "data"
4
+ data_type: "ashaar_proc"
5
+ log_directory: "deep-learning-models/log_dir_ashaar"
6
+ load_training_data: true
7
+ load_test_data: false
8
+ load_validation_data: true
9
+ n_training_examples: null # null load all training examples, good for fast loading
10
+ n_test_examples: null # null load all test examples
11
+ n_validation_examples: null # null load all validation examples
12
+ test_file_name: "test.csv"
13
+ is_data_preprocessed: false # The data file is organized as (original text | text | diacritics)
14
+ data_separator: '|' # Required if the data already processed
15
+ diacritics_separator: '*' # Required if the data already processed
16
+ text_encoder: ArabicEncoderWithStartSymbol
17
+ text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
18
+ max_len: 600 # sentences larger than this size will not be used
19
+ max_sen_len: null
20
+
21
+ max_steps: 10000
22
+ learning_rate: 0.001
23
+ batch_size: 32
24
+ adam_beta1: 0.9
25
+ adam_beta2: 0.999
26
+ use_decay: true
27
+ weight_decay: 0.0
28
+ embedding_dim: 256
29
+ use_prenet: false
30
+ prenet_sizes: [512, 256]
31
+ cbhg_projections: [128, 256]
32
+ cbhg_filters: 16
33
+ cbhg_gru_units: 256
34
+ post_cbhg_layers_units: [256, 256]
35
+ post_cbhg_use_batch_norm: true
36
+
37
+ use_mixed_precision: false
38
+ optimizer_type: Adam
39
+ device: cuda
40
+
41
+ # LOGGING
42
+ evaluate_frequency: 50000000
43
+ max_eval_batches: 100
44
+ evaluate_with_error_rates_frequency: 1000
45
+ n_predicted_text_tensorboard: 10 # To be written to the tensorboard
46
+ model_save_frequency: 5000
47
+ train_plotting_frequency: 50000000 # No plotting for this model
48
+ n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
49
+ error_rates_n_batches: 10000 # if calculating error rate is slow, then you can specify the number of batches to be calculated
50
+
51
+ test_model_path: null # load the last saved model
52
+ train_resume_model_path: null # load last saved model