Spaces:
Runtime error
Runtime error
load from URL param button
Browse files
app.py
CHANGED
@@ -43,6 +43,22 @@ async (text_input) => {
|
|
43 |
}
|
44 |
"""
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
def main():
|
48 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
@@ -96,7 +112,8 @@ def main():
|
|
96 |
plain_img = model.produce_attn_maps([base_text_prompt], [negative_text],
|
97 |
height=height, width=width, num_inference_steps=steps,
|
98 |
guidance_scale=guidance_weight)
|
99 |
-
print('time lapses to get attention maps: %.4f' %
|
|
|
100 |
color_obj_masks, _ = get_token_maps(
|
101 |
model.attention_maps, run_dir, width//8, height//8, color_target_token_ids, seed)
|
102 |
model.masks, token_maps = get_token_maps(
|
@@ -104,7 +121,7 @@ def main():
|
|
104 |
color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
|
105 |
interpolation=transforms.InterpolationMode.BICUBIC,
|
106 |
antialias=True)
|
107 |
-
|
108 |
text_format_dict['color_obj_atten'] = color_obj_masks
|
109 |
model.remove_evaluation_hooks()
|
110 |
|
@@ -112,14 +129,15 @@ def main():
|
|
112 |
begin_time = time.time()
|
113 |
seed_everything(seed)
|
114 |
rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
print('time lapses to generate image from rich text: %.4f' %
|
119 |
-
|
120 |
return [plain_img[0], rich_img[0], token_maps]
|
121 |
|
122 |
with gr.Blocks() as demo:
|
|
|
123 |
gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">Expressive Text-to-Image Generation with Rich Text</h1>
|
124 |
<p> <a href="https://songweige.github.io/">Songwei Ge</a>, <a href="https://taesung.me/">Taesung Park</a>, <a href="https://www.cs.cmu.edu/~junyanz/">Jun-Yan Zhu</a>, <a href="https://jbhuang0604.github.io/">Jia-Bin Huang</a> <p/>
|
125 |
<p> UMD, Adobe, CMU <p/>
|
@@ -150,28 +168,29 @@ def main():
|
|
150 |
value=0.5)
|
151 |
with gr.Accordion('Other Parameters', open=False):
|
152 |
steps = gr.Slider(label='Number of Steps',
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
guidance_weight = gr.Slider(label='CFG weight',
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
width = gr.Dropdown(choices=[512],
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
height = gr.Dropdown(choices=[512],
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
with gr.Row():
|
172 |
with gr.Column(scale=1, min_width=100):
|
173 |
generate_button = gr.Button("Generate")
|
174 |
-
|
|
|
175 |
with gr.Column():
|
176 |
richtext_result = gr.Image(label='Rich-text')
|
177 |
richtext_result.style(height=512)
|
@@ -261,7 +280,7 @@ def main():
|
|
261 |
None
|
262 |
],
|
263 |
]
|
264 |
-
|
265 |
gr.Examples(examples=footnote_examples,
|
266 |
label='Footnote examples',
|
267 |
inputs=[
|
@@ -395,10 +414,20 @@ def main():
|
|
395 |
outputs=[plaintext_result, richtext_result, token_map],
|
396 |
_js=get_js_data
|
397 |
)
|
398 |
-
text_input.change(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
demo.queue(concurrency_count=1)
|
400 |
demo.launch(share=False)
|
401 |
|
402 |
|
403 |
if __name__ == "__main__":
|
404 |
-
main()
|
|
|
43 |
}
|
44 |
"""
|
45 |
|
46 |
+
get_window_url_params = """
|
47 |
+
async (url_params) => {
|
48 |
+
console.log(url_params);
|
49 |
+
const params = new URLSearchParams(window.location.search);
|
50 |
+
url_params = Object.fromEntries(params);
|
51 |
+
return [url_params];
|
52 |
+
}
|
53 |
+
"""
|
54 |
+
|
55 |
+
|
56 |
+
def load_url_params(url_params):
|
57 |
+
if 'prompt' in url_params:
|
58 |
+
return gr.update(visible=True), url_params
|
59 |
+
else:
|
60 |
+
return gr.update(visible=False), url_params
|
61 |
+
|
62 |
|
63 |
def main():
|
64 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
112 |
plain_img = model.produce_attn_maps([base_text_prompt], [negative_text],
|
113 |
height=height, width=width, num_inference_steps=steps,
|
114 |
guidance_scale=guidance_weight)
|
115 |
+
print('time lapses to get attention maps: %.4f' %
|
116 |
+
(time.time()-begin_time))
|
117 |
color_obj_masks, _ = get_token_maps(
|
118 |
model.attention_maps, run_dir, width//8, height//8, color_target_token_ids, seed)
|
119 |
model.masks, token_maps = get_token_maps(
|
|
|
121 |
color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
|
122 |
interpolation=transforms.InterpolationMode.BICUBIC,
|
123 |
antialias=True)
|
124 |
+
for color_obj_mask in color_obj_masks]
|
125 |
text_format_dict['color_obj_atten'] = color_obj_masks
|
126 |
model.remove_evaluation_hooks()
|
127 |
|
|
|
129 |
begin_time = time.time()
|
130 |
seed_everything(seed)
|
131 |
rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
|
132 |
+
height=height, width=width, num_inference_steps=steps,
|
133 |
+
guidance_scale=guidance_weight, use_grad_guidance=use_grad_guidance,
|
134 |
+
text_format_dict=text_format_dict)
|
135 |
print('time lapses to generate image from rich text: %.4f' %
|
136 |
+
(time.time()-begin_time))
|
137 |
return [plain_img[0], rich_img[0], token_maps]
|
138 |
|
139 |
with gr.Blocks() as demo:
|
140 |
+
url_params = gr.JSON({}, visible=True, label="URL Params")
|
141 |
gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">Expressive Text-to-Image Generation with Rich Text</h1>
|
142 |
<p> <a href="https://songweige.github.io/">Songwei Ge</a>, <a href="https://taesung.me/">Taesung Park</a>, <a href="https://www.cs.cmu.edu/~junyanz/">Jun-Yan Zhu</a>, <a href="https://jbhuang0604.github.io/">Jia-Bin Huang</a> <p/>
|
143 |
<p> UMD, Adobe, CMU <p/>
|
|
|
168 |
value=0.5)
|
169 |
with gr.Accordion('Other Parameters', open=False):
|
170 |
steps = gr.Slider(label='Number of Steps',
|
171 |
+
minimum=0,
|
172 |
+
maximum=100,
|
173 |
+
step=1,
|
174 |
+
value=41)
|
175 |
guidance_weight = gr.Slider(label='CFG weight',
|
176 |
+
minimum=0,
|
177 |
+
maximum=50,
|
178 |
+
step=0.1,
|
179 |
+
value=8.5)
|
180 |
width = gr.Dropdown(choices=[512],
|
181 |
+
value=512,
|
182 |
+
label='Width',
|
183 |
+
visible=True)
|
184 |
height = gr.Dropdown(choices=[512],
|
185 |
+
value=512,
|
186 |
+
label='height',
|
187 |
+
visible=True)
|
188 |
+
|
189 |
with gr.Row():
|
190 |
with gr.Column(scale=1, min_width=100):
|
191 |
generate_button = gr.Button("Generate")
|
192 |
+
load_params_button = gr.Button(
|
193 |
+
"Load from URL Params", visible=True)
|
194 |
with gr.Column():
|
195 |
richtext_result = gr.Image(label='Rich-text')
|
196 |
richtext_result.style(height=512)
|
|
|
280 |
None
|
281 |
],
|
282 |
]
|
283 |
+
|
284 |
gr.Examples(examples=footnote_examples,
|
285 |
label='Footnote examples',
|
286 |
inputs=[
|
|
|
414 |
outputs=[plaintext_result, richtext_result, token_map],
|
415 |
_js=get_js_data
|
416 |
)
|
417 |
+
text_input.change(
|
418 |
+
fn=None, inputs=[text_input], outputs=None, _js=set_js_data, queue=False)
|
419 |
+
# load url param prompt to textinput
|
420 |
+
load_params_button.click(fn=lambda x: x['prompt'], inputs=[
|
421 |
+
url_params], outputs=[text_input], queue=False)
|
422 |
+
demo.load(
|
423 |
+
fn=load_url_params,
|
424 |
+
inputs=[url_params],
|
425 |
+
outputs=[load_params_button, url_params],
|
426 |
+
_js=get_window_url_params
|
427 |
+
)
|
428 |
demo.queue(concurrency_count=1)
|
429 |
demo.launch(share=False)
|
430 |
|
431 |
|
432 |
if __name__ == "__main__":
|
433 |
+
main()
|