hgrif commited on
Commit
451b919
Β·
1 Parent(s): 251b174

Fix old type hints

Browse files
Files changed (1) hide show
  1. app.py +17 -24
app.py CHANGED
@@ -11,8 +11,6 @@ import numpy as np
11
  import tensorflow as tf
12
  import streamlit as st
13
  from gazpacho import Soup, get
14
- from transformers.modeling_tf_bert import TFBertForPreTraining
15
- from transformers.tokenization_bert import PreTrainedTokenizer
16
  from transformers import BertTokenizer, TFBertForMaskedLM
17
 
18
 
@@ -102,12 +100,27 @@ def display_output(status_text, query, current_sentences, previous_sentences):
102
  query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True
103
  )
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  class RhymeGenerator:
107
  def __init__(
108
  self,
109
- model: TFBertForPreTraining,
110
- tokenizer: PreTrainedTokenizer,
111
  token_weighter: TokenWeighter = None,
112
  ):
113
  """Generate rhymes.
@@ -335,26 +348,6 @@ def mick_rijmwoordenboek(word: str, n_words: int):
335
  return random.sample(results, min(len(results), n_words))
336
 
337
 
338
- import numpy as np
339
-
340
-
341
- class TokenWeighter:
342
- def __init__(self, tokenizer):
343
- self.tokenizer_ = tokenizer
344
- self.proba = self.get_token_proba()
345
-
346
- def get_token_proba(self):
347
- valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab)
348
- return valid_token_mask
349
-
350
- def _filter_short_partial(self, vocab):
351
- valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k]
352
- is_valid = np.zeros(len(vocab.keys()))
353
- is_valid[valid_token_ids] = 1
354
- return is_valid
355
-
356
-
357
-
358
  def color_new_words(new: str, old: str, color: str = "#eefa66") -> str:
359
  """Color new words in strings with a span."""
360
 
 
11
  import tensorflow as tf
12
  import streamlit as st
13
  from gazpacho import Soup, get
 
 
14
  from transformers import BertTokenizer, TFBertForMaskedLM
15
 
16
 
 
100
  query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True
101
  )
102
 
103
+ class TokenWeighter:
104
+ def __init__(self, tokenizer):
105
+ self.tokenizer_ = tokenizer
106
+ self.proba = self.get_token_proba()
107
+
108
+ def get_token_proba(self):
109
+ valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab)
110
+ return valid_token_mask
111
+
112
+ def _filter_short_partial(self, vocab):
113
+ valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k]
114
+ is_valid = np.zeros(len(vocab.keys()))
115
+ is_valid[valid_token_ids] = 1
116
+ return is_valid
117
+
118
 
119
  class RhymeGenerator:
120
  def __init__(
121
  self,
122
+ model: TFBertForMaskedLM,
123
+ tokenizer: BertTokenizer,
124
  token_weighter: TokenWeighter = None,
125
  ):
126
  """Generate rhymes.
 
348
  return random.sample(results, min(len(results), n_words))
349
 
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  def color_new_words(new: str, old: str, color: str = "#eefa66") -> str:
352
  """Color new words in strings with a span."""
353