hgrif commited on
Commit
251b174
Β·
1 Parent(s): 9360c74

Raw copy-paste

Browse files
Files changed (2) hide show
  1. app.py +406 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import functools
3
+ import itertools
4
+ import logging
5
+ import random
6
+ import string
7
+ from typing import List, Optional
8
+
9
+ import requests
10
+ import numpy as np
11
+ import tensorflow as tf
12
+ import streamlit as st
13
+ from gazpacho import Soup, get
14
+ from transformers.modeling_tf_bert import TFBertForPreTraining
15
+ from transformers.tokenization_bert import PreTrainedTokenizer
16
+ from transformers import BertTokenizer, TFBertForMaskedLM
17
+
18
+
19
+ DEFAULT_QUERY = "Machines will take over the world soon"
20
+ N_RHYMES = 10
21
+ ITER_FACTOR = 5
22
+
23
+
24
+ LANGUAGE = st.sidebar.radio("Language", ["english", "dutch"],0)
25
+ if LANGUAGE == "english":
26
+ MODEL_PATH = "./data/bert-large-cased-whole-word-masking-finetuned-squad"
27
+ elif LANGUAGE == "dutch":
28
+ MODEL_PATH = "./data/wietsedv/bert-base-dutch-cased"
29
+ else:
30
+ raise NotImplementedError(f"Unsupported language ({LANGUAGE}) expected 'english' or 'dutch'.")
31
+
32
+ def main():
33
+ st.markdown(
34
+ "<sup>Created with "
35
+ "[Datamuse](https://www.datamuse.com/api/), "
36
+ "[Mick's rijmwoordenboek](https://rijmwoordenboek.nl)"
37
+ "[Hugging Face](https://huggingface.co/), "
38
+ "[Streamlit](https://streamlit.io/) and "
39
+ "[App Engine](https://cloud.google.com/appengine/)."
40
+ " Read our [blog](https://blog.godatadriven.com/rhyme-with-ai) "
41
+ "or check the "
42
+ "[source](https://github.com/godatadriven/rhyme-with-ai).</sup>",
43
+ unsafe_allow_html=True,
44
+ )
45
+ st.title("Rhyme with AI")
46
+ query = get_query()
47
+ if not query:
48
+ query = DEFAULT_QUERY
49
+ rhyme_words_options = query_rhyme_words(query, n_rhymes=N_RHYMES,language=LANGUAGE)
50
+ if rhyme_words_options:
51
+ start_rhyming(query, rhyme_words_options)
52
+ else:
53
+ st.write("No rhyme words found")
54
+
55
+
56
+ def get_query():
57
+ q = sanitize(
58
+ st.text_input("Write your first line and press ENTER to rhyme:", DEFAULT_QUERY)
59
+ )
60
+ if not q:
61
+ return DEFAULT_QUERY
62
+ return q
63
+
64
+
65
+ def start_rhyming(query, rhyme_words_options):
66
+ st.markdown("## My Suggestions:")
67
+
68
+ progress_bar = st.progress(0)
69
+ status_text = st.empty()
70
+ max_iter = len(query.split()) * ITER_FACTOR
71
+
72
+ rhyme_words = rhyme_words_options[:N_RHYMES]
73
+
74
+ model, tokenizer = load_model(MODEL_PATH)
75
+ sentence_generator = RhymeGenerator(model, tokenizer)
76
+ sentence_generator.start(query, rhyme_words)
77
+
78
+ current_sentences = [" " for _ in range(N_RHYMES)]
79
+ for i in range(max_iter):
80
+ previous_sentences = copy.deepcopy(current_sentences)
81
+ current_sentences = sentence_generator.mutate()
82
+ display_output(status_text, query, current_sentences, previous_sentences)
83
+ progress_bar.progress(i / (max_iter - 1))
84
+ st.balloons()
85
+
86
+
87
+ @st.cache(allow_output_mutation=True)
88
+ def load_model(model_path):
89
+ return (
90
+ TFBertForMaskedLM.from_pretrained(model_path),
91
+ BertTokenizer.from_pretrained(model_path),
92
+ )
93
+
94
+
95
+ def display_output(status_text, query, current_sentences, previous_sentences):
96
+ print_sentences = []
97
+ for new, old in zip(current_sentences, previous_sentences):
98
+ formatted = color_new_words(new, old)
99
+ after_comma = "<li>" + formatted.split(",")[1][:-2] + "</li>"
100
+ print_sentences.append(after_comma)
101
+ status_text.markdown(
102
+ query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True
103
+ )
104
+
105
+
106
+ class RhymeGenerator:
107
+ def __init__(
108
+ self,
109
+ model: TFBertForPreTraining,
110
+ tokenizer: PreTrainedTokenizer,
111
+ token_weighter: TokenWeighter = None,
112
+ ):
113
+ """Generate rhymes.
114
+
115
+ Parameters
116
+ ----------
117
+ model : Model for masked language modelling
118
+ tokenizer : Tokenizer for model
119
+ token_weighter : Class that weighs tokens
120
+ """
121
+
122
+ self.model = model
123
+ self.tokenizer = tokenizer
124
+ if token_weighter is None:
125
+ token_weighter = TokenWeighter(tokenizer)
126
+ self.token_weighter = token_weighter
127
+ self._logger = logging.getLogger(__name__)
128
+
129
+ self.tokenized_rhymes_ = None
130
+ self.position_probas_ = None
131
+
132
+ # Easy access.
133
+ self.comma_token_id = self.tokenizer.encode(",", add_special_tokens=False)[0]
134
+ self.period_token_id = self.tokenizer.encode(".", add_special_tokens=False)[0]
135
+ self.mask_token_id = self.tokenizer.mask_token_id
136
+
137
+ def start(self, query: str, rhyme_words: List[str]) -> None:
138
+ """Start the sentence generator.
139
+
140
+ Parameters
141
+ ----------
142
+ query : Seed sentence
143
+ rhyme_words : Rhyme words for next sentence
144
+ """
145
+ # TODO: What if no content?
146
+ self._logger.info("Got sentence %s", query)
147
+ tokenized_rhymes = [
148
+ self._initialize_rhymes(query, rhyme_word) for rhyme_word in rhyme_words
149
+ ]
150
+ # Make same length.
151
+ self.tokenized_rhymes_ = tf.keras.preprocessing.sequence.pad_sequences(
152
+ tokenized_rhymes, padding="post", value=self.tokenizer.pad_token_id
153
+ )
154
+ p = self.tokenized_rhymes_ == self.tokenizer.mask_token_id
155
+ self.position_probas_ = p / p.sum(1).reshape(-1, 1)
156
+
157
+ def _initialize_rhymes(self, query: str, rhyme_word: str) -> List[int]:
158
+ """Initialize the rhymes.
159
+
160
+ * Tokenize input
161
+ * Append a comma if the sentence does not end in it (might add better predictions as it
162
+ shows the two sentence parts are related)
163
+ * Make second line as long as the original
164
+ * Add a period
165
+
166
+ Parameters
167
+ ----------
168
+ query : First line
169
+ rhyme_word : Last word for second line
170
+
171
+ Returns
172
+ -------
173
+ Tokenized rhyme lines
174
+ """
175
+
176
+ query_token_ids = self.tokenizer.encode(query, add_special_tokens=False)
177
+ rhyme_word_token_ids = self.tokenizer.encode(
178
+ rhyme_word, add_special_tokens=False
179
+ )
180
+
181
+ if query_token_ids[-1] != self.comma_token_id:
182
+ query_token_ids.append(self.comma_token_id)
183
+
184
+ magic_correction = len(rhyme_word_token_ids) + 1 # 1 for comma
185
+ return (
186
+ query_token_ids
187
+ + [self.tokenizer.mask_token_id] * (len(query_token_ids) - magic_correction)
188
+ + rhyme_word_token_ids
189
+ + [self.period_token_id]
190
+ )
191
+
192
+ def mutate(self):
193
+ """Mutate the current rhymes.
194
+
195
+ Returns
196
+ -------
197
+ Mutated rhymes
198
+ """
199
+ self.tokenized_rhymes_ = self._mutate(
200
+ self.tokenized_rhymes_, self.position_probas_, self.token_weighter.proba
201
+ )
202
+
203
+ rhymes = []
204
+ for i in range(len(self.tokenized_rhymes_)):
205
+ rhymes.append(
206
+ self.tokenizer.convert_tokens_to_string(
207
+ self.tokenizer.convert_ids_to_tokens(
208
+ self.tokenized_rhymes_[i], skip_special_tokens=True
209
+ )
210
+ )
211
+ )
212
+ return rhymes
213
+
214
+ def _mutate(
215
+ self,
216
+ tokenized_rhymes: np.ndarray,
217
+ position_probas: np.ndarray,
218
+ token_id_probas: np.ndarray,
219
+ ) -> np.ndarray:
220
+
221
+ replacements = []
222
+ for i in range(tokenized_rhymes.shape[0]):
223
+ mask_idx, masked_token_ids = self._mask_token(
224
+ tokenized_rhymes[i], position_probas[i]
225
+ )
226
+ tokenized_rhymes[i] = masked_token_ids
227
+ replacements.append(mask_idx)
228
+
229
+ predictions = self._predict_masked_tokens(tokenized_rhymes)
230
+
231
+ for i, token_ids in enumerate(tokenized_rhymes):
232
+ replace_ix = replacements[i]
233
+ token_ids[replace_ix] = self._draw_replacement(
234
+ predictions[i], token_id_probas, replace_ix
235
+ )
236
+ tokenized_rhymes[i] = token_ids
237
+
238
+ return tokenized_rhymes
239
+
240
+ def _mask_token(self, token_ids, position_probas):
241
+ """Mask line and return index to update."""
242
+ token_ids = self._mask_repeats(token_ids, position_probas)
243
+ ix = self._locate_mask(token_ids, position_probas)
244
+ token_ids[ix] = self.mask_token_id
245
+ return ix, token_ids
246
+
247
+ def _locate_mask(self, token_ids, position_probas):
248
+ """Update masks or a random token."""
249
+ if self.mask_token_id in token_ids:
250
+ # Already masks present, just return the last.
251
+ # We used to return thee first but this returns worse predictions.
252
+ return np.where(token_ids == self.tokenizer.mask_token_id)[0][-1]
253
+ return np.random.choice(range(len(position_probas)), p=position_probas)
254
+
255
+ def _mask_repeats(self, token_ids, position_probas):
256
+ """Repeated tokens are generally of less quality."""
257
+ repeats = [
258
+ ii for ii, ids in enumerate(pairwise(token_ids[:-2])) if ids[0] == ids[1]
259
+ ]
260
+ for ii in repeats:
261
+ if position_probas[ii] > 0:
262
+ token_ids[ii] = self.mask_token_id
263
+ if position_probas[ii + 1] > 0:
264
+ token_ids[ii + 1] = self.mask_token_id
265
+ return token_ids
266
+
267
+ def _predict_masked_tokens(self, tokenized_rhymes):
268
+ return self.model(tf.constant(tokenized_rhymes))[0]
269
+
270
+ def _draw_replacement(self, predictions, token_probas, replace_ix):
271
+ """Get probability, weigh and draw."""
272
+ # TODO (HG): Can't we softmax when calling the model?
273
+ probas = tf.nn.softmax(predictions[replace_ix]).numpy() * token_probas
274
+ probas /= probas.sum()
275
+ return np.random.choice(range(len(probas)), p=probas)
276
+
277
+
278
+
279
+ def query_rhyme_words(sentence: str, n_rhymes: int, language:str="english") -> List[str]:
280
+ """Returns a list of rhyme words for a sentence.
281
+
282
+ Parameters
283
+ ----------
284
+ sentence : Sentence that may end with punctuation
285
+ n_rhymes : Maximum number of rhymes to return
286
+
287
+ Returns
288
+ -------
289
+ List[str] -- List of words that rhyme with the final word
290
+ """
291
+ last_word = find_last_word(sentence)
292
+ if language == "english":
293
+ return query_datamuse_api(last_word, n_rhymes)
294
+ elif language == "dutch":
295
+ return mick_rijmwoordenboek(last_word, n_rhymes)
296
+ else:
297
+ raise NotImplementedError(f"Unsupported language ({language}) expected 'english' or 'dutch'.")
298
+
299
+
300
+ def query_datamuse_api(word: str, n_rhymes: Optional[int] = None) -> List[str]:
301
+ """Query the DataMuse API.
302
+
303
+ Parameters
304
+ ----------
305
+ word : Word to rhyme with
306
+ n_rhymes : Max rhymes to return
307
+
308
+ Returns
309
+ -------
310
+ Rhyme words
311
+ """
312
+ out = requests.get(
313
+ "https://api.datamuse.com/words", params={"rel_rhy": word}
314
+ ).json()
315
+ words = [_["word"] for _ in out]
316
+ if n_rhymes is None:
317
+ return words
318
+ return words[:n_rhymes]
319
+
320
+
321
+ @functools.lru_cache(maxsize=128, typed=False)
322
+ def mick_rijmwoordenboek(word: str, n_words: int):
323
+ url = f"https://rijmwoordenboek.nl/rijm/{word}"
324
+ html = get(url)
325
+ soup = Soup(html)
326
+
327
+ results = soup.find("div", {"id": "rhymeResultsWords"}).html.split("<br />")
328
+
329
+ # clean up
330
+ results = [r.replace("\n", "").replace(" ", "") for r in results]
331
+
332
+ # filter html and empty strings
333
+ results = [r for r in results if ("<" not in r) and (len(r) > 0)]
334
+
335
+ return random.sample(results, min(len(results), n_words))
336
+
337
+
338
+ import numpy as np
339
+
340
+
341
+ class TokenWeighter:
342
+ def __init__(self, tokenizer):
343
+ self.tokenizer_ = tokenizer
344
+ self.proba = self.get_token_proba()
345
+
346
+ def get_token_proba(self):
347
+ valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab)
348
+ return valid_token_mask
349
+
350
+ def _filter_short_partial(self, vocab):
351
+ valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k]
352
+ is_valid = np.zeros(len(vocab.keys()))
353
+ is_valid[valid_token_ids] = 1
354
+ return is_valid
355
+
356
+
357
+
358
+ def color_new_words(new: str, old: str, color: str = "#eefa66") -> str:
359
+ """Color new words in strings with a span."""
360
+
361
+ def find_diff(new_, old_):
362
+ return [ii for ii, (n, o) in enumerate(zip(new_, old_)) if n != o]
363
+
364
+ new_words = new.split()
365
+ old_words = old.split()
366
+ forward = find_diff(new_words, old_words)
367
+ backward = find_diff(new_words[::-1], old_words[::-1])
368
+
369
+ if not forward or not backward:
370
+ # No difference
371
+ return new
372
+
373
+ start, end = forward[0], len(new_words) - backward[0]
374
+ return (
375
+ " ".join(new_words[:start])
376
+ + " "
377
+ + f'<span style="background-color: {color}">'
378
+ + " ".join(new_words[start:end])
379
+ + "</span>"
380
+ + " "
381
+ + " ".join(new_words[end:])
382
+ )
383
+
384
+
385
+ def find_last_word(s):
386
+ """Find the last word in a string."""
387
+ # Note: will break on \n, \r, etc.
388
+ alpha_only_sentence = "".join([c for c in s if (c.isalpha() or (c == " "))]).strip()
389
+ return alpha_only_sentence.split()[-1]
390
+
391
+
392
+ def pairwise(iterable):
393
+ """s -> (s0,s1), (s1,s2), (s2, s3), ..."""
394
+ # https://stackoverflow.com/questions/5434891/iterate-a-list-as-pair-current-next-in-python
395
+ a, b = itertools.tee(iterable)
396
+ next(b, None)
397
+ return zip(a, b)
398
+
399
+
400
+ def sanitize(s):
401
+ """Remove punctuation from a string."""
402
+ return s.translate(str.maketrans("", "", string.punctuation))
403
+
404
+
405
+ if __name__ == "__main__":
406
+ main()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gazpacho
2
+ numpy
3
+ requests
4
+ tensorflow
5
+ transformers