SuperPrompt-v1 / app.py
Nick088's picture
Switched from chatinterface to normal interface and has better use random seed option
1aa631a verified
raw
history blame
3.13 kB
import gradio as gr
import torch
import random
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=torch.float16)
if torch.cuda.is_available():
device = "cuda"
print("Using GPU")
else:
device = "cpu"
print("Using CPU")
model.to(device)
def generate(
system_prompt,
prompt,
max_new_tokens,
repetition_penalty,
temperature,
top_p,
top_k,
random_seed,
seed,
):
input_text = f"{system_prompt}, {prompt}"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
if random_seed:
seed = random.randint(1, 100000)
torch.manual_seed(seed)
else:
torch.manual_seed(seed)
outputs = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
better_prompt = tokenizer.decode(outputs[0])
return better_prompt
prompt = gr.Textbox(label="Prompt", interactive=True)
system_prompt = gr.Textbox(label="System Prompt", interactive=True)
max_new_tokens = gr.Slider(value=512, minimum=250, maximum=512, step=1, interactive=True, label="Max New Tokens", info="The maximum numbers of new tokens, controls how long is the output")
repetition_penalty = gr.Slider(value=1.2, minimum=0, maximum=2, step=0.05, interactive=True, label="Repetition Penalty", info="Penalize repeated tokens, making the AI repeat less itself")
temperature = gr.Slider(value=0.5, minimum=0, maximum=1, step=0.05, interactive=True, label="Temperature", info="Higher values produce more diverse outputs")
top_p = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, interactive=True, label="Top P", info="Higher values sample more low-probability tokens")
top_k = gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens")
use_random_seed = gr.Checkbox(value=False, label="Use Random Seed", info="Check to use a random seed which is a start point for the generation process")
manual_seed = gr.Number(value=42, interactive=True, label="Manual Seed", info="A starting point to initiate the generation process", visible={'False' if use_random_seed else 'True'})
examples = [
[
"A storefront with 'Text to Image' written on it.",
"Expand the following prompt to add more detail:",
512,
1.2,
0.5,
1,
50,
False,
42,
]
]
gr.Interface(
fn=generate,
inputs=[prompt, system_prompt, max_new_tokens, repetition_penalty, temperature, top_p, top_k, use_random_seed, manual_seed]
outputs=gr.Textbox(label="Better Prompt", interactive=True)
title="SuperPrompt-v1",
description="Make your prompts more detailed!",
examples=examples,
live=True
concurrency_limit=20,
).launch(show_api=False)