AlekseyCalvin commited on
Commit
77881f1
·
verified ·
1 Parent(s): 6cb0ace

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +1 -188
pipeline.py CHANGED
@@ -66,7 +66,7 @@ def prepare_timesteps(
66
  return timesteps, num_inference_steps
67
 
68
  # FLUX pipeline function
69
- class FluxWithCFGPipeline(StableDiffusion3Pipeline):
70
 
71
  def __init__(
72
  self,
@@ -244,193 +244,6 @@ class FluxWithCFGPipeline(StableDiffusion3Pipeline):
244
  self.maybe_free_model_hooks()
245
  torch.cuda.empty_cache()
246
 
247
- def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
248
- """Decodes the given latents into an image."""
249
- vae = vae or self.vae
250
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
251
- latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
252
- image = vae.decode(latents, return_dict=False)[0]
253
- return self.image_processor.postprocess(image, output_type=output_type)[0]
254
-
255
- class FluxWithCFGPipeline(StableDiffusion3Pipeline):
256
- @torch.inference_mode()
257
- def __init__(
258
- self,
259
- transformer: FluxTransformer2DModel,
260
- scheduler: FlowMatchEulerDiscreteScheduler,
261
- vae: AutoencoderKL,
262
- text_encoder: CLIPTextModelWithProjection,
263
- tokenizer: CLIPTokenizer,
264
- tokenizer_2: T5TokenizerFast,
265
- tokenizer_3: None,
266
- text_encoder_2: T5EncoderModel,
267
- text_encoder_3: None,
268
- ):
269
- super().__init__()
270
-
271
- self.register_modules(
272
- vae=vae,
273
- text_encoder=text_encoder,
274
- text_encoder_2=text_encoder_2,
275
- text_encoder_3=text_encoder_3,
276
- tokenizer=tokenizer,
277
- tokenizer_2=tokenizer_2,
278
- tokenizer_3=tokenizer_3,
279
- transformer=transformer,
280
- scheduler=scheduler,
281
- )
282
- self.vae_scale_factor = (
283
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 16
284
- )
285
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
286
- self.tokenizer_max_length = (
287
- self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
288
- )
289
- self.default_sample_size = 64
290
-
291
- def generate_image(
292
- self,
293
- prompt: Union[str, List[str]] = None,
294
- prompt_2: Optional[Union[str, List[str]]] = None,
295
- height: Optional[int] = None,
296
- width: Optional[int] = None,
297
- negative_prompt: Optional[Union[str, List[str]]] = None,
298
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
299
- num_inference_steps: int = 4,
300
- timesteps: List[int] = None,
301
- guidance_scale: float = 3.5,
302
- num_images_per_prompt: Optional[int] = 1,
303
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
304
- latents: Optional[torch.FloatTensor] = None,
305
- prompt_embeds: Optional[torch.FloatTensor] = None,
306
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
307
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
308
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
309
- output_type: Optional[str] = "pil",
310
- return_dict: bool = True,
311
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
312
- max_sequence_length: int = 300,
313
- ):
314
- height = height or self.default_sample_size * self.vae_scale_factor
315
- width = width or self.default_sample_size * self.vae_scale_factor
316
-
317
- # 1. Check inputs
318
- self.check_inputs(
319
- prompt,
320
- prompt_2,
321
- negative_prompt,
322
- height,
323
- width,
324
- prompt_embeds=prompt_embeds,
325
- pooled_prompt_embeds=pooled_prompt_embeds,
326
- max_sequence_length=max_sequence_length,
327
- )
328
-
329
- self._guidance_scale = guidance_scale
330
- self._joint_attention_kwargs = joint_attention_kwargs
331
- self._interrupt = False
332
-
333
- # 2. Define call parameters
334
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
335
- device = "cuda" if torch.cuda.is_available() else "cpu"
336
-
337
- # 3. Encode prompt
338
- lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
339
- prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
340
- prompt=prompt,
341
- prompt_2=prompt_2,
342
- prompt_embeds=prompt_embeds,
343
- pooled_prompt_embeds=pooled_prompt_embeds,
344
- device=device,
345
- num_images_per_prompt=num_images_per_prompt,
346
- max_sequence_length=max_sequence_length,
347
- lora_scale=lora_scale,
348
- )
349
- negative_prompt_embeds, negative_pooled_prompt_embeds, negative_text_ids = self.encode_prompt(
350
- prompt=negative_prompt,
351
- prompt_2=negative_prompt_2,
352
- prompt_embeds=negative_prompt_embeds,
353
- pooled_prompt_embeds=negative_pooled_prompt_embeds,
354
- device=device,
355
- num_images_per_prompt=num_images_per_prompt,
356
- max_sequence_length=max_sequence_length,
357
- lora_scale=lora_scale,
358
- )
359
-
360
- # 4. Prepare latent variables
361
- num_channels_latents = self.transformer.config.in_channels // 4
362
- latents, latent_image_ids = self.prepare_latents(
363
- batch_size * num_images_per_prompt,
364
- num_channels_latents,
365
- height,
366
- width,
367
- prompt_embeds.dtype,
368
- negative_prompt_embeds.dtype,
369
- device,
370
- generator,
371
- latents,
372
- )
373
-
374
- # 5. Prepare timesteps
375
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
376
- image_seq_len = latents.shape[1]
377
- mu = calculate_timestep_shift(image_seq_len)
378
- timesteps, num_inference_steps = prepare_timesteps(
379
- self.scheduler,
380
- num_inference_steps,
381
- device,
382
- timesteps,
383
- sigmas,
384
- mu=mu,
385
- )
386
- self._num_timesteps = len(timesteps)
387
-
388
- # Handle guidance
389
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
390
-
391
- # 6. Denoising loop
392
- for i, t in enumerate(timesteps):
393
- if self.interrupt:
394
- continue
395
-
396
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
397
-
398
- noise_pred = self.transformer(
399
- hidden_states=latents,
400
- timestep=timestep / 1000,
401
- guidance=guidance,
402
- pooled_projections=pooled_prompt_embeds,
403
- encoder_hidden_states=prompt_embeds,
404
- txt_ids=text_ids,
405
- img_ids=latent_image_ids,
406
- joint_attention_kwargs=self.joint_attention_kwargs,
407
- return_dict=False,
408
- )[0]
409
-
410
- noise_pred_uncond = self.transformer(
411
- hidden_states=latents,
412
- timestep=timestep / 1000,
413
- guidance=guidance,
414
- pooled_projections=negative_pooled_prompt_embeds,
415
- encoder_hidden_states=negative_prompt_embeds,
416
- txt_ids=negative_text_ids,
417
- img_ids=latent_image_ids,
418
- joint_attention_kwargs=self.joint_attention_kwargs,
419
- return_dict=False,
420
- )[0]
421
-
422
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
423
-
424
- latents_dtype = latents.dtype
425
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
426
- # Yield intermediate result
427
- torch.cuda.empty_cache()
428
-
429
- # Final image
430
- return self._decode_latents_to_image(latents, height, width, output_type)
431
- self.maybe_free_model_hooks()
432
- torch.cuda.empty_cache()
433
-
434
  def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
435
  """Decodes the given latents into an image."""
436
  vae = vae or self.vae
 
66
  return timesteps, num_inference_steps
67
 
68
  # FLUX pipeline function
69
+ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
70
 
71
  def __init__(
72
  self,
 
244
  self.maybe_free_model_hooks()
245
  torch.cuda.empty_cache()
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
248
  """Decodes the given latents into an image."""
249
  vae = vae or self.vae