image_modification / prompt_app.py
timbrooks's picture
Add InstructPix2Pix
2afcb7e
raw
history blame
1.98 kB
from __future__ import annotations
from argparse import ArgumentParser
import datasets
import gradio as gr
import numpy as np
import openai
from dataset_creation.generate_txt_dataset import generate
def main(openai_model: str):
dataset = datasets.load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus", split="train")
captions = dataset[np.random.permutation(len(dataset))]["TEXT"]
index = 0
def click_random():
nonlocal index
output = captions[index]
index = (index + 1) % len(captions)
return output
def click_generate(input: str):
if input == "":
raise gr.Error("Input caption is missing!")
edit_output = generate(openai_model, input)
if edit_output is None:
return "Failed :(", "Failed :("
return edit_output
with gr.Blocks(css="footer {visibility: hidden}") as demo:
txt_input = gr.Textbox(lines=3, label="Input Caption", interactive=True, placeholder="Type image caption here...") # fmt: skip
txt_edit = gr.Textbox(lines=1, label="GPT-3 Instruction", interactive=False)
txt_output = gr.Textbox(lines=3, label="GPT3 Edited Caption", interactive=False)
with gr.Row():
clear_btn = gr.Button("Clear")
random_btn = gr.Button("Random Input")
generate_btn = gr.Button("Generate Instruction + Edited Caption")
clear_btn.click(fn=lambda: ("", "", ""), inputs=[], outputs=[txt_input, txt_edit, txt_output])
random_btn.click(fn=click_random, inputs=[], outputs=[txt_input])
generate_btn.click(fn=click_generate, inputs=[txt_input], outputs=[txt_edit, txt_output])
demo.launch(share=True)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("openai-api-key", type=str)
parser.add_argument("openai-model", type=str)
args = parser.parse_args()
openai.api_key = args.openai_api_key
main(args.openai_model)