dome272 commited on
Commit
9793d8c
1 Parent(s): f429fd8

upload files

Browse files
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import gradio as gr
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+ from typing import List
8
+ from diffusers.utils import numpy_to_pil
9
+ from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline
10
+ from diffusers.pipelines.wuerstchen import WuerstchenPrior, default_stage_c_timesteps
11
+ from previewer.modules import Previewer
12
+
13
+ DESCRIPTION = "# Würstchen"
14
+ if not torch.cuda.is_available():
15
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
16
+
17
+ MAX_SEED = np.iinfo(np.int32).max
18
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
19
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
20
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
21
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
22
+ PREVIEW_IMAGES = True
23
+
24
+ dtype = torch.float16
25
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
+ if torch.cuda.is_available():
27
+ prior_pipeline = WuerstchenPriorPipeline.from_pretrained("warp-ai/wuerstchen-prior", torch_dtype=dtype)
28
+ decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained("warp-ai/wuerstchen", torch_dtype=dtype)
29
+ if ENABLE_CPU_OFFLOAD:
30
+ prior_pipeline.enable_model_cpu_offload()
31
+ decoder_pipeline.enable_model_cpu_offload()
32
+ else:
33
+ prior_pipeline.to(device)
34
+ decoder_pipeline.to(device)
35
+
36
+ if USE_TORCH_COMPILE:
37
+ prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
38
+ decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="reduce-overhead", fullgraph=True)
39
+
40
+ if PREVIEW_IMAGES:
41
+ previewer = Previewer()
42
+ previewer.load_state_dict(torch.load(r"C:\Users\d6582\Documents\ml\wuerstchen\diffusers\previewer\text2img_wurstchen_b_v1_previewer_100k.pt")["state_dict"])
43
+ previewer.eval().requires_grad_(False).to(device).to(dtype)
44
+
45
+ def callback_prior(i, t, latents):
46
+ output = previewer(latents)
47
+ output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
48
+ return output
49
+ else:
50
+ previewer = None
51
+ callback_prior = None
52
+ else:
53
+ prior_pipeline = None
54
+ decoder_pipeline = None
55
+
56
+
57
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
58
+ if randomize_seed:
59
+ seed = random.randint(0, MAX_SEED)
60
+ return seed
61
+
62
+
63
+ def generate(
64
+ prompt: str,
65
+ negative_prompt: str = "",
66
+ seed: int = 0,
67
+ width: int = 1024,
68
+ height: int = 1024,
69
+ prior_num_inference_steps: int = 60,
70
+ # prior_timesteps: List[float] = None,
71
+ prior_guidance_scale: float = 4.0,
72
+ decoder_num_inference_steps: int = 12,
73
+ # decoder_timesteps: List[float] = None,
74
+ decoder_guidance_scale: float = 0.0,
75
+ num_images_per_prompt: int = 2,
76
+ ) -> PIL.Image.Image:
77
+ generator = torch.Generator().manual_seed(seed)
78
+
79
+ prior_output = prior_pipeline(
80
+ prompt=prompt,
81
+ height=height,
82
+ width=width,
83
+ timesteps=default_stage_c_timesteps,
84
+ negative_prompt=negative_prompt,
85
+ guidance_scale=prior_guidance_scale,
86
+ num_images_per_prompt=num_images_per_prompt,
87
+ generator=generator,
88
+ callback=callback_prior,
89
+ )
90
+
91
+ if PREVIEW_IMAGES:
92
+ for _ in range(len(default_stage_c_timesteps)):
93
+ r = next(prior_output)
94
+ if isinstance(r, list):
95
+ yield r
96
+ prior_output = r
97
+
98
+ decoder_output = decoder_pipeline(
99
+ image_embeddings=prior_output.image_embeddings,
100
+ prompt=prompt,
101
+ num_inference_steps=decoder_num_inference_steps,
102
+ # timesteps=decoder_timesteps,
103
+ guidance_scale=decoder_guidance_scale,
104
+ negative_prompt=negative_prompt,
105
+ num_images_per_prompt=num_images_per_prompt,
106
+ generator=generator,
107
+ output_type="pil",
108
+ ).images
109
+ yield decoder_output
110
+
111
+
112
+ examples = [
113
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
114
+ "An astronaut riding a green horse",
115
+ ]
116
+
117
+ with gr.Blocks(css="style.css") as demo:
118
+ gr.Markdown(DESCRIPTION)
119
+ gr.DuplicateButton(
120
+ value="Duplicate Space for private use",
121
+ elem_id="duplicate-button",
122
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
123
+ )
124
+ with gr.Group():
125
+ with gr.Row():
126
+ prompt = gr.Text(
127
+ label="Prompt",
128
+ show_label=False,
129
+ max_lines=1,
130
+ placeholder="Enter your prompt",
131
+ container=False,
132
+ )
133
+ run_button = gr.Button("Run", scale=0)
134
+ result = gr.Gallery(label="Result", show_label=False)
135
+ with gr.Accordion("Advanced options", open=False):
136
+ negative_prompt = gr.Text(
137
+ label="Negative prompt",
138
+ max_lines=1,
139
+ placeholder="Enter a Negative Prompt",
140
+ )
141
+
142
+ seed = gr.Slider(
143
+ label="Seed",
144
+ minimum=0,
145
+ maximum=MAX_SEED,
146
+ step=1,
147
+ value=0,
148
+ )
149
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
150
+ with gr.Row():
151
+ width = gr.Slider(
152
+ label="Width",
153
+ minimum=768,
154
+ maximum=MAX_IMAGE_SIZE,
155
+ step=128,
156
+ value=1024,
157
+ )
158
+ height = gr.Slider(
159
+ label="Height",
160
+ minimum=768,
161
+ maximum=MAX_IMAGE_SIZE,
162
+ step=128,
163
+ value=1024,
164
+ )
165
+ num_images_per_prompt = gr.Slider(
166
+ label="Number of Images",
167
+ minimum=1,
168
+ maximum=6,
169
+ step=1,
170
+ value=2,
171
+ )
172
+ with gr.Row():
173
+ prior_guidance_scale = gr.Slider(
174
+ label="Prior Guidance Scale",
175
+ minimum=1,
176
+ maximum=20,
177
+ step=0.1,
178
+ value=4.0,
179
+ )
180
+ prior_num_inference_steps = gr.Slider(
181
+ label="Prior Inference Steps",
182
+ minimum=10,
183
+ maximum=100,
184
+ step=1,
185
+ value=60,
186
+ )
187
+
188
+ decoder_guidance_scale = gr.Slider(
189
+ label="Decoder Guidance Scale",
190
+ minimum=1,
191
+ maximum=20,
192
+ step=0.1,
193
+ value=0.0,
194
+ )
195
+ decoder_num_inference_steps = gr.Slider(
196
+ label="Decoder Inference Steps",
197
+ minimum=10,
198
+ maximum=100,
199
+ step=1,
200
+ value=12,
201
+ )
202
+
203
+ gr.Examples(
204
+ examples=examples,
205
+ inputs=prompt,
206
+ outputs=result,
207
+ fn=generate,
208
+ cache_examples=CACHE_EXAMPLES,
209
+ )
210
+
211
+ inputs = [
212
+ prompt,
213
+ negative_prompt,
214
+ seed,
215
+ width,
216
+ height,
217
+ prior_num_inference_steps,
218
+ # prior_timesteps,
219
+ prior_guidance_scale,
220
+ decoder_num_inference_steps,
221
+ # decoder_timesteps,
222
+ decoder_guidance_scale,
223
+ num_images_per_prompt,
224
+ ]
225
+ prompt.submit(
226
+ fn=randomize_seed_fn,
227
+ inputs=[seed, randomize_seed],
228
+ outputs=seed,
229
+ queue=False,
230
+ api_name=False,
231
+ ).then(
232
+ fn=generate,
233
+ inputs=inputs,
234
+ outputs=result,
235
+ api_name="run",
236
+ )
237
+ negative_prompt.submit(
238
+ fn=randomize_seed_fn,
239
+ inputs=[seed, randomize_seed],
240
+ outputs=seed,
241
+ queue=False,
242
+ api_name=False,
243
+ ).then(
244
+ fn=generate,
245
+ inputs=inputs,
246
+ outputs=result,
247
+ api_name=False,
248
+ )
249
+ run_button.click(
250
+ fn=randomize_seed_fn,
251
+ inputs=[seed, randomize_seed],
252
+ outputs=seed,
253
+ queue=False,
254
+ api_name=False,
255
+ ).then(
256
+ fn=generate,
257
+ inputs=inputs,
258
+ outputs=result,
259
+ api_name=False,
260
+ )
261
+
262
+ if __name__ == "__main__":
263
+ demo.queue(max_size=20).launch()
previewer/modules.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ # Effnet 16x16 to 64x64 previewer
4
+ class Previewer(nn.Module):
5
+ def __init__(self, c_in=16, c_hidden=512, c_out=3):
6
+ super().__init__()
7
+ self.blocks = nn.Sequential(
8
+ nn.Conv2d(c_in, c_hidden, kernel_size=1), # 36 channels to 512 channels
9
+ nn.GELU(),
10
+ nn.BatchNorm2d(c_hidden),
11
+
12
+ nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
13
+ nn.GELU(),
14
+ nn.BatchNorm2d(c_hidden),
15
+
16
+ nn.ConvTranspose2d(c_hidden, c_hidden//2, kernel_size=2, stride=2), # 16 -> 32
17
+ nn.GELU(),
18
+ nn.BatchNorm2d(c_hidden//2),
19
+
20
+ nn.Conv2d(c_hidden//2, c_hidden//2, kernel_size=3, padding=1),
21
+ nn.GELU(),
22
+ nn.BatchNorm2d(c_hidden//2),
23
+
24
+ nn.ConvTranspose2d(c_hidden//2, c_hidden//4, kernel_size=2, stride=2), # 32 -> 64
25
+ nn.GELU(),
26
+ nn.BatchNorm2d(c_hidden//4),
27
+
28
+ nn.Conv2d(c_hidden//4, c_hidden//4, kernel_size=3, padding=1),
29
+ nn.GELU(),
30
+ nn.BatchNorm2d(c_hidden//4),
31
+
32
+ nn.Conv2d(c_hidden//4, c_out, kernel_size=1),
33
+ )
34
+
35
+ def forward(self, x):
36
+ return self.blocks(x)
previewer/text2img_wurstchen_b_v1_previewer_100k.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76e82483253b24430b20e3e0c98ec2f9aeb45f0b487f7b330bac044b5de0d6f7
3
+ size 45244773