pokemon-move-generator-app / gradio_demo.py
arjunpatel's picture
skeleton version of demo working
658b022
raw
history blame
2.15 kB
import gradio as gr
from transformers import AutoTokenizer
from transformers import pipeline
model_checkpoint = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
generate = pipeline("text-generation",
model="arjunpatel/distilgpt2-finetuned-pokemon-moves",
tokenizer=tokenizer)
def filter_text(generated_move):
# removes any moves that follow after the genrated move
print(generated_move)
sentences = generated_move.split(".")
if len(sentences) > 2:
ret_set = " ".join(sentences[0:1])
else:
ret_set = generated_move
return ret_set
def create_move(move):
seed_text = "This move is called "
generated_move = generate(seed_text + move, num_return_sequences=2,
no_repeat_ngram_size=4)[0]["generated_text"]
return generated_move
# # demo = gr.Interface(fn=greet, inputs = "text", outputs="text")
#
# gr.Interface(fn=create_move,
# inputs="text", outputs="text").launch()
# # demo.launch()
def filler_move(test_move, temperature):
return test_move + " with temperature " + str(temperature)
demo = gr.Blocks()
with demo:
gr.Markdown("What's that Pokemon Move?")
with gr.Tabs():
with gr.TabItem("Standard Generation"):
with gr.Row():
text_input_baseline = gr.Textbox()
text_output_baseline = gr.Textbox()
text_button_baseline = gr.Button("Create my move!")
with gr.TabItem("Temperature Search"):
with gr.Row():
temperature = gr.Slider(minimum = 0.3, maximum = 4, value = 1, step = 0.1,
label = "Temperature")
text_input_temp = gr.Textbox(label="Move Name")
text_output_temp = gr.Textbox(label = "Move Description")
text_button_temp = gr.Button("Create my move!")
#text_button_baseline.click(filler_move, inputs=[text_input_baseline, 0], outputs=text_output_baseline)
text_button_temp.click(filler_move, inputs=[text_input_temp, temperature], outputs=text_output_temp)
demo.launch()