radames commited on
Commit
acb3ec0
1 Parent(s): f1cb8d7

load from URL param button

Browse files
Files changed (1) hide show
  1. app.py +54 -25
app.py CHANGED
@@ -43,6 +43,22 @@ async (text_input) => {
43
  }
44
  """
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def main():
48
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -96,7 +112,8 @@ def main():
96
  plain_img = model.produce_attn_maps([base_text_prompt], [negative_text],
97
  height=height, width=width, num_inference_steps=steps,
98
  guidance_scale=guidance_weight)
99
- print('time lapses to get attention maps: %.4f' % (time.time()-begin_time))
 
100
  color_obj_masks, _ = get_token_maps(
101
  model.attention_maps, run_dir, width//8, height//8, color_target_token_ids, seed)
102
  model.masks, token_maps = get_token_maps(
@@ -104,7 +121,7 @@ def main():
104
  color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
105
  interpolation=transforms.InterpolationMode.BICUBIC,
106
  antialias=True)
107
- for color_obj_mask in color_obj_masks]
108
  text_format_dict['color_obj_atten'] = color_obj_masks
109
  model.remove_evaluation_hooks()
110
 
@@ -112,14 +129,15 @@ def main():
112
  begin_time = time.time()
113
  seed_everything(seed)
114
  rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
115
- height=height, width=width, num_inference_steps=steps,
116
- guidance_scale=guidance_weight, use_grad_guidance=use_grad_guidance,
117
- text_format_dict=text_format_dict)
118
  print('time lapses to generate image from rich text: %.4f' %
119
- (time.time()-begin_time))
120
  return [plain_img[0], rich_img[0], token_maps]
121
 
122
  with gr.Blocks() as demo:
 
123
  gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">Expressive Text-to-Image Generation with Rich Text</h1>
124
  <p> <a href="https://songweige.github.io/">Songwei Ge</a>, <a href="https://taesung.me/">Taesung Park</a>, <a href="https://www.cs.cmu.edu/~junyanz/">Jun-Yan Zhu</a>, <a href="https://jbhuang0604.github.io/">Jia-Bin Huang</a> <p/>
125
  <p> UMD, Adobe, CMU <p/>
@@ -150,28 +168,29 @@ def main():
150
  value=0.5)
151
  with gr.Accordion('Other Parameters', open=False):
152
  steps = gr.Slider(label='Number of Steps',
153
- minimum=0,
154
- maximum=100,
155
- step=1,
156
- value=41)
157
  guidance_weight = gr.Slider(label='CFG weight',
158
- minimum=0,
159
- maximum=50,
160
- step=0.1,
161
- value=8.5)
162
  width = gr.Dropdown(choices=[512],
163
- value=512,
164
- label='Width',
165
- visible=True)
166
  height = gr.Dropdown(choices=[512],
167
- value=512,
168
- label='height',
169
- visible=True)
170
-
171
  with gr.Row():
172
  with gr.Column(scale=1, min_width=100):
173
  generate_button = gr.Button("Generate")
174
-
 
175
  with gr.Column():
176
  richtext_result = gr.Image(label='Rich-text')
177
  richtext_result.style(height=512)
@@ -261,7 +280,7 @@ def main():
261
  None
262
  ],
263
  ]
264
-
265
  gr.Examples(examples=footnote_examples,
266
  label='Footnote examples',
267
  inputs=[
@@ -395,10 +414,20 @@ def main():
395
  outputs=[plaintext_result, richtext_result, token_map],
396
  _js=get_js_data
397
  )
398
- text_input.change(fn=None, inputs=[text_input], outputs=None, _js=set_js_data, queue=False)
 
 
 
 
 
 
 
 
 
 
399
  demo.queue(concurrency_count=1)
400
  demo.launch(share=False)
401
 
402
 
403
  if __name__ == "__main__":
404
- main()
 
43
  }
44
  """
45
 
