piyushgrover commited on
Commit
d13c68f
β€’
1 Parent(s): 0093752
Files changed (1) hide show
  1. app.py +76 -31
app.py CHANGED
@@ -28,27 +28,53 @@ with gr.Blocks() as app:
28
  with gr.Row():
29
  search_text = gr.Textbox(value=fn_query_on_load, placeholder='Search..', label=None)
30
 
31
- with gr.Row():
32
- submit_btn = gr.Button("Submit", variant='primary')
33
- clear_btn = gr.ClearButton()
34
 
35
  with gr.Row(visible=True):
36
- output_images = gr.Gallery(value=fn_refresh, interactive=False, every=5)
 
 
 
 
 
 
 
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def clear_data():
40
  return {
41
- output_images: None,
 
 
 
 
42
  search_text: None
43
  }
44
 
45
 
46
- clear_btn.click(clear_data, None, [output_images, search_text])
47
 
48
 
49
- def func_generate(query):
50
- global is_clicked
51
- is_clicked = True
52
  prompt = query + ' in the style of bulb'
53
  text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
54
  return_tensors="pt")
@@ -58,44 +84,63 @@ with gr.Blocks() as app:
58
  position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
59
  position_embeddings = pos_emb_layer(position_ids)
60
 
61
- s = 0
62
- for i in range(5):
63
- token_embeddings = token_emb_layer(input_ids)
64
- # The new embedding - our special birb word
65
- replacement_token_embedding = concept_embeds[i].to(torch_device)
66
 
67
- # Insert this into the token embeddings
68
- token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device)
69
 
70
- # Combine with pos embs
71
- input_embeddings = token_embeddings + position_embeddings
72
 
73
- # Feed through to get final output embs
74
- modified_output_embeddings = get_output_embeds(input_embeddings)
75
 
76
- # And generate an image with this:
77
 
78
- s = random.randint(s + 1, s + 30)
79
- g = torch.manual_seed(s)
80
- output = generate_with_embs(text_input, modified_output_embeddings, output=out_img_list[i], generator=g)
81
- #output_images.append(dict(seed=s, output=output))
 
 
 
 
 
 
82
 
83
- is_clicked = False
 
 
 
 
84
 
85
- return None
 
 
 
 
86
 
 
 
 
 
 
87
 
88
- submit_btn.click(
89
  func_generate,
90
- [search_text],
91
- None
92
  )
93
 
94
 
95
  '''
96
  Launch the app
97
  '''
98
- app.queue.launch(share=True)
99
 
100
 
101
 
 
28
  with gr.Row():
29
  search_text = gr.Textbox(value=fn_query_on_load, placeholder='Search..', label=None)
30
 
 
 
 
31
 
32
  with gr.Row(visible=True):
33
+ with gr.Column():
34
+ with gr.Row():
35
+ out1 = gr.Image(interactive=False, label='Oil Painting')
36
+ submit1 = gr.Button("submit", variant='primary')
37
+
38
+ with gr.Column():
39
+ with gr.Row():
40
+ out2 = gr.Image(interactive=False, label='Low Poly HD Style')
41
+ submit2 = gr.Button("submit", variant='primary')
42
 
43
+ with gr.Column():
44
+ with gr.Row():
45
+ out3 = gr.Image(interactive=False, label='Matrix style')
46
+ submit3 = gr.Button("submit", variant='primary')
47
+
48
+
49
+ with gr.Column():
50
+ with gr.Row():
51
+ out4 = gr.Image(interactive=False, label='Dreamy Painting')
52
+ submit4 = gr.Button("submit", variant='primary')
53
+
54
+ with gr.Column():
55
+ with gr.Row():
56
+ out5 = gr.Image(interactive=False, label='Depth Map Style')
57
+ submit5 = gr.Button("submit", variant='primary')
58
+
59
+
60
+ with gr.Row(visible=True):
61
+ clear_btn = gr.ClearButton()
62
 
63
  def clear_data():
64
  return {
65
+ out1: None,
66
+ out2: None,
67
+ out3: None,
68
+ out4: None,
69
+ out5: None,
70
  search_text: None
71
  }
72
 
73
 
74
+ clear_btn.click(clear_data, None, [out1, out2, out3, out4, out5, search_text])
75
 
76
 
77
+ def func_generate(query, concept_idx, seed):
 
 
78
  prompt = query + ' in the style of bulb'
79
  text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
80
  return_tensors="pt")
 
84
  position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
85
  position_embeddings = pos_emb_layer(position_ids)
86
 
87
+ s = seed
88
+
89
+ token_embeddings = token_emb_layer(input_ids)
90
+ # The new embedding - our special birb word
91
+ replacement_token_embedding = concept_embeds[concept_idx].to(torch_device)
92
 
93
+ # Insert this into the token embeddings
94
+ token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device)
95
 
96
+ # Combine with pos embs
97
+ input_embeddings = token_embeddings + position_embeddings
98
 
99
+ # Feed through to get final output embs
100
+ modified_output_embeddings = get_output_embeds(input_embeddings)
101
 
102
+ # And generate an image with this:
103
 
104
+ s = random.randint(s + 1, s + 30)
105
+ g = torch.manual_seed(s)
106
+ return generate_with_embs(text_input, modified_output_embeddings, generator=g)
107
+
108
+
109
+ submit1.click(
110
+ func_generate,
111
+ [search_text, 0, 0],
112
+ out1
113
+ )
114
 
115
+ submit2.click(
116
+ func_generate,
117
+ [search_text, 1, 30],
118
+ out2
119
+ )
120
 
121
+ submit3.click(
122
+ func_generate,
123
+ [search_text, 2, 60],
124
+ out3
125
+ )
126
 
127
+ submit4.click(
128
+ func_generate,
129
+ [search_text, 3, 90],
130
+ out4
131
+ )
132
 
133
+ submit5.click(
134
  func_generate,
135
+ [search_text, 4, 120],
136
+ out5
137
  )
138
 
139
 
140
  '''
141
  Launch the app
142
  '''
143
+ app.launch()
144
 
145
 
146