Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import re
|
3 |
+
from poems import SAMPLE_POEMS
|
4 |
+
|
5 |
+
import langid
|
6 |
+
import numpy as np
|
7 |
+
import streamlit as st
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from icu_tokenizer import Tokenizer
|
11 |
+
from transformers import pipeline
|
12 |
+
|
13 |
+
MODELS = {
|
14 |
+
"ALBERTI": "linhd-postdata/alberti-bert-base-multilingual-cased",
|
15 |
+
"mBERT": "bert-base-multilingual-cased"
|
16 |
+
}
|
17 |
+
|
18 |
+
TOPK = 50
|
19 |
+
st.set_page_config(layout="wide")
|
20 |
+
|
21 |
+
|
22 |
+
def mask_line(line, language="es", restrictive=True):
|
23 |
+
tokenizer = Tokenizer(lang=language)
|
24 |
+
token_list = tokenizer.tokenize(line)
|
25 |
+
if lang != "zh":
|
26 |
+
restrictive = not all([len(token) <= 3 for token in token_list])
|
27 |
+
random_num = random.randint(0, len(token_list) - 1)
|
28 |
+
random_word = token_list[random_num]
|
29 |
+
if not restrictive:
|
30 |
+
token_list[random_num] = "[MASK]"
|
31 |
+
masked_l = " ".join(token_list)
|
32 |
+
return masked_l
|
33 |
+
elif len(random_word) > 3 or (lang == "zh" and random_word.isalpha()):
|
34 |
+
token_list[random_num] = "[MASK]"
|
35 |
+
masked_l = " ".join(token_list)
|
36 |
+
return masked_l
|
37 |
+
else:
|
38 |
+
return mask_line(line, language)
|
39 |
+
|
40 |
+
|
41 |
+
def filter_candidates(candidates, get_any_candidate=False):
|
42 |
+
cand_list = []
|
43 |
+
score_list = []
|
44 |
+
for candidate in candidates:
|
45 |
+
if not get_any_candidate and candidate["token_str"][:2] != "##" and candidate["token_str"].isalpha():
|
46 |
+
cand = candidate["sequence"]
|
47 |
+
score = candidate["score"]
|
48 |
+
cand_list.append(cand)
|
49 |
+
score_list.append('{0:.5f}'.format(score))
|
50 |
+
elif get_any_candidate:
|
51 |
+
cand = candidate["sequence"]
|
52 |
+
score = candidate["score"]
|
53 |
+
cand_list.append(cand)
|
54 |
+
score_list.append('{0:.5f}'.format(score))
|
55 |
+
if len(score_list) == TOPK:
|
56 |
+
break
|
57 |
+
if len(cand_list) < 1:
|
58 |
+
return filter_candidates(candidates, get_any_candidate=True)
|
59 |
+
else:
|
60 |
+
return cand_list[0]
|
61 |
+
|
62 |
+
|
63 |
+
def infer_candidates(nlp, line):
|
64 |
+
line = re.sub("β", "-", line)
|
65 |
+
line = re.sub("β", "-", line)
|
66 |
+
line = re.sub("β", "'", line)
|
67 |
+
line = re.sub("β¦", "...", line)
|
68 |
+
inputs = nlp._parse_and_tokenize(line)
|
69 |
+
outputs = nlp._forward(inputs, return_tensors=True)
|
70 |
+
input_ids = inputs["input_ids"][0]
|
71 |
+
masked_index = torch.nonzero(input_ids == nlp.tokenizer.mask_token_id,
|
72 |
+
as_tuple=False)
|
73 |
+
logits = outputs[0, masked_index.item(), :]
|
74 |
+
probs = logits.softmax(dim=0)
|
75 |
+
values, predictions = probs.topk(TOPK)
|
76 |
+
result = []
|
77 |
+
for v, p in zip(values.tolist(), predictions.tolist()):
|
78 |
+
tokens = input_ids.numpy()
|
79 |
+
tokens[masked_index] = p
|
80 |
+
# Filter padding out:
|
81 |
+
tokens = tokens[np.where(tokens != nlp.tokenizer.pad_token_id)]
|
82 |
+
l = []
|
83 |
+
token_list = [nlp.tokenizer.decode([token], skip_special_tokens=True) for token in tokens]
|
84 |
+
for idx, token in enumerate(token_list):
|
85 |
+
if token.startswith('##'):
|
86 |
+
l[-1] += token[2:]
|
87 |
+
elif idx == masked_index.item():
|
88 |
+
l += ['<b style="color: #ff0000;">', token, "</b>"]
|
89 |
+
else:
|
90 |
+
l += [token]
|
91 |
+
sequence = " ".join(l).strip()
|
92 |
+
result.append(
|
93 |
+
{
|
94 |
+
"sequence": sequence,
|
95 |
+
"score": v,
|
96 |
+
"token": p,
|
97 |
+
"token_str": nlp.tokenizer.decode(p),
|
98 |
+
"masked_index": masked_index.item()
|
99 |
+
}
|
100 |
+
)
|
101 |
+
return result
|
102 |
+
|
103 |
+
|
104 |
+
def rewrite_poem(poem, ml_model=MODELS["ALBERTI"], masking=True, language="es"):
|
105 |
+
nlp = pipeline("fill-mask", model=ml_model)
|
106 |
+
unmasked_lines = []
|
107 |
+
masked_lines = []
|
108 |
+
for line in poem:
|
109 |
+
if line == "":
|
110 |
+
unmasked_lines.append("")
|
111 |
+
masked_lines.append("")
|
112 |
+
continue
|
113 |
+
if masking:
|
114 |
+
masked_line = mask_line(line, language)
|
115 |
+
else:
|
116 |
+
masked_line = line
|
117 |
+
masked_lines.append(masked_line)
|
118 |
+
unmasked_line_candidates = infer_candidates(nlp, masked_line)
|
119 |
+
unmasked_line = filter_candidates(unmasked_line_candidates)
|
120 |
+
unmasked_lines.append(unmasked_line)
|
121 |
+
unmasked_poem = "<br>".join(unmasked_lines)
|
122 |
+
return unmasked_poem, masked_lines
|
123 |
+
|
124 |
+
|
125 |
+
instructions_text_0 = st.sidebar.markdown(
|
126 |
+
"""# ALBERTI vs BERT π₯
|
127 |
+
|
128 |
+
We present ALBERTI, our BERT-based multilingual model for poetry.""")
|
129 |
+
|
130 |
+
instructions_text_1 = st.sidebar.markdown(
|
131 |
+
"""We have trained bert on a huge (for poetry, that is) corpus of
|
132 |
+
multilingual poetry to try to get a more 'poetic' model. This is the result
|
133 |
+
of our work.
|
134 |
+
|
135 |
+
You can find more information on the [project's site](https://huggingface.co/flax-community/alberti-bert-base-multilingual-cased)""")
|
136 |
+
|
137 |
+
sample_chooser = st.sidebar.selectbox(
|
138 |
+
"Choose a poem",
|
139 |
+
list(SAMPLE_POEMS.keys())
|
140 |
+
)
|
141 |
+
|
142 |
+
instructions_text_2 = st.sidebar.markdown("""# How to use
|
143 |
+
|
144 |
+
You can choose from a list of example poems in Spanish, English, French, German,
|
145 |
+
Chinese and Arabic, but you can also paste a poem, or write it yourself!
|
146 |
+
|
147 |
+
Then click on 'Rewrite!' to do the masking and the fill-mask task on the chosen
|
148 |
+
poem, randomly masking one word per verse, and get the two new versions for each of the models.
|
149 |
+
|
150 |
+
The list of languages used on the training of ALBERTI are:
|
151 |
+
|
152 |
+
* Arabic
|
153 |
+
* Chinese
|
154 |
+
* Czech
|
155 |
+
* English
|
156 |
+
* Finnish
|
157 |
+
* French
|
158 |
+
* German
|
159 |
+
* Hungarian
|
160 |
+
* Italian
|
161 |
+
* Portuguese
|
162 |
+
* Russian
|
163 |
+
* Spanish""")
|
164 |
+
|
165 |
+
col1, col2, col3 = st.columns(3)
|
166 |
+
|
167 |
+
st.markdown(
|
168 |
+
"""
|
169 |
+
<style>
|
170 |
+
label {
|
171 |
+
font-size: 1rem !important;
|
172 |
+
font-weight: bold !important;
|
173 |
+
}
|
174 |
+
.block-container {
|
175 |
+
padding-left: 1rem !important;
|
176 |
+
padding-right: 1rem !important;
|
177 |
+
}
|
178 |
+
</style>
|
179 |
+
""", unsafe_allow_html=True)
|
180 |
+
|
181 |
+
if sample_chooser:
|
182 |
+
model_list = set(MODELS.values())
|
183 |
+
user_input = col1.text_area("Input poem",
|
184 |
+
"\n".join(SAMPLE_POEMS[sample_chooser]),
|
185 |
+
height=600)
|
186 |
+
poem = user_input.split("\n")
|
187 |
+
rewrite_button = col1.button("Rewrite!")
|
188 |
+
if "[MASK]" in user_input or "<mask>" in user_input:
|
189 |
+
col1.error("You don't have to mask the poem, we'll do it for you!")
|
190 |
+
|
191 |
+
if rewrite_button:
|
192 |
+
lang = langid.classify(user_input)[0]
|
193 |
+
unmasked_poem, masked_poem = rewrite_poem(poem, language=lang)
|
194 |
+
user_input_2 = col2.write(f"""<b>Output poem from ALBERTI</b>
|
195 |
+
|
196 |
+
|
197 |
+
{unmasked_poem}""", unsafe_allow_html=True)
|
198 |
+
unmasked_poem_2, _ = rewrite_poem(masked_poem, ml_model=MODELS["mBERT"],
|
199 |
+
masking=False)
|
200 |
+
user_input_3 = col3.write(f"""<b>Output poem from mBERT</b>
|
201 |
+
|
202 |
+
{unmasked_poem_2}""", unsafe_allow_html=True)
|