pokemon-move-generator-app / gradio_demo.py
arjunpatel's picture
History up and running
61e66c3
raw
history blame
7.04 kB
import gradio as gr
from transformers import AutoTokenizer
from transformers import pipeline
from utils import format_moves
import pandas as pd
model_checkpoint = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
generate = pipeline("text-generation",
model="arjunpatel/distilgpt2-finetuned-pokemon-moves",
tokenizer=tokenizer)
# load in the model
seed_text = "This move is called "
import tensorflow as tf
tf.random.set_seed(0)
# need a function to sanitize imputs
# - remove extra spaces
# - make sure each word is capitalized
# - format the moves such that it's clearer when each move is listed
# - play with the max length parameter abit, and try to remove sentences that don't end in periods.
def update_history(df, move_name, move_desc, generation, parameters):
# needs to format each move description with new lines to cut down on width
new_row = [{"Move Name": move_name,
"Move Description": move_desc,
"Generation Type": generation,
"Parameters": parameters}]
return pd.concat([df, pd.DataFrame(new_row)])
def create_move(move, history):
generated_move = format_moves(generate(seed_text + move, num_return_sequences=1))
return generated_move, update_history(history, move, generated_move,
"baseline", "None")
def create_greedy_search_move(move, history):
generated_move = format_moves(generate(seed_text + move, do_sample=False))
return generated_move, update_history(history, move, generated_move,
"greedy", "None")
def create_beam_search_move(move, num_beams, history):
generated_move = format_moves(generate(seed_text + move, num_beams=num_beams,
num_return_sequences=1,
do_sample=False, early_stopping=True))
return generated_move, update_history(history, move, generated_move,
"beam", {"num_beams": 2})
def create_sampling_search_move(move, do_sample, temperature, history):
generated_move = format_moves(generate(seed_text + move, do_sample=do_sample, temperature=float(temperature),
num_return_sequences=1, topk=0))
return generated_move, update_history(history, move, generated_move,
"temperature", {"do_sample": do_sample,
"temperature": temperature})
def create_top_search_move(move, topk, topp, history):
generated_move = format_moves(generate(
seed_text + move,
do_sample=True,
num_return_sequences=1,
top_k=topk,
top_p=topp,
force_word_ids=tokenizer.encode("The user", return_tensors='tf')))
return generated_move, update_history(history, move, generated_move,
"top", {"top k": topk,
"top p": topp})
demo = gr.Blocks()
with demo:
gr.Markdown("<h1><center>What's that Pokemon Move?</center></h1>")
gr.Markdown(
"This Gradio demo is a small GPT-2 model fine-tuned on a dataset of Pokemon moves! It'll generate a move description given a name.")
gr.Markdown("Enter a two to three word Pokemon Move name of your imagination below!")
with gr.Tabs():
with gr.TabItem("Standard Generation"):
with gr.Row():
text_input_baseline = gr.Textbox(label="Move",
placeholder="Type a two or three word move name here! Try \"Wonder Shield\"!")
text_output_baseline = gr.Textbox(label="Move Description",
placeholder="Leave this blank!")
text_button_baseline = gr.Button("Create my move!")
with gr.TabItem("Greedy Search"):
gr.Markdown("This tab lets you learn about using greedy search!")
with gr.Row():
text_input_greedy = gr.Textbox(label="Move")
text_output_greedy = gr.Textbox(label="Move Description")
text_button_greedy = gr.Button("Create my move!")
with gr.TabItem("Beam Search"):
gr.Markdown("This tab lets you learn about using beam search!")
with gr.Row():
num_beams = gr.Slider(minimum=2, maximum=10, value=2, step=1,
label="Number of Beams")
text_input_beam = gr.Textbox(label="Move")
text_output_beam = gr.Textbox(label="Move Description")
text_button_beam = gr.Button("Create my move!")
with gr.TabItem("Sampling and Temperature Search"):
gr.Markdown("This tab lets you experiment with adjusting the temperature of the generator")
with gr.Row():
temperature = gr.Slider(minimum=0.3, maximum=4.0, value=1.0, step=0.1,
label="Temperature")
sample_boolean = gr.Checkbox(label="Enable Sampling?")
text_input_temp = gr.Textbox(label="Move")
text_output_temp = gr.Textbox(label="Move Description")
text_button_temp = gr.Button("Create my move!")
with gr.TabItem("Top K and Top P Sampling"):
gr.Markdown("This tab lets you learn about Top K and Top P Sampling")
with gr.Row():
topk = gr.Slider(minimum=10, maximum=100, value=0, step=5,
label="Top K")
topp = gr.Slider(minimum=0.10, maximum=0.95, value=1, step=0.05,
label="Top P")
text_input_top = gr.Textbox(label="Move")
text_output_top = gr.Textbox(label="Move Description")
text_button_top = gr.Button("Create my move!")
with gr.Box():
# Displays a dataframe with the history of moves generated, with parameters
history = gr.Dataframe(headers=["Move Name", "Move Description", "Generation Type", "Parameters"])
text_button_baseline.click(create_move, inputs=[text_input_baseline, history],
outputs=[text_output_baseline, history])
text_button_greedy.click(create_greedy_search_move, inputs=[text_input_greedy, history],
outputs=[text_output_greedy, history])
text_button_temp.click(create_sampling_search_move, inputs=[text_input_temp, sample_boolean, temperature, history],
outputs=[text_output_temp, history])
text_button_beam.click(create_beam_search_move, inputs=[text_input_beam, num_beams, history],
outputs=[text_output_beam, history])
text_button_top.click(create_top_search_move, inputs=[text_input_top, topk, topp, history],
outputs=[text_output_top, history])
demo.launch(share=True)