versae commited on
Commit
34795be
β€’
1 Parent(s): e112560

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -0
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import re
3
+ from poems import SAMPLE_POEMS
4
+
5
+ import langid
6
+ import numpy as np
7
+ import streamlit as st
8
+ import torch
9
+
10
+ from icu_tokenizer import Tokenizer
11
+ from transformers import pipeline
12
+
13
+ MODELS = {
14
+ "ALBERTI": "linhd-postdata/alberti-bert-base-multilingual-cased",
15
+ "mBERT": "bert-base-multilingual-cased"
16
+ }
17
+
18
+ TOPK = 50
19
+ st.set_page_config(layout="wide")
20
+
21
+
22
+ def mask_line(line, language="es", restrictive=True):
23
+ tokenizer = Tokenizer(lang=language)
24
+ token_list = tokenizer.tokenize(line)
25
+ if lang != "zh":
26
+ restrictive = not all([len(token) <= 3 for token in token_list])
27
+ random_num = random.randint(0, len(token_list) - 1)
28
+ random_word = token_list[random_num]
29
+ if not restrictive:
30
+ token_list[random_num] = "[MASK]"
31
+ masked_l = " ".join(token_list)
32
+ return masked_l
33
+ elif len(random_word) > 3 or (lang == "zh" and random_word.isalpha()):
34
+ token_list[random_num] = "[MASK]"
35
+ masked_l = " ".join(token_list)
36
+ return masked_l
37
+ else:
38
+ return mask_line(line, language)
39
+
40
+
41
+ def filter_candidates(candidates, get_any_candidate=False):
42
+ cand_list = []
43
+ score_list = []
44
+ for candidate in candidates:
45
+ if not get_any_candidate and candidate["token_str"][:2] != "##" and candidate["token_str"].isalpha():
46
+ cand = candidate["sequence"]
47
+ score = candidate["score"]
48
+ cand_list.append(cand)
49
+ score_list.append('{0:.5f}'.format(score))
50
+ elif get_any_candidate:
51
+ cand = candidate["sequence"]
52
+ score = candidate["score"]
53
+ cand_list.append(cand)
54
+ score_list.append('{0:.5f}'.format(score))
55
+ if len(score_list) == TOPK:
56
+ break
57
+ if len(cand_list) < 1:
58
+ return filter_candidates(candidates, get_any_candidate=True)
59
+ else:
60
+ return cand_list[0]
61
+
62
+
63
+ def infer_candidates(nlp, line):
64
+ line = re.sub("–", "-", line)
65
+ line = re.sub("β€”", "-", line)
66
+ line = re.sub("’", "'", line)
67
+ line = re.sub("…", "...", line)
68
+ inputs = nlp._parse_and_tokenize(line)
69
+ outputs = nlp._forward(inputs, return_tensors=True)
70
+ input_ids = inputs["input_ids"][0]
71
+ masked_index = torch.nonzero(input_ids == nlp.tokenizer.mask_token_id,
72
+ as_tuple=False)
73
+ logits = outputs[0, masked_index.item(), :]
74
+ probs = logits.softmax(dim=0)
75
+ values, predictions = probs.topk(TOPK)
76
+ result = []
77
+ for v, p in zip(values.tolist(), predictions.tolist()):
78
+ tokens = input_ids.numpy()
79
+ tokens[masked_index] = p
80
+ # Filter padding out:
81
+ tokens = tokens[np.where(tokens != nlp.tokenizer.pad_token_id)]
82
+ l = []
83
+ token_list = [nlp.tokenizer.decode([token], skip_special_tokens=True) for token in tokens]
84
+ for idx, token in enumerate(token_list):
85
+ if token.startswith('##'):
86
+ l[-1] += token[2:]
87
+ elif idx == masked_index.item():
88
+ l += ['<b style="color: #ff0000;">', token, "</b>"]
89
+ else:
90
+ l += [token]
91
+ sequence = " ".join(l).strip()
92
+ result.append(
93
+ {
94
+ "sequence": sequence,
95
+ "score": v,
96
+ "token": p,
97
+ "token_str": nlp.tokenizer.decode(p),
98
+ "masked_index": masked_index.item()
99
+ }
100
+ )
101
+ return result
102
+
103
+
104
+ def rewrite_poem(poem, ml_model=MODELS["ALBERTI"], masking=True, language="es"):
105
+ nlp = pipeline("fill-mask", model=ml_model)
106
+ unmasked_lines = []
107
+ masked_lines = []
108
+ for line in poem:
109
+ if line == "":
110
+ unmasked_lines.append("")
111
+ masked_lines.append("")
112
+ continue
113
+ if masking:
114
+ masked_line = mask_line(line, language)
115
+ else:
116
+ masked_line = line
117
+ masked_lines.append(masked_line)
118
+ unmasked_line_candidates = infer_candidates(nlp, masked_line)
119
+ unmasked_line = filter_candidates(unmasked_line_candidates)
120
+ unmasked_lines.append(unmasked_line)
121
+ unmasked_poem = "<br>".join(unmasked_lines)
122
+ return unmasked_poem, masked_lines
123
+
124
+
125
+ instructions_text_0 = st.sidebar.markdown(
126
+ """# ALBERTI vs BERT πŸ₯Š
127
+
128
+ We present ALBERTI, our BERT-based multilingual model for poetry.""")
129
+
130
+ instructions_text_1 = st.sidebar.markdown(
131
+ """We have trained bert on a huge (for poetry, that is) corpus of
132
+ multilingual poetry to try to get a more 'poetic' model. This is the result
133
+ of our work.
134
+
135
+ You can find more information on the [project's site](https://huggingface.co/flax-community/alberti-bert-base-multilingual-cased)""")
136
+
137
+ sample_chooser = st.sidebar.selectbox(
138
+ "Choose a poem",
139
+ list(SAMPLE_POEMS.keys())
140
+ )
141
+
142
+ instructions_text_2 = st.sidebar.markdown("""# How to use
143
+
144
+ You can choose from a list of example poems in Spanish, English, French, German,
145
+ Chinese and Arabic, but you can also paste a poem, or write it yourself!
146
+
147
+ Then click on 'Rewrite!' to do the masking and the fill-mask task on the chosen
148
+ poem, randomly masking one word per verse, and get the two new versions for each of the models.
149
+
150
+ The list of languages used on the training of ALBERTI are:
151
+
152
+ * Arabic
153
+ * Chinese
154
+ * Czech
155
+ * English
156
+ * Finnish
157
+ * French
158
+ * German
159
+ * Hungarian
160
+ * Italian
161
+ * Portuguese
162
+ * Russian
163
+ * Spanish""")
164
+
165
+ col1, col2, col3 = st.columns(3)
166
+
167
+ st.markdown(
168
+ """
169
+ <style>
170
+ label {
171
+ font-size: 1rem !important;
172
+ font-weight: bold !important;
173
+ }
174
+ .block-container {
175
+ padding-left: 1rem !important;
176
+ padding-right: 1rem !important;
177
+ }
178
+ </style>
179
+ """, unsafe_allow_html=True)
180
+
181
+ if sample_chooser:
182
+ model_list = set(MODELS.values())
183
+ user_input = col1.text_area("Input poem",
184
+ "\n".join(SAMPLE_POEMS[sample_chooser]),
185
+ height=600)
186
+ poem = user_input.split("\n")
187
+ rewrite_button = col1.button("Rewrite!")
188
+ if "[MASK]" in user_input or "<mask>" in user_input:
189
+ col1.error("You don't have to mask the poem, we'll do it for you!")
190
+
191
+ if rewrite_button:
192
+ lang = langid.classify(user_input)[0]
193
+ unmasked_poem, masked_poem = rewrite_poem(poem, language=lang)
194
+ user_input_2 = col2.write(f"""<b>Output poem from ALBERTI</b>
195
+
196
+
197
+ {unmasked_poem}""", unsafe_allow_html=True)
198
+ unmasked_poem_2, _ = rewrite_poem(masked_poem, ml_model=MODELS["mBERT"],
199
+ masking=False)
200
+ user_input_3 = col3.write(f"""<b>Output poem from mBERT</b>
201
+
202
+ {unmasked_poem_2}""", unsafe_allow_html=True)