Spaces:
Build error
Build error
import logging | |
from typing import List | |
import numpy as np | |
import tensorflow as tf | |
from transformers import BertTokenizer, TFAutoModelForMaskedLM | |
from rhyme_with_ai.token_weighter import TokenWeighter | |
from rhyme_with_ai.utils import pairwise | |
class RhymeGenerator: | |
def __init__( | |
self, | |
model: TFAutoModelForMaskedLM, | |
tokenizer: BertTokenizer, | |
token_weighter: TokenWeighter = None, | |
): | |
"""Generate rhymes. | |
Parameters | |
---------- | |
model : Model for masked language modelling | |
tokenizer : Tokenizer for model | |
token_weighter : Class that weighs tokens | |
""" | |
self.model = model | |
self.tokenizer = tokenizer | |
if token_weighter is None: | |
token_weighter = TokenWeighter(tokenizer) | |
self.token_weighter = token_weighter | |
self._logger = logging.getLogger(__name__) | |
self.tokenized_rhymes_ = None | |
self.position_probas_ = None | |
# Easy access. | |
self.comma_token_id = self.tokenizer.encode(",", add_special_tokens=False)[0] | |
self.period_token_id = self.tokenizer.encode(".", add_special_tokens=False)[0] | |
self.mask_token_id = self.tokenizer.mask_token_id | |
def start(self, query: str, rhyme_words: List[str]) -> None: | |
"""Start the sentence generator. | |
Parameters | |
---------- | |
query : Seed sentence | |
rhyme_words : Rhyme words for next sentence | |
""" | |
# TODO: What if no content? | |
self._logger.info("Got sentence %s", query) | |
tokenized_rhymes = [ | |
self._initialize_rhymes(query, rhyme_word) for rhyme_word in rhyme_words | |
] | |
# Make same length. | |
self.tokenized_rhymes_ = tf.keras.preprocessing.sequence.pad_sequences( | |
tokenized_rhymes, padding="post", value=self.tokenizer.pad_token_id | |
) | |
p = self.tokenized_rhymes_ == self.tokenizer.mask_token_id | |
self.position_probas_ = p / p.sum(1).reshape(-1, 1) | |
def _initialize_rhymes(self, query: str, rhyme_word: str) -> List[int]: | |
"""Initialize the rhymes. | |
* Tokenize input | |
* Append a comma if the sentence does not end in it (might add better predictions as it | |
shows the two sentence parts are related) | |
* Make second line as long as the original | |
* Add a period | |
Parameters | |
---------- | |
query : First line | |
rhyme_word : Last word for second line | |
Returns | |
------- | |
Tokenized rhyme lines | |
""" | |
query_token_ids = self.tokenizer.encode(query, add_special_tokens=False) | |
rhyme_word_token_ids = self.tokenizer.encode( | |
rhyme_word, add_special_tokens=False | |
) | |
if query_token_ids[-1] != self.comma_token_id: | |
query_token_ids.append(self.comma_token_id) | |
magic_correction = len(rhyme_word_token_ids) + 1 # 1 for comma | |
return ( | |
query_token_ids | |
+ [self.tokenizer.mask_token_id] * (len(query_token_ids) - magic_correction) | |
+ rhyme_word_token_ids | |
+ [self.period_token_id] | |
) | |
def mutate(self): | |
"""Mutate the current rhymes. | |
Returns | |
------- | |
Mutated rhymes | |
""" | |
self.tokenized_rhymes_ = self._mutate( | |
self.tokenized_rhymes_, self.position_probas_, self.token_weighter.proba | |
) | |
rhymes = [] | |
for i in range(len(self.tokenized_rhymes_)): | |
rhymes.append( | |
self.tokenizer.convert_tokens_to_string( | |
self.tokenizer.convert_ids_to_tokens( | |
self.tokenized_rhymes_[i], skip_special_tokens=True | |
) | |
) | |
) | |
return rhymes | |
def _mutate( | |
self, | |
tokenized_rhymes: np.ndarray, | |
position_probas: np.ndarray, | |
token_id_probas: np.ndarray, | |
) -> np.ndarray: | |
replacements = [] | |
for i in range(tokenized_rhymes.shape[0]): | |
mask_idx, masked_token_ids = self._mask_token( | |
tokenized_rhymes[i], position_probas[i] | |
) | |
tokenized_rhymes[i] = masked_token_ids | |
replacements.append(mask_idx) | |
predictions = self._predict_masked_tokens(tokenized_rhymes) | |
for i, token_ids in enumerate(tokenized_rhymes): | |
replace_ix = replacements[i] | |
token_ids[replace_ix] = self._draw_replacement( | |
predictions[i], token_id_probas, replace_ix | |
) | |
tokenized_rhymes[i] = token_ids | |
return tokenized_rhymes | |
def _mask_token(self, token_ids, position_probas): | |
"""Mask line and return index to update.""" | |
token_ids = self._mask_repeats(token_ids, position_probas) | |
ix = self._locate_mask(token_ids, position_probas) | |
token_ids[ix] = self.mask_token_id | |
return ix, token_ids | |
def _locate_mask(self, token_ids, position_probas): | |
"""Update masks or a random token.""" | |
if self.mask_token_id in token_ids: | |
# Already masks present, just return the last. | |
# We used to return thee first but this returns worse predictions. | |
return np.where(token_ids == self.tokenizer.mask_token_id)[0][-1] | |
return np.random.choice(range(len(position_probas)), p=position_probas) | |
def _mask_repeats(self, token_ids, position_probas): | |
"""Repeated tokens are generally of less quality.""" | |
repeats = [ | |
ii for ii, ids in enumerate(pairwise(token_ids[:-2])) if ids[0] == ids[1] | |
] | |
for ii in repeats: | |
if position_probas[ii] > 0: | |
token_ids[ii] = self.mask_token_id | |
if position_probas[ii + 1] > 0: | |
token_ids[ii + 1] = self.mask_token_id | |
return token_ids | |
def _predict_masked_tokens(self, tokenized_rhymes): | |
return self.model(tf.constant(tokenized_rhymes))[0] | |
def _draw_replacement(self, predictions, token_probas, replace_ix): | |
"""Get probability, weigh and draw.""" | |
# TODO (HG): Can't we softmax when calling the model? | |
probas = tf.nn.softmax(predictions[replace_ix]).numpy() * token_probas | |
probas /= probas.sum() | |
return np.random.choice(range(len(probas)), p=probas) | |