File size: 3,395 Bytes
251b174
1431830
5ddb621
251b174
 
f9afcb3
6a177e5
 
5f80ca4
6a177e5
251b174
 
 
 
 
 
 
 
047392f
462d4c0
251b174
9204ef7
462d4c0
251b174
 
 
 
 
 
 
f58c040
251b174
 
 
 
 
 
 
 
a162482
251b174
 
 
 
 
3aa3b62
251b174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81d3f6d
f9afcb3
251b174
 
 
 
 
 
 
 
 
 
 
 
 
1431830
 
251b174
fc81a7b
251b174
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import copy
import logging
from typing import List

import streamlit as st
from transformers import BertTokenizer, TFAutoModelForMaskedLM

from rhyme_with_ai.utils import color_new_words, sanitize
from rhyme_with_ai.rhyme import query_rhyme_words
from rhyme_with_ai.rhyme_generator import RhymeGenerator


DEFAULT_QUERY = "Machines will take over the world soon"
N_RHYMES = 10


LANGUAGE = st.sidebar.radio("Language", ["english", "dutch"],0)
if LANGUAGE == "english":
    MODEL_PATH = "bert-large-cased-whole-word-masking"
    ITER_FACTOR = 5
elif LANGUAGE == "dutch":
    MODEL_PATH = "GroNLP/bert-base-dutch-cased"
    ITER_FACTOR = 10  # Faster model
else:
    raise NotImplementedError(f"Unsupported language ({LANGUAGE}) expected 'english' or 'dutch'.")

def main():
    st.markdown(
        "<sup>Created with "
        "[Datamuse](https://www.datamuse.com/api/), "
        "[Mick's rijmwoordenboek](https://rijmwoordenboek.nl), "
        "[Hugging Face](https://huggingface.co/), "
        "[Streamlit](https://streamlit.io/) and "
        "[App Engine](https://cloud.google.com/appengine/)."
        " Read our [blog](https://blog.godatadriven.com/rhyme-with-ai) "
        "or check the "
        "[source](https://github.com/godatadriven/rhyme-with-ai).</sup>",
        unsafe_allow_html=True,
    )
    st.title("Rhyme with AI - Hi there! πŸ‘‹")
    query = get_query()
    if not query:
        query = DEFAULT_QUERY
    rhyme_words_options = query_rhyme_words(query, n_rhymes=N_RHYMES,language=LANGUAGE)
    if rhyme_words_options:
        logging.getLogger(__name__).info("Got rhyme words: %s", rhyme_words_options)
        start_rhyming(query, rhyme_words_options)
    else:
        st.write("No rhyme words found")


def get_query():
    q = sanitize(
        st.text_input("Write your first line and press ENTER to rhyme:", DEFAULT_QUERY)
    )
    if not q:
        return DEFAULT_QUERY
    return q


def start_rhyming(query, rhyme_words_options):
    st.markdown("## My Suggestions:")

    progress_bar = st.progress(0)
    status_text = st.empty()
    max_iter = len(query.split()) * ITER_FACTOR

    rhyme_words = rhyme_words_options[:N_RHYMES]

    model, tokenizer = load_model(MODEL_PATH)
    sentence_generator = RhymeGenerator(model, tokenizer)
    sentence_generator.start(query, rhyme_words)

    current_sentences = [" " for _ in range(N_RHYMES)]
    for i in range(max_iter):
        previous_sentences = copy.deepcopy(current_sentences)
        current_sentences = sentence_generator.mutate()
        display_output(status_text, query, current_sentences, previous_sentences)
        progress_bar.progress(i / (max_iter - 1))
    st.balloons()


@st.cache(allow_output_mutation=True)
def load_model(model_path):
    return (
        TFAutoModelForMaskedLM.from_pretrained(model_path),
        BertTokenizer.from_pretrained(model_path),
    )


def display_output(status_text, query, current_sentences, previous_sentences):
    print_sentences = []
    for new, old in zip(current_sentences, previous_sentences):
        formatted = color_new_words(new, old)
        after_comma = "<li>" + formatted.split(",")[1][:-2] + "</li>"
        print_sentences.append(after_comma)
    status_text.markdown(
        query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True
    )



if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    main()