import gradio as gr #from utils import * import random is_clicked = False out_img_list = ['', '', '', '', ''] out_state_list = [False, False, False, False, False] def fn_query_on_load(): return "Cats at sunset" def fn_refresh(): return out_img_list with gr.Blocks() as app: with gr.Row(): gr.Markdown( """ # Stable Diffusion Image Generation ### Enter query to generate images in various styles """) with gr.Row(visible=True): with gr.Column(): with gr.Row(): search_text = gr.Textbox(value=fn_query_on_load, placeholder='Search..', label=None) with gr.Row(visible=True): #with gr.Column(): out1 = gr.Image(value="out1.png", interactive=False, width=128, height=128, label='Oil Painting') #submit1 = gr.Button("Submit", variant='primary') #with gr.Column(): out2 = gr.Image(value="out2.png", interactive=False, width=128, height=128, label='Low Poly HD Style') #submit2 = gr.Button("Submit", variant='primary') #with gr.Column(): out3 = gr.Image(value="out3.png", interactive=False, width=128, height=128, label='Matrix style') #submit3 = gr.Button("Submit", variant='primary') #with gr.Column(): out4 = gr.Image(value="out4.png", interactive=False, width=128, height=128, label='Dreamy Painting') #submit4 = gr.Button("Submit", variant='primary') #with gr.Column(): out5 = gr.Image(value="out5.png", interactive=False, width=128, height=128, label='Depth Map Style') #submit5 = gr.Button("Submit", variant='primary') with gr.Row(visible=True): clear_btn = gr.ClearButton() def clear_data(): return { out1: None, out2: None, out3: None, out4: None, out5: None, search_text: None } clear_btn.click(clear_data, None, [out1, out2, out3, out4, out5, search_text]) '''def func_generate(query, concept_idx, seed): prompt = query + ' in the style of bulb' text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") input_ids = text_input.input_ids.to(torch_device) # Get token embeddings position_ids = text_encoder.text_model.embeddings.position_ids[:, :77] position_embeddings = pos_emb_layer(position_ids) s = seed token_embeddings = token_emb_layer(input_ids) # The new embedding - our special birb word replacement_token_embedding = concept_embeds[concept_idx].to(torch_device) # Insert this into the token embeddings token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device) # Combine with pos embs input_embeddings = token_embeddings + position_embeddings # Feed through to get final output embs modified_output_embeddings = get_output_embeds(input_embeddings) # And generate an image with this: s = random.randint(s + 1, s + 30) g = torch.manual_seed(s) return generate_with_embs(text_input, modified_output_embeddings, generator=g) def generate_oil_painting(query): return { out1: func_generate(query, 0, 0) } def generate_low_poly_hd(query): return { out2: func_generate(query, 1, 30) } def generate_matrix_style(query): return { out3: func_generate(query, 2, 60) } def generate_dreamy_painting(query): return { out4: func_generate(query, 3, 90) } def generate_depth_map_style(query): return { out5: func_generate(query, 4, 120) } submit1.click( generate_oil_painting, search_text, out1 ) submit2.click( generate_low_poly_hd, search_text, out2 ) submit3.click( generate_matrix_style, search_text, out3 ) submit4.click( generate_dreamy_painting, search_text, out4 ) submit5.click( generate_depth_map_style, search_text, out5 ) ''' ''' Launch the app ''' app.launch()