File size: 4,426 Bytes
477daa4
05ebee8
477daa4
 
 
 
0093752
477daa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4eaa34
ceafc04
 
a34dec2
d13c68f
a34dec2
ceafc04
a34dec2
477daa4
a34dec2
ceafc04
a34dec2
d13c68f
 
a34dec2
ceafc04
a34dec2
d13c68f
a34dec2
ceafc04
a34dec2
d13c68f
 
 
477daa4
 
 
d13c68f
 
 
 
 
477daa4
 
 
 
d13c68f
477daa4
 
05ebee8
477daa4
 
 
 
 
 
 
 
 
d13c68f
 
 
 
 
477daa4
d13c68f
 
477daa4
d13c68f
 
477daa4
d13c68f
 
477daa4
d13c68f
477daa4
d13c68f
 
 
 
 
c19465d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d13c68f
c19465d
 
d13c68f
 
477daa4
d13c68f
c19465d
 
d13c68f
 
477daa4
d13c68f
c19465d
 
d13c68f
 
477daa4
d13c68f
c19465d
 
d13c68f
 
477daa4
d13c68f
c19465d
 
d13c68f
477daa4
05ebee8
477daa4
 
 
 
d13c68f
477daa4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import gradio as gr
#from utils import *
import random


is_clicked = False
out_img_list = ['', '', '', '', '']
out_state_list = [False, False, False, False, False]

def fn_query_on_load():
    return "Cats at sunset"

def fn_refresh():

    return out_img_list


with gr.Blocks() as app:
    with gr.Row():
        gr.Markdown(
            """
            # Stable Diffusion Image Generation
            ### Enter query to generate images in various styles
            """)

    with gr.Row(visible=True):
        with gr.Column():
            with gr.Row():
                search_text = gr.Textbox(value=fn_query_on_load, placeholder='Search..', label=None)


    with gr.Row(visible=True):
        #with gr.Column():
            out1 = gr.Image(value="out1.png", interactive=False, width=128, height=128, label='Oil Painting')
            #submit1 = gr.Button("Submit", variant='primary')

        #with gr.Column():
            out2 = gr.Image(value="out2.png", interactive=False, width=128, height=128, label='Low Poly HD Style')
            #submit2 = gr.Button("Submit", variant='primary')

        #with gr.Column():
            out3 = gr.Image(value="out3.png", interactive=False, width=128, height=128, label='Matrix style')
            #submit3 = gr.Button("Submit", variant='primary')


        #with gr.Column():
            out4 = gr.Image(value="out4.png", interactive=False, width=128, height=128, label='Dreamy Painting')
            #submit4 = gr.Button("Submit", variant='primary')

        #with gr.Column():
            out5 = gr.Image(value="out5.png", interactive=False, width=128, height=128, label='Depth Map Style')
            #submit5 = gr.Button("Submit", variant='primary')

    with gr.Row(visible=True):
        clear_btn = gr.ClearButton()

    def clear_data():
        return {
            out1: None,
            out2: None,
            out3: None,
            out4: None,
            out5: None,
            search_text: None
        }


    clear_btn.click(clear_data, None, [out1, out2, out3, out4, out5, search_text])


    '''def func_generate(query, concept_idx, seed):
        prompt = query + ' in the style of bulb'
        text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
                               return_tensors="pt")
        input_ids = text_input.input_ids.to(torch_device)

        # Get token embeddings
        position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
        position_embeddings = pos_emb_layer(position_ids)

        s = seed

        token_embeddings = token_emb_layer(input_ids)
        # The new embedding - our special birb word
        replacement_token_embedding = concept_embeds[concept_idx].to(torch_device)

        # Insert this into the token embeddings
        token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device)

        # Combine with pos embs
        input_embeddings = token_embeddings + position_embeddings

        #  Feed through to get final output embs
        modified_output_embeddings = get_output_embeds(input_embeddings)

        # And generate an image with this:

        s = random.randint(s + 1, s + 30)
        g = torch.manual_seed(s)
        return generate_with_embs(text_input, modified_output_embeddings, generator=g)


    def generate_oil_painting(query):
        return {
            out1: func_generate(query, 0, 0)
        }

    def generate_low_poly_hd(query):
        return {
            out2: func_generate(query, 1, 30)
        }

    def generate_matrix_style(query):
        return {
            out3: func_generate(query, 2, 60)
        }

    def generate_dreamy_painting(query):
        return {
            out4: func_generate(query, 3, 90)
        }

    def generate_depth_map_style(query):
        return {
            out5: func_generate(query, 4, 120)
        }

    submit1.click(
        generate_oil_painting,
        search_text,
        out1
    )

    submit2.click(
        generate_low_poly_hd,
        search_text,
        out2
    )

    submit3.click(
        generate_matrix_style,
        search_text,
        out3
    )

    submit4.click(
        generate_dreamy_painting,
        search_text,
        out4
    )

    submit5.click(
        generate_depth_map_style,
        search_text,
        out5
    )
'''

'''
Launch the app
'''
app.launch()