import gradio as gr from transformers import AutoTokenizer from transformers import pipeline from utils import format_moves import pandas as pd model_checkpoint = "distilgpt2" tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) generate = pipeline("text-generation", model="arjunpatel/distilgpt2-finetuned-pokemon-moves", tokenizer=tokenizer) # load in the model seed_text = "This move is called " import tensorflow as tf tf.random.set_seed(0) # need a function to sanitize imputs # - remove extra spaces # - make sure each word is capitalized # - format the moves such that it's clearer when each move is listed # - play with the max length parameter abit, and try to remove sentences that don't end in periods. def update_history(df, move_name, move_desc, generation, parameters): # needs to format each move description with new lines to cut down on width new_row = [{"Move Name": move_name, "Move Description": move_desc, "Generation Type": generation, "Parameters": parameters}] return pd.concat([df, pd.DataFrame(new_row)]) def create_move(move, history): generated_move = format_moves(generate(seed_text + move, num_return_sequences=1)) return generated_move, update_history(history, move, generated_move, "baseline", "None") def create_greedy_search_move(move, history): generated_move = format_moves(generate(seed_text + move, do_sample=False)) return generated_move, update_history(history, move, generated_move, "greedy", "None") def create_beam_search_move(move, num_beams, history): generated_move = format_moves(generate(seed_text + move, num_beams=num_beams, num_return_sequences=1, do_sample=False, early_stopping=True)) return generated_move, update_history(history, move, generated_move, "beam", {"num_beams": 2}) def create_sampling_search_move(move, do_sample, temperature, history): generated_move = format_moves(generate(seed_text + move, do_sample=do_sample, temperature=float(temperature), num_return_sequences=1, topk=0)) return generated_move, update_history(history, move, generated_move, "temperature", {"do_sample": do_sample, "temperature": temperature}) def create_top_search_move(move, topk, topp, history): generated_move = format_moves(generate( seed_text + move, do_sample=True, num_return_sequences=1, top_k=topk, top_p=topp, force_word_ids=tokenizer.encode("The user", return_tensors='tf'))) return generated_move, update_history(history, move, generated_move, "top", {"top k": topk, "top p": topp}) demo = gr.Blocks() with demo: gr.Markdown("

What's that Pokemon Move?

") gr.Markdown( "This Gradio demo is a small GPT-2 model fine-tuned on a dataset of Pokemon moves! It'll generate a move description given a name.") gr.Markdown("Enter a two to three word Pokemon Move name of your imagination below!") with gr.Tabs(): with gr.TabItem("Standard Generation"): with gr.Row(): text_input_baseline = gr.Textbox(label="Move", placeholder="Type a two or three word move name here! Try \"Wonder Shield\"!") text_output_baseline = gr.Textbox(label="Move Description", placeholder="Leave this blank!") text_button_baseline = gr.Button("Create my move!") with gr.TabItem("Greedy Search"): gr.Markdown("This tab lets you learn about using greedy search!") with gr.Row(): text_input_greedy = gr.Textbox(label="Move") text_output_greedy = gr.Textbox(label="Move Description") text_button_greedy = gr.Button("Create my move!") with gr.TabItem("Beam Search"): gr.Markdown("This tab lets you learn about using beam search!") with gr.Row(): num_beams = gr.Slider(minimum=2, maximum=10, value=2, step=1, label="Number of Beams") text_input_beam = gr.Textbox(label="Move") text_output_beam = gr.Textbox(label="Move Description") text_button_beam = gr.Button("Create my move!") with gr.TabItem("Sampling and Temperature Search"): gr.Markdown("This tab lets you experiment with adjusting the temperature of the generator") with gr.Row(): temperature = gr.Slider(minimum=0.3, maximum=4.0, value=1.0, step=0.1, label="Temperature") sample_boolean = gr.Checkbox(label="Enable Sampling?") text_input_temp = gr.Textbox(label="Move") text_output_temp = gr.Textbox(label="Move Description") text_button_temp = gr.Button("Create my move!") with gr.TabItem("Top K and Top P Sampling"): gr.Markdown("This tab lets you learn about Top K and Top P Sampling") with gr.Row(): topk = gr.Slider(minimum=10, maximum=100, value=0, step=5, label="Top K") topp = gr.Slider(minimum=0.10, maximum=0.95, value=1, step=0.05, label="Top P") text_input_top = gr.Textbox(label="Move") text_output_top = gr.Textbox(label="Move Description") text_button_top = gr.Button("Create my move!") with gr.Box(): # Displays a dataframe with the history of moves generated, with parameters history = gr.Dataframe(headers=["Move Name", "Move Description", "Generation Type", "Parameters"]) text_button_baseline.click(create_move, inputs=[text_input_baseline, history], outputs=[text_output_baseline, history]) text_button_greedy.click(create_greedy_search_move, inputs=[text_input_greedy, history], outputs=[text_output_greedy, history]) text_button_temp.click(create_sampling_search_move, inputs=[text_input_temp, sample_boolean, temperature, history], outputs=[text_output_temp, history]) text_button_beam.click(create_beam_search_move, inputs=[text_input_beam, num_beams, history], outputs=[text_output_beam, history]) text_button_top.click(create_top_search_move, inputs=[text_input_top, topk, topp, history], outputs=[text_output_top, history]) demo.launch(share=True)