Spaces:
Build error
Build error
Fix old type hints
Browse files
app.py
CHANGED
@@ -11,8 +11,6 @@ import numpy as np
|
|
11 |
import tensorflow as tf
|
12 |
import streamlit as st
|
13 |
from gazpacho import Soup, get
|
14 |
-
from transformers.modeling_tf_bert import TFBertForPreTraining
|
15 |
-
from transformers.tokenization_bert import PreTrainedTokenizer
|
16 |
from transformers import BertTokenizer, TFBertForMaskedLM
|
17 |
|
18 |
|
@@ -102,12 +100,27 @@ def display_output(status_text, query, current_sentences, previous_sentences):
|
|
102 |
query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True
|
103 |
)
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
class RhymeGenerator:
|
107 |
def __init__(
|
108 |
self,
|
109 |
-
model:
|
110 |
-
tokenizer:
|
111 |
token_weighter: TokenWeighter = None,
|
112 |
):
|
113 |
"""Generate rhymes.
|
@@ -335,26 +348,6 @@ def mick_rijmwoordenboek(word: str, n_words: int):
|
|
335 |
return random.sample(results, min(len(results), n_words))
|
336 |
|
337 |
|
338 |
-
import numpy as np
|
339 |
-
|
340 |
-
|
341 |
-
class TokenWeighter:
|
342 |
-
def __init__(self, tokenizer):
|
343 |
-
self.tokenizer_ = tokenizer
|
344 |
-
self.proba = self.get_token_proba()
|
345 |
-
|
346 |
-
def get_token_proba(self):
|
347 |
-
valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab)
|
348 |
-
return valid_token_mask
|
349 |
-
|
350 |
-
def _filter_short_partial(self, vocab):
|
351 |
-
valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k]
|
352 |
-
is_valid = np.zeros(len(vocab.keys()))
|
353 |
-
is_valid[valid_token_ids] = 1
|
354 |
-
return is_valid
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
def color_new_words(new: str, old: str, color: str = "#eefa66") -> str:
|
359 |
"""Color new words in strings with a span."""
|
360 |
|
|
|
11 |
import tensorflow as tf
|
12 |
import streamlit as st
|
13 |
from gazpacho import Soup, get
|
|
|
|
|
14 |
from transformers import BertTokenizer, TFBertForMaskedLM
|
15 |
|
16 |
|
|
|
100 |
query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True
|
101 |
)
|
102 |
|
103 |
+
class TokenWeighter:
|
104 |
+
def __init__(self, tokenizer):
|
105 |
+
self.tokenizer_ = tokenizer
|
106 |
+
self.proba = self.get_token_proba()
|
107 |
+
|
108 |
+
def get_token_proba(self):
|
109 |
+
valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab)
|
110 |
+
return valid_token_mask
|
111 |
+
|
112 |
+
def _filter_short_partial(self, vocab):
|
113 |
+
valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k]
|
114 |
+
is_valid = np.zeros(len(vocab.keys()))
|
115 |
+
is_valid[valid_token_ids] = 1
|
116 |
+
return is_valid
|
117 |
+
|
118 |
|
119 |
class RhymeGenerator:
|
120 |
def __init__(
|
121 |
self,
|
122 |
+
model: TFBertForMaskedLM,
|
123 |
+
tokenizer: BertTokenizer,
|
124 |
token_weighter: TokenWeighter = None,
|
125 |
):
|
126 |
"""Generate rhymes.
|
|
|
348 |
return random.sample(results, min(len(results), n_words))
|
349 |
|
350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
def color_new_words(new: str, old: str, color: str = "#eefa66") -> str:
|
352 |
"""Color new words in strings with a span."""
|
353 |
|