AlekseyCalvin commited on
Commit
b44f918
·
verified ·
1 Parent(s): 2f4b4f6

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +333 -0
pipeline.py CHANGED
@@ -97,6 +97,339 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
97
  self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
98
  )
99
  self.default_sample_size = 64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  def __call__(
102
  self,
 
97
  self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
98
  )
99
  self.default_sample_size = 64
100
+ def _get_t5_prompt_embeds(
101
+ self,
102
+ prompt: Union[str, List[str]] = None,
103
+ num_images_per_prompt: int = 1,
104
+ max_sequence_length: int = 512,
105
+ device: Optional[torch.device] = None,
106
+ dtype: Optional[torch.dtype] = None,
107
+ ):
108
+ device = device or self._execution_device
109
+ dtype = dtype or self.text_encoder.dtype
110
+
111
+ prompt = [prompt] if isinstance(prompt, str) else prompt
112
+ batch_size = len(prompt)
113
+
114
+ text_inputs = self.tokenizer_2(
115
+ prompt,
116
+ padding="max_length",
117
+ max_length=max_sequence_length,
118
+ truncation=True,
119
+ return_length=False,
120
+ return_overflowing_tokens=False,
121
+ return_tensors="pt",
122
+ )
123
+ text_input_ids = text_inputs.input_ids
124
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
125
+
126
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
127
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
128
+ logger.warning(
129
+ "The following part of your input was truncated because `max_sequence_length` is set to "
130
+ f" {max_sequence_length} tokens: {removed_text}"
131
+ )
132
+
133
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
134
+
135
+ dtype = self.text_encoder_2.dtype
136
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
137
+
138
+ _, seq_len, _ = prompt_embeds.shape
139
+
140
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
141
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
142
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
143
+
144
+ return prompt_embeds
145
+
146
+ def _get_clip_prompt_embeds(
147
+ self,
148
+ prompt: Union[str, List[str]],
149
+ num_images_per_prompt: int = 1,
150
+ device: Optional[torch.device] = None,
151
+ ):
152
+ device = device or self._execution_device
153
+
154
+ prompt = [prompt] if isinstance(prompt, str) else prompt
155
+ batch_size = len(prompt)
156
+
157
+ text_inputs = self.tokenizer(
158
+ prompt,
159
+ padding="max_length",
160
+ max_length=self.tokenizer_max_length,
161
+ truncation=True,
162
+ return_overflowing_tokens=False,
163
+ return_length=False,
164
+ return_tensors="pt",
165
+ )
166
+
167
+ text_input_ids = text_inputs.input_ids
168
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
169
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
170
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
171
+ logger.warning(
172
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
173
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
174
+ )
175
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
176
+
177
+ # Use pooled output of CLIPTextModel
178
+ prompt_embeds = prompt_embeds.pooler_output
179
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
180
+
181
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
182
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
183
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
184
+
185
+ return prompt_embeds
186
+
187
+ def encode_prompt(
188
+ self,
189
+ prompt: Union[str, List[str]],
190
+ prompt_2: Union[str, List[str]],
191
+ negative_prompt: Optional[Union[str, List[str]]] = None,
192
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
193
+ device: Optional[torch.device] = None,
194
+ num_images_per_prompt: int = 1,
195
+ prompt_embeds: Optional[torch.FloatTensor] = None,
196
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
197
+ max_sequence_length: int = 512,
198
+ lora_scale: Optional[float] = None,
199
+ ):
200
+ r"""
201
+
202
+ Args:
203
+ prompt (`str` or `List[str]`, *optional*):
204
+ prompt to be encoded
205
+ prompt_2 (`str` or `List[str]`, *optional*):
206
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
207
+ used in all text-encoders
208
+ device: (`torch.device`):
209
+ torch device
210
+ num_images_per_prompt (`int`):
211
+ number of images that should be generated per prompt
212
+ prompt_embeds (`torch.FloatTensor`, *optional*):
213
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
214
+ provided, text embeddings will be generated from `prompt` input argument.
215
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
216
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
217
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
218
+ lora_scale (`float`, *optional*):
219
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
220
+ """
221
+ device = device or self._execution_device
222
+
223
+ # set lora scale so that monkey patched LoRA
224
+ # function of text encoder can correctly access it
225
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
226
+ self._lora_scale = lora_scale
227
+
228
+ # dynamically adjust the LoRA scale
229
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
230
+ scale_lora_layers(self.text_encoder, lora_scale)
231
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
232
+ scale_lora_layers(self.text_encoder_2, lora_scale)
233
+
234
+ prompt = [prompt] if isinstance(prompt, str) else prompt
235
+ negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
236
+
237
+ if prompt_embeds is None:
238
+ prompt_2 = prompt_2 or prompt
239
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
240
+
241
+ # We only use the pooled prompt output from the CLIPTextModel
242
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
243
+ prompt=prompt,
244
+ device=device,
245
+ num_images_per_prompt=num_images_per_prompt,
246
+ )
247
+ prompt_embeds = self._get_t5_prompt_embeds(
248
+ prompt=prompt_2,
249
+ num_images_per_prompt=num_images_per_prompt,
250
+ max_sequence_length=max_sequence_length,
251
+ device=device,
252
+ )
253
+
254
+ if self.text_encoder is not None:
255
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
256
+ # Retrieve the original scale by scaling back the LoRA layers
257
+ unscale_lora_layers(self.text_encoder, lora_scale)
258
+
259
+ if self.text_encoder_2 is not None:
260
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
261
+ # Retrieve the original scale by scaling back the LoRA layers
262
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
263
+
264
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
265
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
266
+
267
+ return prompt_embeds, pooled_prompt_embeds, text_ids
268
+
269
+ def check_inputs(
270
+ self,
271
+ prompt,
272
+ prompt_2,
273
+ negative_prompt,
274
+ height,
275
+ width,
276
+ prompt_embeds=None,
277
+ pooled_prompt_embeds=None,
278
+ callback_on_step_end_tensor_inputs=None,
279
+ max_sequence_length=None,
280
+ ):
281
+ if height % 8 != 0 or width % 8 != 0:
282
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
283
+
284
+ if callback_on_step_end_tensor_inputs is not None and not all(
285
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
286
+ ):
287
+ raise ValueError(
288
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
289
+ )
290
+
291
+ if prompt is not None and prompt_embeds is not None:
292
+ raise ValueError(
293
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
294
+ " only forward one of the two."
295
+ )
296
+ elif prompt_2 is not None and prompt_embeds is not None:
297
+ raise ValueError(
298
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
299
+ " only forward one of the two."
300
+ )
301
+ elif prompt is None and prompt_embeds is None:
302
+ raise ValueError(
303
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
304
+ )
305
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
306
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
307
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
308
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
309
+
310
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
311
+ raise ValueError(
312
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
313
+ )
314
+
315
+ if max_sequence_length is not None and max_sequence_length > 512:
316
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
317
+
318
+ @staticmethod
319
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
320
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
321
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
322
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
323
+
324
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
325
+
326
+ latent_image_ids = latent_image_ids.reshape(
327
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
328
+ )
329
+
330
+ return latent_image_ids.to(device=device, dtype=dtype)
331
+
332
+ @staticmethod
333
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
334
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
335
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
336
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
337
+
338
+ return latents
339
+
340
+ @staticmethod
341
+ def _unpack_latents(latents, height, width, vae_scale_factor):
342
+ batch_size, num_patches, channels = latents.shape
343
+
344
+ height = height // vae_scale_factor
345
+ width = width // vae_scale_factor
346
+
347
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
348
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
349
+
350
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
351
+
352
+ return latents
353
+
354
+ def enable_vae_slicing(self):
355
+ r"""
356
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
357
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
358
+ """
359
+ self.vae.enable_slicing()
360
+
361
+ def disable_vae_slicing(self):
362
+ r"""
363
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
364
+ computing decoding in one step.
365
+ """
366
+ self.vae.disable_slicing()
367
+
368
+ def enable_vae_tiling(self):
369
+ r"""
370
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
371
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
372
+ processing larger images.
373
+ """
374
+ self.vae.enable_tiling()
375
+
376
+ def disable_vae_tiling(self):
377
+ r"""
378
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
379
+ computing decoding in one step.
380
+ """
381
+ self.vae.disable_tiling()
382
+
383
+ def prepare_latents(
384
+ self,
385
+ batch_size,
386
+ num_channels_latents,
387
+ height,
388
+ width,
389
+ dtype,
390
+ device,
391
+ generator,
392
+ latents=None,
393
+ ):
394
+ height = 2 * (int(height) // self.vae_scale_factor)
395
+ width = 2 * (int(width) // self.vae_scale_factor)
396
+
397
+ shape = (batch_size, num_channels_latents, height, width)
398
+
399
+ if latents is not None:
400
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
401
+ return latents.to(device=device, dtype=dtype), latent_image_ids
402
+
403
+ if isinstance(generator, list) and len(generator) != batch_size:
404
+ raise ValueError(
405
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
406
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
407
+ )
408
+
409
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
410
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
411
+
412
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
413
+
414
+ return latents, latent_image_ids
415
+
416
+ @property
417
+ def guidance_scale(self):
418
+ return self._guidance_scale
419
+
420
+ @property
421
+ def joint_attention_kwargs(self):
422
+ return self._joint_attention_kwargs
423
+
424
+ @property
425
+ def num_timesteps(self):
426
+ return self._num_timesteps
427
+
428
+ @property
429
+ def interrupt(self):
430
+ return self._interrupt
431
+
432
+ @torch.no_grad()
433
 
434
  def __call__(
435
  self,