Text-to-Image
daoyuan98 commited on
Commit
a6c206d
1 Parent(s): 39dc4f4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +367 -0
app.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import spaces
5
+ import torch
6
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
7
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
8
+
9
+ from model import Flux
10
+
11
+ def calculate_shift(
12
+ image_seq_len,
13
+ base_seq_len: int = 256,
14
+ max_seq_len: int = 4096,
15
+ base_shift: float = 0.5,
16
+ max_shift: float = 1.16,
17
+ ):
18
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
19
+ b = base_shift - m * base_seq_len
20
+ mu = image_seq_len * m + b
21
+ return mu
22
+
23
+
24
+ def retrieve_timesteps(
25
+ scheduler,
26
+ num_inference_steps: Optional[int] = None,
27
+ device: Optional[Union[str, torch.device]] = None,
28
+ timesteps: Optional[List[int]] = None,
29
+ sigmas: Optional[List[float]] = None,
30
+ **kwargs,
31
+ ):
32
+ if timesteps is not None and sigmas is not None:
33
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
34
+ if timesteps is not None:
35
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
36
+ timesteps = scheduler.timesteps
37
+ num_inference_steps = len(timesteps)
38
+ elif sigmas is not None:
39
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
40
+ timesteps = scheduler.timesteps
41
+ num_inference_steps = len(timesteps)
42
+ else:
43
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
44
+ timesteps = scheduler.timesteps
45
+ return timesteps, num_inference_steps
46
+
47
+
48
+ @torch.inference_mode()
49
+ def flux_pipe_call_that_returns_an_iterable_of_images(
50
+ self,
51
+ prompt: Union[str, List[str]] = None,
52
+ prompt_2: Optional[Union[str, List[str]]] = None,
53
+ height: Optional[int] = None,
54
+ width: Optional[int] = None,
55
+ num_inference_steps: int = 28,
56
+ timesteps: List[int] = None,
57
+ guidance_scale: float = 3.5,
58
+ num_images_per_prompt: Optional[int] = 1,
59
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
60
+ latents: Optional[torch.FloatTensor] = None,
61
+ prompt_embeds: Optional[torch.FloatTensor] = None,
62
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
63
+ output_type: Optional[str] = "pil",
64
+ return_dict: bool = True,
65
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
66
+ max_sequence_length: int = 512,
67
+ good_vae: Optional[Any] = None,
68
+ ):
69
+ height = height or self.default_sample_size * self.vae_scale_factor
70
+ width = width or self.default_sample_size * self.vae_scale_factor
71
+
72
+ # 1. Check inputs
73
+ self.check_inputs(
74
+ prompt,
75
+ prompt_2,
76
+ height,
77
+ width,
78
+ prompt_embeds=prompt_embeds,
79
+ pooled_prompt_embeds=pooled_prompt_embeds,
80
+ max_sequence_length=max_sequence_length,
81
+ )
82
+
83
+ self._guidance_scale = guidance_scale
84
+ self._joint_attention_kwargs = joint_attention_kwargs
85
+ self._interrupt = False
86
+
87
+ # 2. Define call parameters
88
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
89
+ device = self._execution_device
90
+
91
+ # 3. Encode prompt
92
+ lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
93
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
94
+ prompt=prompt,
95
+ prompt_2=prompt_2,
96
+ prompt_embeds=prompt_embeds,
97
+ pooled_prompt_embeds=pooled_prompt_embeds,
98
+ device=device,
99
+ num_images_per_prompt=num_images_per_prompt,
100
+ max_sequence_length=max_sequence_length,
101
+ lora_scale=lora_scale,
102
+ )
103
+ # 4. Prepare latent variables
104
+ num_channels_latents = self.transformer.config.in_channels // 4
105
+ latents, latent_image_ids = self.prepare_latents(
106
+ batch_size * num_images_per_prompt,
107
+ num_channels_latents,
108
+ height,
109
+ width,
110
+ prompt_embeds.dtype,
111
+ device,
112
+ generator,
113
+ latents,
114
+ )
115
+ # 5. Prepare timesteps
116
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
117
+ image_seq_len = latents.shape[1]
118
+ mu = calculate_shift(
119
+ image_seq_len,
120
+ self.scheduler.config.base_image_seq_len,
121
+ self.scheduler.config.max_image_seq_len,
122
+ self.scheduler.config.base_shift,
123
+ self.scheduler.config.max_shift,
124
+ )
125
+ timesteps, num_inference_steps = retrieve_timesteps(
126
+ self.scheduler,
127
+ num_inference_steps,
128
+ device,
129
+ timesteps,
130
+ sigmas,
131
+ mu=mu,
132
+ )
133
+ self._num_timesteps = len(timesteps)
134
+
135
+ # Handle guidance
136
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
137
+
138
+ # 6. Denoising loop
139
+ for i, t in enumerate(timesteps):
140
+ if self.interrupt:
141
+ continue
142
+
143
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
144
+
145
+ noise_pred = self.transformer(
146
+ hidden_states=latents,
147
+ timestep=timestep / 1000,
148
+ guidance=guidance,
149
+ pooled_projections=pooled_prompt_embeds,
150
+ encoder_hidden_states=prompt_embeds,
151
+ txt_ids=text_ids,
152
+ img_ids=latent_image_ids,
153
+ joint_attention_kwargs=self.joint_attention_kwargs,
154
+ return_dict=False,
155
+ )[0]
156
+ # Yield intermediate result
157
+ latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
158
+ latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
159
+ image = self.vae.decode(latents_for_image, return_dict=False)[0]
160
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]
161
+
162
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
163
+ torch.cuda.empty_cache()
164
+
165
+ # Final image using good_vae
166
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
167
+ latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
168
+ image = good_vae.decode(latents, return_dict=False)[0]
169
+ self.maybe_free_model_hooks()
170
+ torch.cuda.empty_cache()
171
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]
172
+
173
+
174
+ @dataclass
175
+ class ModelSpec:
176
+ params: FluxParams
177
+ ae_params: AutoEncoderParams
178
+ ckpt_path: str
179
+ ae_path: str
180
+ repo_id: str
181
+ repo_flow: str
182
+ repo_ae: str
183
+ repo_id_ae: str
184
+
185
+ config = ModelSpec(
186
+ repo_id="TencentARC/flux-mini",
187
+ repo_flow="flux-mini.safetensors",
188
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
189
+ repo_ae="ae.safetensors",
190
+ ckpt_path=os.getenv("FLUX_MINI", None),
191
+ params=FluxParams(
192
+ in_channels=64,
193
+ vec_in_dim=768,
194
+ context_in_dim=4096,
195
+ hidden_size=3072,
196
+ mlp_ratio=4.0,
197
+ num_heads=24,
198
+ depth=5,
199
+ depth_single_blocks=10,
200
+ axes_dim=[16, 56, 56],
201
+ theta=10_000,
202
+ qkv_bias=True,
203
+ guidance_embed=True,
204
+ )
205
+
206
+
207
+ def load_flow_model2(device: str = "cuda", hf_download: bool = True):
208
+ if (
209
+ and config.repo_id is not None
210
+ and config.repo_flow is not None
211
+ and hf_download
212
+ ):
213
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
214
+
215
+ model = Flux(params)
216
+ if ckpt_path is not None:
217
+ sd = load_sft(ckpt_path, device=str(device))
218
+ missing, unexpected = model.load_state_dict(sd, strict=True)
219
+ return model
220
+
221
+
222
+
223
+
224
+ dtype = torch.bfloat16
225
+ device = "cuda" if torch.cuda.is_available() else "cpu"
226
+
227
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler").to(device)
228
+ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
229
+ text_encoder = CLIPTextModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder").to(device)
230
+ tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer").to(device)
231
+ text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2").to(device)
232
+ tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer_2").to(device)
233
+ transformer = load_flow_model2(device)
234
+
235
+ pipe = FluxPipeline(
236
+ scheduler,
237
+ vae,
238
+ text_encoder,
239
+ tokenizer,
240
+ text_encoder_2,
241
+ tokenizer_2
242
+ transformer
243
+ )
244
+ torch.cuda.empty_cache()
245
+
246
+ MAX_SEED = np.iinfo(np.int32).max
247
+ MAX_IMAGE_SIZE = 2048
248
+
249
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
250
+
251
+ @spaces.GPU(duration=75)
252
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
253
+ if randomize_seed:
254
+ seed = random.randint(0, MAX_SEED)
255
+ generator = torch.Generator().manual_seed(seed)
256
+
257
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
258
+ prompt=prompt,
259
+ guidance_scale=guidance_scale,
260
+ num_inference_steps=num_inference_steps,
261
+ width=width,
262
+ height=height,
263
+ generator=generator,
264
+ output_type="pil",
265
+ good_vae=good_vae,
266
+ ):
267
+ yield img, seed
268
+
269
+ examples = [
270
+ "thousands of luminous oysters on a shore reflecting and refracting the sunset",
271
+ "profile of sad Socrates, full body, high detail, dramatic scene, Epic dynamic action, wide angle, cinematic, hyper realistic, concept art, warm muted tones as painted by Bernie Wrightson, Frank Frazetta,",
272
+ "ghosts, astronauts, robots, cats, superhero costumes, line drawings, naive, simple, exploring a strange planet, coloured pencil crayons, , black canvas background, drawn by 5 year old child",
273
+ ]
274
+
275
+ css="""
276
+ #col-container {
277
+ margin: 0 auto;
278
+ max-width: 520px;
279
+ }
280
+ """
281
+
282
+ with gr.Blocks(css=css) as demo:
283
+
284
+ with gr.Column(elem_id="col-container"):
285
+ gr.Markdown(f"""# FLUX-Mini
286
+ A 3.2B param rectified flow transformer distilled from [FLUX.1 [dev]](https://blackforestlabs.ai/)
287
+ [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)]
288
+ """)
289
+
290
+ with gr.Row():
291
+
292
+ prompt = gr.Text(
293
+ label="Prompt",
294
+ show_label=False,
295
+ max_lines=1,
296
+ placeholder="Enter your prompt",
297
+ container=False,
298
+ )
299
+
300
+ run_button = gr.Button("Run", scale=0)
301
+
302
+ result = gr.Image(label="Result", show_label=False)
303
+
304
+ with gr.Accordion("Advanced Settings", open=False):
305
+
306
+ seed = gr.Slider(
307
+ label="Seed",
308
+ minimum=0,
309
+ maximum=MAX_SEED,
310
+ step=1,
311
+ value=0,
312
+ )
313
+
314
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
315
+
316
+ with gr.Row():
317
+
318
+ width = gr.Slider(
319
+ label="Width",
320
+ minimum=256,
321
+ maximum=MAX_IMAGE_SIZE,
322
+ step=32,
323
+ value=1024,
324
+ )
325
+
326
+ height = gr.Slider(
327
+ label="Height",
328
+ minimum=256,
329
+ maximum=MAX_IMAGE_SIZE,
330
+ step=32,
331
+ value=1024,
332
+ )
333
+
334
+ with gr.Row():
335
+
336
+ guidance_scale = gr.Slider(
337
+ label="Guidance Scale",
338
+ minimum=1,
339
+ maximum=15,
340
+ step=0.1,
341
+ value=3.5,
342
+ )
343
+
344
+ num_inference_steps = gr.Slider(
345
+ label="Number of inference steps",
346
+ minimum=1,
347
+ maximum=50,
348
+ step=1,
349
+ value=28,
350
+ )
351
+
352
+ gr.Examples(
353
+ examples = examples,
354
+ fn = infer,
355
+ inputs = [prompt],
356
+ outputs = [result, seed],
357
+ cache_examples="lazy"
358
+ )
359
+
360
+ gr.on(
361
+ triggers=[run_button.click, prompt.submit],
362
+ fn = infer,
363
+ inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
364
+ outputs = [result, seed]
365
+ )
366
+
367
+ demo.launch()