46
+ get_window_url_params = """
47
+ async (url_params) => {
48
+ console.log(url_params);
49
+ const params = new URLSearchParams(window.location.search);
50
+ url_params = Object.fromEntries(params);
51
+ return [url_params];
52
+ }
53
+ """
54
+
55
+
56
+ def load_url_params(url_params):
57
+ if 'prompt' in url_params:
58
+ return gr.update(visible=True), url_params
59
+ else:
60
+ return gr.update(visible=False), url_params
61
+
62
 
63
  def main():
64
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
112
  plain_img = model.produce_attn_maps([base_text_prompt], [negative_text],
113
  height=height, width=width, num_inference_steps=steps,
114
  guidance_scale=guidance_weight)
115
+ print('time lapses to get attention maps: %.4f' %
116
+ (time.time()-begin_time))
117
  color_obj_masks, _ = get_token_maps(
118
  model.attention_maps, run_dir, width//8, height//8, color_target_token_ids, seed)
119
  model.masks, token_maps = get_token_maps(
 
121
  color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
122
  interpolation=transforms.InterpolationMode.BICUBIC,
123
  antialias=True)
124
+ for color_obj_mask in color_obj_masks]
125
  text_format_dict['color_obj_atten'] = color_obj_masks
126
  model.remove_evaluation_hooks()
127
 
 
129
  begin_time = time.time()
130
  seed_everything(seed)
131
  rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
132
+ height=height, width=width, num_inference_steps=steps,
133
+ guidance_scale=guidance_weight, use_grad_guidance=use_grad_guidance,
134
+ text_format_dict=text_format_dict)
135
  print('time lapses to generate image from rich text: %.4f' %
136
+ (time.time()-begin_time))
137
  return [plain_img[0], rich_img[0], token_maps]
138
 
139
  with gr.Blocks() as demo:
140
+ url_params = gr.JSON({}, visible=True, label="URL Params")
141
  gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">Expressive Text-to-Image Generation with Rich Text</h1>
142
  <p> <a href="https://songweige.github.io/">Songwei Ge</a>, <a href="https://taesung.me/">Taesung Park</a>, <a href="https://www.cs.cmu.edu/~junyanz/">Jun-Yan Zhu</a>, <a href="https://jbhuang0604.github.io/">Jia-Bin Huang</a> <p/>
143
  <p> UMD, Adobe, CMU <p/>
 
168
  value=0.5)
169
  with gr.Accordion('Other Parameters', open=False):
170
  steps = gr.Slider(label='Number of Steps',
171
+ minimum=0,
172
+ maximum=100,
173
+ step=1,
174
+ value=41)
175
  guidance_weight = gr.Slider(label='CFG weight',
176
+ minimum=0,
177
+ maximum=50,
178
+ step=0.1,
179
+ value=8.5)
180
  width = gr.Dropdown(choices=[512],
181
+ value=512,
182
+ label='Width',
183
+ visible=True)
184
  height = gr.Dropdown(choices=[512],
185
+ value=512,
186
+ label='height',
187
+ visible=True)
188
+
189
  with gr.Row():
190
  with gr.Column(scale=1, min_width=100):
191
  generate_button = gr.Button("Generate")
192
+ load_params_button = gr.Button(
193
+ "Load from URL Params", visible=True)
194
  with gr.Column():
195
  richtext_result = gr.Image(label='Rich-text')
196
  richtext_result.style(height=512)
 
280
  None
281
  ],
282
  ]
283
+
284
  gr.Examples(examples=footnote_examples,
285
  label='Footnote examples',
286
  inputs=[
 
414
  outputs=[plaintext_result, richtext_result, token_map],
415
  _js=get_js_data
416
  )
417
+ text_input.change(
418
+ fn=None, inputs=[text_input], outputs=None, _js=set_js_data, queue=False)
419
+ # load url param prompt to textinput
420
+ load_params_button.click(fn=lambda x: x['prompt'], inputs=[
421
+ url_params], outputs=[text_input], queue=False)
422
+ demo.load(
423
+ fn=load_url_params,
424
+ inputs=[url_params],
425
+ outputs=[load_params_button, url_params],
426
+ _js=get_window_url_params
427
+ )
428
  demo.queue(concurrency_count=1)
429
  demo.launch(share=False)
430
 
431
 
432
  if __name__ == "__main__":
433
+ main()