piyushgrover's picture
Update app.py
ceafc04
raw
history blame
4.43 kB
import gradio as gr
#from utils import *
import random
is_clicked = False
out_img_list = ['', '', '', '', '']
out_state_list = [False, False, False, False, False]
def fn_query_on_load():
return "Cats at sunset"
def fn_refresh():
return out_img_list
with gr.Blocks() as app:
with gr.Row():
gr.Markdown(
"""
# Stable Diffusion Image Generation
### Enter query to generate images in various styles
""")
with gr.Row(visible=True):
with gr.Column():
with gr.Row():
search_text = gr.Textbox(value=fn_query_on_load, placeholder='Search..', label=None)
with gr.Row(visible=True):
#with gr.Column():
out1 = gr.Image(value="out1.png", interactive=False, width=128, height=128, label='Oil Painting')
#submit1 = gr.Button("Submit", variant='primary')
#with gr.Column():
out2 = gr.Image(value="out2.png", interactive=False, width=128, height=128, label='Low Poly HD Style')
#submit2 = gr.Button("Submit", variant='primary')
#with gr.Column():
out3 = gr.Image(value="out3.png", interactive=False, width=128, height=128, label='Matrix style')
#submit3 = gr.Button("Submit", variant='primary')
#with gr.Column():
out4 = gr.Image(value="out4.png", interactive=False, width=128, height=128, label='Dreamy Painting')
#submit4 = gr.Button("Submit", variant='primary')
#with gr.Column():
out5 = gr.Image(value="out5.png", interactive=False, width=128, height=128, label='Depth Map Style')
#submit5 = gr.Button("Submit", variant='primary')
with gr.Row(visible=True):
clear_btn = gr.ClearButton()
def clear_data():
return {
out1: None,
out2: None,
out3: None,
out4: None,
out5: None,
search_text: None
}
clear_btn.click(clear_data, None, [out1, out2, out3, out4, out5, search_text])
'''def func_generate(query, concept_idx, seed):
prompt = query + ' in the style of bulb'
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
return_tensors="pt")
input_ids = text_input.input_ids.to(torch_device)
# Get token embeddings
position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
position_embeddings = pos_emb_layer(position_ids)
s = seed
token_embeddings = token_emb_layer(input_ids)
# The new embedding - our special birb word
replacement_token_embedding = concept_embeds[concept_idx].to(torch_device)
# Insert this into the token embeddings
token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device)
# Combine with pos embs
input_embeddings = token_embeddings + position_embeddings
# Feed through to get final output embs
modified_output_embeddings = get_output_embeds(input_embeddings)
# And generate an image with this:
s = random.randint(s + 1, s + 30)
g = torch.manual_seed(s)
return generate_with_embs(text_input, modified_output_embeddings, generator=g)
def generate_oil_painting(query):
return {
out1: func_generate(query, 0, 0)
}
def generate_low_poly_hd(query):
return {
out2: func_generate(query, 1, 30)
}
def generate_matrix_style(query):
return {
out3: func_generate(query, 2, 60)
}
def generate_dreamy_painting(query):
return {
out4: func_generate(query, 3, 90)
}
def generate_depth_map_style(query):
return {
out5: func_generate(query, 4, 120)
}
submit1.click(
generate_oil_painting,
search_text,
out1
)
submit2.click(
generate_low_poly_hd,
search_text,
out2
)
submit3.click(
generate_matrix_style,
search_text,
out3
)
submit4.click(
generate_dreamy_painting,
search_text,
out4
)
submit5.click(
generate_depth_map_style,
search_text,
out5
)
'''
'''
Launch the app
'''
app.launch()