hgrif commited on
Commit
5ddb621
Β·
1 Parent(s): b57caec

Add rhyme.py

Browse files
Files changed (2) hide show
  1. app.py +3 -82
  2. rhyme_with_ai/rhyme.py +69 -0
app.py CHANGED
@@ -1,18 +1,13 @@
1
  import copy
2
- import functools
3
- import itertools
4
  import logging
5
- import random
6
- import string
7
- from typing import List, Optional
8
 
9
- import requests
10
  import numpy as np
11
  import tensorflow as tf
12
  import streamlit as st
13
- from gazpacho import Soup, get
14
  from transformers import BertTokenizer, TFAutoModelForMaskedLM
15
- from rhyme_with_ai.utils import color_new_words, pairwise, find_last_word, sanitize
 
16
 
17
 
18
  DEFAULT_QUERY = "Machines will take over the world soon"
@@ -102,21 +97,6 @@ def display_output(status_text, query, current_sentences, previous_sentences):
102
  query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True
103
  )
104
 
105
- class TokenWeighter:
106
- def __init__(self, tokenizer):
107
- self.tokenizer_ = tokenizer
108
- self.proba = self.get_token_proba()
109
-
110
- def get_token_proba(self):
111
- valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab)
112
- return valid_token_mask
113
-
114
- def _filter_short_partial(self, vocab):
115
- valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k]
116
- is_valid = np.zeros(len(vocab.keys()))
117
- is_valid[valid_token_ids] = 1
118
- return is_valid
119
-
120
 
121
  class RhymeGenerator:
122
  def __init__(
@@ -291,65 +271,6 @@ class RhymeGenerator:
291
 
292
 
293
 
294
- def query_rhyme_words(sentence: str, n_rhymes: int, language:str="english") -> List[str]:
295
- """Returns a list of rhyme words for a sentence.
296
-
297
- Parameters
298
- ----------
299
- sentence : Sentence that may end with punctuation
300
- n_rhymes : Maximum number of rhymes to return
301
-
302
- Returns
303
- -------
304
- List[str] -- List of words that rhyme with the final word
305
- """
306
- last_word = find_last_word(sentence)
307
- if language == "english":
308
- return query_datamuse_api(last_word, n_rhymes)
309
- elif language == "dutch":
310
- return mick_rijmwoordenboek(last_word, n_rhymes)
311
- else:
312
- raise NotImplementedError(f"Unsupported language ({language}) expected 'english' or 'dutch'.")
313
-
314
-
315
- def query_datamuse_api(word: str, n_rhymes: Optional[int] = None) -> List[str]:
316
- """Query the DataMuse API.
317
-
318
- Parameters
319
- ----------
320
- word : Word to rhyme with
321
- n_rhymes : Max rhymes to return
322
-
323
- Returns
324
- -------
325
- Rhyme words
326
- """
327
- out = requests.get(
328
- "https://api.datamuse.com/words", params={"rel_rhy": word}
329
- ).json()
330
- words = [_["word"] for _ in out]
331
- if n_rhymes is None:
332
- return words
333
- return words[:n_rhymes]
334
-
335
-
336
- @functools.lru_cache(maxsize=128, typed=False)
337
- def mick_rijmwoordenboek(word: str, n_words: int):
338
- url = f"https://rijmwoordenboek.nl/rijm/{word}"
339
- html = get(url)
340
- soup = Soup(html)
341
-
342
- results = soup.find("div", {"id": "rhymeResultsWords"}).html.split("<br>")
343
-
344
- # clean up
345
- results = [r.replace("\n", "").replace(" ", "") for r in results]
346
-
347
- # filter html and empty strings
348
- results = [r for r in results if ("<" not in r) and (len(r) > 0)]
349
-
350
- return random.sample(results, min(len(results), n_words))
351
-
352
-
353
  if __name__ == "__main__":
354
  logging.basicConfig(level=logging.INFO)
355
  main()
 
1
  import copy
 
 
2
  import logging
3
+ from typing import List
 
 
4
 
 
5
  import numpy as np
6
  import tensorflow as tf
7
  import streamlit as st
 
8
  from transformers import BertTokenizer, TFAutoModelForMaskedLM
9
+ from rhyme_with_ai.utils import color_new_words, pairwise, sanitize
10
+ from rhyme_with_ai.token_weighter import TokenWeighter
11
 
12
 
13
  DEFAULT_QUERY = "Machines will take over the world soon"
 
97
  query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True
98
  )
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  class RhymeGenerator:
102
  def __init__(
 
271
 
272
 
273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  if __name__ == "__main__":
275
  logging.basicConfig(level=logging.INFO)
276
  main()
rhyme_with_ai/rhyme.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import random
3
+ from typing import List, Optional
4
+
5
+ import requests
6
+ from gazpacho import Soup, get
7
+
8
+ from rhyme_with_ai.utils import find_last_word
9
+
10
+
11
+ def query_rhyme_words(sentence: str, n_rhymes: int, language:str="english") -> List[str]:
12
+ """Returns a list of rhyme words for a sentence.
13
+
14
+ Parameters
15
+ ----------
16
+ sentence : Sentence that may end with punctuation
17
+ n_rhymes : Maximum number of rhymes to return
18
+
19
+ Returns
20
+ -------
21
+ List[str] -- List of words that rhyme with the final word
22
+ """
23
+ last_word = find_last_word(sentence)
24
+ if language == "english":
25
+ return query_datamuse_api(last_word, n_rhymes)
26
+ elif language == "dutch":
27
+ return mick_rijmwoordenboek(last_word, n_rhymes)
28
+ else:
29
+ raise NotImplementedError(f"Unsupported language ({language}) expected 'english' or 'dutch'.")
30
+
31
+
32
+ def query_datamuse_api(word: str, n_rhymes: Optional[int] = None) -> List[str]:
33
+ """Query the DataMuse API.
34
+
35
+ Parameters
36
+ ----------
37
+ word : Word to rhyme with
38
+ n_rhymes : Max rhymes to return
39
+
40
+ Returns
41
+ -------
42
+ Rhyme words
43
+ """
44
+ out = requests.get(
45
+ "https://api.datamuse.com/words", params={"rel_rhy": word}
46
+ ).json()
47
+ words = [_["word"] for _ in out]
48
+ if n_rhymes is None:
49
+ return words
50
+ return words[:n_rhymes]
51
+
52
+
53
+ @functools.lru_cache(maxsize=128, typed=False)
54
+ def mick_rijmwoordenboek(word: str, n_words: int):
55
+ url = f"https://rijmwoordenboek.nl/rijm/{word}"
56
+ html = get(url)
57
+ soup = Soup(html)
58
+
59
+ results = soup.find("div", {"id": "rhymeResultsWords"}).html.split("<br>")
60
+
61
+ # clean up
62
+ results = [r.replace("\n", "").replace(" ", "") for r in results]
63
+
64
+ # filter html and empty strings
65
+ results = [r for r in results if ("<" not in r) and (len(r) > 0)]
66
+
67
+ return random.sample(results, min(len(results), n_words))
68
+
69
+