mrfakename commited on
Commit
6f40cd8
1 Parent(s): d9eb133

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +319 -0
app.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import random
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import PIL.Image
11
+ import spaces
12
+ import torch
13
+ from diffusers import AutoencoderKL, DiffusionPipeline
14
+
15
+ DESCRIPTION = """
16
+ # OpenDalle
17
+
18
+ This is a demo of <a href="https://huggingface.co/dataautogpt3/OpenDalleV1.1">OpenDalle V1.1</a> by @dataautogpt3.
19
+
20
+ It's a merge of several different models and is supposed to provide excellent performance. Try it out!
21
+
22
+ **The code for this demo is based on [@hysts's SD-XL demo](https://huggingface.co/spaces/hysts/SD-XL) running on a A10G GPU.**
23
+ """
24
+ if not torch.cuda.is_available():
25
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
26
+
27
+ MAX_SEED = np.iinfo(np.int32).max
28
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
29
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
30
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
31
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
32
+ ENABLE_REFINER = os.getenv("ENABLE_REFINER", "0") == "1"
33
+
34
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
+ if torch.cuda.is_available():
36
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
37
+ pipe = DiffusionPipeline.from_pretrained(
38
+ "dataautogpt3/OpenDalleV1.1",
39
+ vae=vae,
40
+ torch_dtype=torch.float16,
41
+ use_safetensors=True,
42
+ variant="fp16",
43
+ )
44
+ if ENABLE_REFINER:
45
+ refiner = DiffusionPipeline.from_pretrained(
46
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
47
+ vae=vae,
48
+ torch_dtype=torch.float16,
49
+ use_safetensors=True,
50
+ variant="fp16",
51
+ )
52
+
53
+ if ENABLE_CPU_OFFLOAD:
54
+ pipe.enable_model_cpu_offload()
55
+ if ENABLE_REFINER:
56
+ refiner.enable_model_cpu_offload()
57
+ else:
58
+ pipe.to(device)
59
+ if ENABLE_REFINER:
60
+ refiner.to(device)
61
+
62
+ if USE_TORCH_COMPILE:
63
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
64
+ if ENABLE_REFINER:
65
+ refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
66
+
67
+
68
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
69
+ if randomize_seed:
70
+ seed = random.randint(0, MAX_SEED)
71
+ return seed
72
+
73
+
74
+ @spaces.GPU
75
+ def generate(
76
+ prompt: str,
77
+ negative_prompt: str = "",
78
+ prompt_2: str = "",
79
+ negative_prompt_2: str = "",
80
+ use_negative_prompt: bool = False,
81
+ use_prompt_2: bool = False,
82
+ use_negative_prompt_2: bool = False,
83
+ seed: int = 0,
84
+ width: int = 1024,
85
+ height: int = 1024,
86
+ guidance_scale_base: float = 5.0,
87
+ guidance_scale_refiner: float = 5.0,
88
+ num_inference_steps_base: int = 25,
89
+ num_inference_steps_refiner: int = 25,
90
+ apply_refiner: bool = False,
91
+ ) -> PIL.Image.Image:
92
+ generator = torch.Generator().manual_seed(seed)
93
+
94
+ if not use_negative_prompt:
95
+ negative_prompt = None # type: ignore
96
+ if not use_prompt_2:
97
+ prompt_2 = None # type: ignore
98
+ if not use_negative_prompt_2:
99
+ negative_prompt_2 = None # type: ignore
100
+
101
+ if not apply_refiner:
102
+ return pipe(
103
+ prompt=prompt,
104
+ negative_prompt=negative_prompt,
105
+ prompt_2=prompt_2,
106
+ negative_prompt_2=negative_prompt_2,
107
+ width=width,
108
+ height=height,
109
+ guidance_scale=guidance_scale_base,
110
+ num_inference_steps=num_inference_steps_base,
111
+ generator=generator,
112
+ output_type="pil",
113
+ ).images[0]
114
+ else:
115
+ latents = pipe(
116
+ prompt=prompt,
117
+ negative_prompt=negative_prompt,
118
+ prompt_2=prompt_2,
119
+ negative_prompt_2=negative_prompt_2,
120
+ width=width,
121
+ height=height,
122
+ guidance_scale=guidance_scale_base,
123
+ num_inference_steps=num_inference_steps_base,
124
+ generator=generator,
125
+ output_type="latent",
126
+ ).images
127
+ image = refiner(
128
+ prompt=prompt,
129
+ negative_prompt=negative_prompt,
130
+ prompt_2=prompt_2,
131
+ negative_prompt_2=negative_prompt_2,
132
+ guidance_scale=guidance_scale_refiner,
133
+ num_inference_steps=num_inference_steps_refiner,
134
+ image=latents,
135
+ generator=generator,
136
+ ).images[0]
137
+ return image
138
+
139
+
140
+ examples = [
141
+ "A realistic photograph of an astronaut in a jungle, cold color palette, detailed, 8k",
142
+ "An astronaut riding a green horse",
143
+ ]
144
+
145
+ theme = gr.themes.Base(
146
+ font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
147
+ )
148
+ with gr.Blocks(css="footer{display:none !important}", theme=theme) as demo:
149
+ gr.Markdown(DESCRIPTION)
150
+ gr.DuplicateButton(
151
+ value="Duplicate Space for private use",
152
+ elem_id="duplicate-button",
153
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
154
+ )
155
+ with gr.Group():
156
+ with gr.Row():
157
+ prompt = gr.Text(
158
+ label="Prompt",
159
+ show_label=False,
160
+ max_lines=1,
161
+ placeholder="Enter your prompt",
162
+ container=False,
163
+ )
164
+ run_button = gr.Button("Run", scale=0)
165
+ result = gr.Image(label="Result", show_label=False)
166
+ with gr.Accordion("Advanced options", open=False):
167
+ with gr.Row():
168
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
169
+ use_prompt_2 = gr.Checkbox(label="Use prompt 2", value=False)
170
+ use_negative_prompt_2 = gr.Checkbox(label="Use negative prompt 2", value=False)
171
+ negative_prompt = gr.Text(
172
+ label="Negative prompt",
173
+ max_lines=1,
174
+ placeholder="Enter a negative prompt",
175
+ visible=False,
176
+ )
177
+ prompt_2 = gr.Text(
178
+ label="Prompt 2",
179
+ max_lines=1,
180
+ placeholder="Enter your prompt",
181
+ visible=False,
182
+ )
183
+ negative_prompt_2 = gr.Text(
184
+ label="Negative prompt 2",
185
+ max_lines=1,
186
+ placeholder="Enter a negative prompt",
187
+ visible=False,
188
+ )
189
+
190
+ seed = gr.Slider(
191
+ label="Seed",
192
+ minimum=0,
193
+ maximum=MAX_SEED,
194
+ step=1,
195
+ value=0,
196
+ )
197
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
198
+ with gr.Row():
199
+ width = gr.Slider(
200
+ label="Width",
201
+ minimum=256,
202
+ maximum=MAX_IMAGE_SIZE,
203
+ step=32,
204
+ value=1024,
205
+ )
206
+ height = gr.Slider(
207
+ label="Height",
208
+ minimum=256,
209
+ maximum=MAX_IMAGE_SIZE,
210
+ step=32,
211
+ value=1024,
212
+ )
213
+ apply_refiner = gr.Checkbox(label="Apply refiner", value=False, visible=ENABLE_REFINER)
214
+ with gr.Row():
215
+ guidance_scale_base = gr.Slider(
216
+ label="Guidance scale for base",
217
+ minimum=1,
218
+ maximum=20,
219
+ step=0.1,
220
+ value=5.0,
221
+ )
222
+ num_inference_steps_base = gr.Slider(
223
+ label="Number of inference steps for base",
224
+ minimum=10,
225
+ maximum=100,
226
+ step=1,
227
+ value=25,
228
+ )
229
+ with gr.Row(visible=False) as refiner_params:
230
+ guidance_scale_refiner = gr.Slider(
231
+ label="Guidance scale for refiner",
232
+ minimum=1,
233
+ maximum=20,
234
+ step=0.1,
235
+ value=5.0,
236
+ )
237
+ num_inference_steps_refiner = gr.Slider(
238
+ label="Number of inference steps for refiner",
239
+ minimum=10,
240
+ maximum=100,
241
+ step=1,
242
+ value=25,
243
+ )
244
+
245
+ gr.Examples(
246
+ examples=examples,
247
+ inputs=prompt,
248
+ outputs=result,
249
+ fn=generate,
250
+ cache_examples=CACHE_EXAMPLES,
251
+ )
252
+
253
+ use_negative_prompt.change(
254
+ fn=lambda x: gr.update(visible=x),
255
+ inputs=use_negative_prompt,
256
+ outputs=negative_prompt,
257
+ queue=False,
258
+ api_name=False,
259
+ )
260
+ use_prompt_2.change(
261
+ fn=lambda x: gr.update(visible=x),
262
+ inputs=use_prompt_2,
263
+ outputs=prompt_2,
264
+ queue=False,
265
+ api_name=False,
266
+ )
267
+ use_negative_prompt_2.change(
268
+ fn=lambda x: gr.update(visible=x),
269
+ inputs=use_negative_prompt_2,
270
+ outputs=negative_prompt_2,
271
+ queue=False,
272
+ api_name=False,
273
+ )
274
+ apply_refiner.change(
275
+ fn=lambda x: gr.update(visible=x),
276
+ inputs=apply_refiner,
277
+ outputs=refiner_params,
278
+ queue=False,
279
+ api_name=False,
280
+ )
281
+
282
+ gr.on(
283
+ triggers=[
284
+ prompt.submit,
285
+ negative_prompt.submit,
286
+ prompt_2.submit,
287
+ negative_prompt_2.submit,
288
+ run_button.click,
289
+ ],
290
+ fn=randomize_seed_fn,
291
+ inputs=[seed, randomize_seed],
292
+ outputs=seed,
293
+ queue=False,
294
+ api_name=False,
295
+ ).then(
296
+ fn=generate,
297
+ inputs=[
298
+ prompt,
299
+ negative_prompt,
300
+ prompt_2,
301
+ negative_prompt_2,
302
+ use_negative_prompt,
303
+ use_prompt_2,
304
+ use_negative_prompt_2,
305
+ seed,
306
+ width,
307
+ height,
308
+ guidance_scale_base,
309
+ guidance_scale_refiner,
310
+ num_inference_steps_base,
311
+ num_inference_steps_refiner,
312
+ apply_refiner,
313
+ ],
314
+ outputs=result,
315
+ api_name="run",
316
+ )
317
+
318
+ if __name__ == "__main__":
319
+ demo.queue(max_size=20, api_open=False).launch(show_api=False)