mlnotes's picture
ADD visitor badge
e0f1d46
raw
history blame
No virus
3.51 kB
import io, os, base64
from PIL import Image
import gradio as gr
import shortuuid
from transformers import pipeline
text_generation_model = "pranavpsv/gpt2-genre-story-generator"
text_generation = pipeline("text-generation", text_generation_model)
latent = gr.Interface.load("spaces/multimodalart/latentdiffusion")
def get_story(user_input, genre="sci_fi"):
prompt = f"<BOS> <{genre}> "
stories = text_generation(f"{prompt}{user_input}", max_length=32, num_return_sequences=1)
story = stories[0]["generated_text"]
story_without_prompt = story[len(prompt):]
return story_without_prompt
def text2image_latent(text, steps, width, height, images, diversity):
print(text)
results = latent(text, steps, width, height, images, diversity)
image_paths = []
for image in results[1]:
image_str = image[0]
image_str = image_str.replace("data:image/png;base64,","")
decoded_bytes = base64.decodebytes(bytes(image_str, "utf-8"))
img = Image.open(io.BytesIO(decoded_bytes))
url = shortuuid.uuid()
temp_dir = './tmp'
if not os.path.exists(temp_dir):
os.makedirs(temp_dir, exist_ok=True)
image_path = f'{temp_dir}/{url}.png'
img.save(f'{temp_dir}/{url}.png')
image_paths.append(image_path)
return(image_paths)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
user_input = gr.inputs.Textbox(placeholder="Type your prompt to generate an image", label="Prompt - try adding increments to your prompt such as 'a painting of', 'in the style of Picasso'", default="A giant mecha robot in Rio de Janeiro, oil on canvas")
genre_input = gr.Dropdown(["superhero","action","drama","horror","thriller","sci_fi",])
generated_story = gr.Textbox()
with gr.Row():
button_generate_story = gr.Button("Generate Story")
with gr.Column():
steps = gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=50,maximum=50,minimum=1,step=1)
width = gr.inputs.Slider(label="Width", default=256, step=32, maximum=256, minimum=32)
height = gr.inputs.Slider(label="Height", default=256, step=32, maximum = 256, minimum=32)
images = gr.inputs.Slider(label="Images - How many images you wish to generate", default=4, step=1, minimum=1, maximum=4)
diversity = gr.inputs.Slider(label="Diversity scale - How different from one another you wish the images to be",default=15.0, minimum=1.0, maximum=15.0)
with gr.Column():
gallery = gr.Gallery(label="Individual images")
with gr.Row():
get_image_latent = gr.Button("Generate Image", css={"margin-top": "1em"})
with gr.Row():
gr.Markdown("<a href='https://huggingface.co/spaces/merve/GPT-2-story-gen' target='_blank'>Story generation with GPT-2</a>, and text to image by <a href='https://huggingface.co/spaces/multimodalart/latentdiffusion' target='_blank'>Latent Diffusion</a>.")
with gr.Row():
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=gradio-blocks_latent_gpt2_story)")
button_generate_story.click(get_story, inputs=[user_input, genre_input], outputs=generated_story)
get_image_latent.click(text2image_latent, inputs=[generated_story,steps,width,height,images,diversity], outputs=gallery)
demo.launch(enable_queue=False)