Spaces:
Runtime error
Runtime error
arjunpatel
commited on
Commit
·
b4075da
1
Parent(s):
61e66c3
Formatting and explainers
Browse files- gradio_demo.py +83 -11
gradio_demo.py
CHANGED
@@ -3,6 +3,7 @@ from transformers import AutoTokenizer
|
|
3 |
from transformers import pipeline
|
4 |
from utils import format_moves
|
5 |
import pandas as pd
|
|
|
6 |
|
7 |
model_checkpoint = "distilgpt2"
|
8 |
|
@@ -13,7 +14,6 @@ generate = pipeline("text-generation",
|
|
13 |
tokenizer=tokenizer)
|
14 |
# load in the model
|
15 |
seed_text = "This move is called "
|
16 |
-
import tensorflow as tf
|
17 |
|
18 |
tf.random.set_seed(0)
|
19 |
|
@@ -80,24 +80,59 @@ demo = gr.Blocks()
|
|
80 |
with demo:
|
81 |
gr.Markdown("<h1><center>What's that Pokemon Move?</center></h1>")
|
82 |
gr.Markdown(
|
83 |
-
"This Gradio demo
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
with gr.Tabs():
|
86 |
with gr.TabItem("Standard Generation"):
|
|
|
|
|
|
|
87 |
with gr.Row():
|
88 |
text_input_baseline = gr.Textbox(label="Move",
|
89 |
-
placeholder="Type a two or three word move name here! Try \"Wonder
|
|
|
90 |
text_output_baseline = gr.Textbox(label="Move Description",
|
91 |
placeholder="Leave this blank!")
|
92 |
text_button_baseline = gr.Button("Create my move!")
|
93 |
-
with gr.TabItem("Greedy Search"):
|
94 |
-
gr.Markdown("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
with gr.Row():
|
96 |
text_input_greedy = gr.Textbox(label="Move")
|
97 |
text_output_greedy = gr.Textbox(label="Move Description")
|
98 |
text_button_greedy = gr.Button("Create my move!")
|
99 |
with gr.TabItem("Beam Search"):
|
100 |
gr.Markdown("This tab lets you learn about using beam search!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
with gr.Row():
|
102 |
num_beams = gr.Slider(minimum=2, maximum=10, value=2, step=1,
|
103 |
label="Number of Beams")
|
@@ -106,24 +141,61 @@ with demo:
|
|
106 |
text_button_beam = gr.Button("Create my move!")
|
107 |
with gr.TabItem("Sampling and Temperature Search"):
|
108 |
gr.Markdown("This tab lets you experiment with adjusting the temperature of the generator")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
with gr.Row():
|
110 |
temperature = gr.Slider(minimum=0.3, maximum=4.0, value=1.0, step=0.1,
|
111 |
label="Temperature")
|
112 |
-
sample_boolean = gr.Checkbox(label="Enable Sampling?")
|
113 |
text_input_temp = gr.Textbox(label="Move")
|
|
|
|
|
114 |
text_output_temp = gr.Textbox(label="Move Description")
|
115 |
text_button_temp = gr.Button("Create my move!")
|
116 |
with gr.TabItem("Top K and Top P Sampling"):
|
117 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
with gr.Row():
|
119 |
-
topk = gr.Slider(minimum=
|
120 |
label="Top K")
|
121 |
-
|
122 |
-
label="Top P")
|
123 |
text_input_top = gr.Textbox(label="Move")
|
|
|
|
|
|
|
124 |
text_output_top = gr.Textbox(label="Move Description")
|
125 |
text_button_top = gr.Button("Create my move!")
|
126 |
with gr.Box():
|
|
|
127 |
# Displays a dataframe with the history of moves generated, with parameters
|
128 |
history = gr.Dataframe(headers=["Move Name", "Move Description", "Generation Type", "Parameters"])
|
129 |
|
|
|
3 |
from transformers import pipeline
|
4 |
from utils import format_moves
|
5 |
import pandas as pd
|
6 |
+
import tensorflow as tf
|
7 |
|
8 |
model_checkpoint = "distilgpt2"
|
9 |
|
|
|
14 |
tokenizer=tokenizer)
|
15 |
# load in the model
|
16 |
seed_text = "This move is called "
|
|
|
17 |
|
18 |
tf.random.set_seed(0)
|
19 |
|
|
|
80 |
with demo:
|
81 |
gr.Markdown("<h1><center>What's that Pokemon Move?</center></h1>")
|
82 |
gr.Markdown(
|
83 |
+
"""This Gradio demo allows you to generate Pokemon Move descriptions given a name, and learn more about text
|
84 |
+
decoding methods in the process! Each tab aims to explain each generation methodology available for the
|
85 |
+
model. The dataframe below allows you to keep track of each move generated, to compare!""")
|
86 |
+
gr.Markdown("<h3> How does text generation work? <h3>")
|
87 |
+
gr.Markdown("""Roughly, text generation models accept an input sequence of words (or parts of words, known as tokens.
|
88 |
+
These models then output a corresponding set of words or tokens. Given the input, the model
|
89 |
+
estimates the probability of another possible word or token appearing right after the given sequence. In
|
90 |
+
other words, the model estimates conditional probabilities and ranks them in order to generate sequences
|
91 |
+
. """)
|
92 |
+
gr.Markdown("Enter a two to three word Pokemon Move name of your imagination below, with each word capitalized!")
|
93 |
+
gr.Markdown("<h3> Move Generation <h3>")
|
94 |
with gr.Tabs():
|
95 |
with gr.TabItem("Standard Generation"):
|
96 |
+
gr.Markdown(
|
97 |
+
"""The default parameters for distilgpt2 work well to generate moves. Use this tab to have fun and as
|
98 |
+
a baseline for your experiments.""")
|
99 |
with gr.Row():
|
100 |
text_input_baseline = gr.Textbox(label="Move",
|
101 |
+
placeholder="Type a two or three word move name here! Try \"Wonder "
|
102 |
+
"Shield\"!")
|
103 |
text_output_baseline = gr.Textbox(label="Move Description",
|
104 |
placeholder="Leave this blank!")
|
105 |
text_button_baseline = gr.Button("Create my move!")
|
106 |
+
with gr.TabItem("Greedy Search Decoding"):
|
107 |
+
gr.Markdown("""
|
108 |
+
|
109 |
+
Greedy search is a decoding method that relies on finding words that has the highest estimated
|
110 |
+
probability of following the sequence thus far.
|
111 |
+
|
112 |
+
Therefore, the model \"greedily\" grabs the highest
|
113 |
+
probability word and continues generating the sentence.
|
114 |
+
|
115 |
+
This has the side effect of finding sequences that are reasonable, but avoids sequences that are
|
116 |
+
less probable but way more interesting.
|
117 |
+
Try the other decoding methods to get sentences with more variety!
|
118 |
+
""")
|
119 |
with gr.Row():
|
120 |
text_input_greedy = gr.Textbox(label="Move")
|
121 |
text_output_greedy = gr.Textbox(label="Move Description")
|
122 |
text_button_greedy = gr.Button("Create my move!")
|
123 |
with gr.TabItem("Beam Search"):
|
124 |
gr.Markdown("This tab lets you learn about using beam search!")
|
125 |
+
gr.Markdown("""Beam search is an improvement on Greedy Search. Instead of directly grabbing the word that
|
126 |
+
maximizes probability, we conduct a search with B number of candidates. We then try to find the next word
|
127 |
+
that would most likely follow each beam, and we grab the top B candidates of that search. This may
|
128 |
+
eliminate one of the original beams we started with, and that's okay! That is how the algorithm decides
|
129 |
+
on an optimal candidate. Eventually, the beam sequence terminate or are eliminated due to being too improbale.
|
130 |
+
|
131 |
+
Increasing the number of beams will increase model generation time, but also result in a more thorough search.
|
132 |
+
Decreasing the number of beams will decrease decoding time, but it may not find an optimal sentence.
|
133 |
+
|
134 |
+
Play around with the num_beams parameter to experiment! """
|
135 |
+
)
|
136 |
with gr.Row():
|
137 |
num_beams = gr.Slider(minimum=2, maximum=10, value=2, step=1,
|
138 |
label="Number of Beams")
|
|
|
141 |
text_button_beam = gr.Button("Create my move!")
|
142 |
with gr.TabItem("Sampling and Temperature Search"):
|
143 |
gr.Markdown("This tab lets you experiment with adjusting the temperature of the generator")
|
144 |
+
gr.Markdown(
|
145 |
+
"""
|
146 |
+
Greedy Search and Beam Search were both good at finding sequences that are likely to follow our input text,
|
147 |
+
but when generating cool move descriptions, we want some more variety!
|
148 |
+
|
149 |
+
Instead of choosing the word or token that is most likely to follow a given sequence, we can instead
|
150 |
+
ask the model to sample across the probability distribution of likely words. It's kind of like walking
|
151 |
+
into the tall grass and finding a Pokemon encounter. There are different encounter rates, which allow
|
152 |
+
for the most common mons to appear (looking at you, Zubat), but also account for surprise, like shinys!
|
153 |
+
|
154 |
+
We might even want to go further, though. We can rescale the probability distributions directly instead,
|
155 |
+
allowing for rare words to temporarily become more frequently. We do this using the temperature parameter.
|
156 |
+
|
157 |
+
Turn the temperature up, and rare tokens become very likely! Cool down, and we approach more sensible output.
|
158 |
+
|
159 |
+
Experiment with turning sampling on and off, and by varying temperature below!.
|
160 |
+
""")
|
161 |
with gr.Row():
|
162 |
temperature = gr.Slider(minimum=0.3, maximum=4.0, value=1.0, step=0.1,
|
163 |
label="Temperature")
|
|
|
164 |
text_input_temp = gr.Textbox(label="Move")
|
165 |
+
with gr.Row():
|
166 |
+
sample_boolean = gr.Checkbox(label="Enable Sampling?")
|
167 |
text_output_temp = gr.Textbox(label="Move Description")
|
168 |
text_button_temp = gr.Button("Create my move!")
|
169 |
with gr.TabItem("Top K and Top P Sampling"):
|
170 |
+
gr.Markdown(
|
171 |
+
"""
|
172 |
+
When we want more control over the words we get to sample from, we turn to Top K and Top P decoding methods!
|
173 |
+
|
174 |
+
|
175 |
+
The Top K sampling method selects the K most probable words given a sequence, and then samples from that subset,
|
176 |
+
rather than the whole vocabulary. This effectively cuts out low probability words.
|
177 |
+
|
178 |
+
|
179 |
+
Top P also reduces the available vocabulary to sample from, but instead of choosing the number of
|
180 |
+
words or tokens in advance, we sort the vocabulary from most to least likely word, and we
|
181 |
+
grab the smallest set of words that sum to P. This allows for the number of words we look at to
|
182 |
+
change while sampling, instead of being fixed.
|
183 |
+
|
184 |
+
We can even use both methods at the same time! To disable Top K, set it to 0 using the slider.
|
185 |
+
To disable Top P, set it to 1""")
|
186 |
+
|
187 |
with gr.Row():
|
188 |
+
topk = gr.Slider(minimum=0, maximum=200, value=0, step=5,
|
189 |
label="Top K")
|
190 |
+
|
|
|
191 |
text_input_top = gr.Textbox(label="Move")
|
192 |
+
with gr.Row():
|
193 |
+
topp = gr.Slider(minimum=0.10, maximum=1, value=1, step=0.05,
|
194 |
+
label="Top P")
|
195 |
text_output_top = gr.Textbox(label="Move Description")
|
196 |
text_button_top = gr.Button("Create my move!")
|
197 |
with gr.Box():
|
198 |
+
gr.Markdown("<h3> Generation History <h3>")
|
199 |
# Displays a dataframe with the history of moves generated, with parameters
|
200 |
history = gr.Dataframe(headers=["Move Name", "Move Description", "Generation Type", "Parameters"])
|
201 |
|