Spaces:
Runtime error
Runtime error
import gradio as gr | |
from utils import * | |
import random | |
is_clicked = False | |
out_img_list = [None, None, None, None, None] | |
out_state_list = [False, False, False, False, False] | |
def fn_query_on_load(): | |
return "Cats at sunset" | |
def fn_refresh(): | |
return out_img_list | |
with gr.Blocks() as app: | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
# Stable Diffusion Image Generation | |
### Enter query to generate images in various styles | |
""") | |
with gr.Row(visible=True): | |
with gr.Column(): | |
with gr.Row(): | |
search_text = gr.Textbox(value=fn_query_on_load, placeholder='Search..', label=None) | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant='primary') | |
clear_btn = gr.ClearButton() | |
with gr.Row(visible=True): | |
output_images = gr.Gallery(value=fn_refresh, interactive=False, every=5) | |
def clear_data(): | |
return { | |
output_images: None, | |
search_text: None | |
} | |
clear_btn.click(clear_data, None, [output_images, search_text]) | |
def func_generate(query): | |
global is_clicked | |
is_clicked = True | |
prompt = query + ' in the style of bulb' | |
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, | |
return_tensors="pt") | |
input_ids = text_input.input_ids.to(torch_device) | |
# Get token embeddings | |
position_ids = text_encoder.text_model.embeddings.position_ids[:, :77] | |
position_embeddings = pos_emb_layer(position_ids) | |
s = 0 | |
for i in range(5): | |
token_embeddings = token_emb_layer(input_ids) | |
# The new embedding - our special birb word | |
replacement_token_embedding = concept_embeds[i].to(torch_device) | |
# Insert this into the token embeddings | |
token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device) | |
# Combine with pos embs | |
input_embeddings = token_embeddings + position_embeddings | |
# Feed through to get final output embs | |
modified_output_embeddings = get_output_embeds(input_embeddings) | |
# And generate an image with this: | |
s = random.randint(s + 1, s + 30) | |
g = torch.manual_seed(s) | |
output = generate_with_embs(text_input, modified_output_embeddings, output=out_img_list[i], generator=g) | |
#output_images.append(dict(seed=s, output=output)) | |
is_clicked = False | |
return None | |
submit_btn.click( | |
func_generate, | |
[search_text], | |
None | |
) | |
''' | |
Launch the app | |
''' | |
app.queue.launch(share=True) | |