AlanB commited on
Commit
33c1548
·
verified ·
1 Parent(s): 5cbad96

Merge final release update

Browse files
Files changed (1) hide show
  1. pipeline.py +99 -58
pipeline.py CHANGED
@@ -16,6 +16,7 @@ import inspect
16
  from typing import Any, Callable, Dict, List, Optional, Union
17
 
18
  import numpy as np
 
19
  import torch
20
  from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
21
 
@@ -80,6 +81,7 @@ EXAMPLE_DOC_STRING = """
80
  """
81
 
82
 
 
83
  def calculate_shift(
84
  image_seq_len,
85
  base_seq_len: int = 256,
@@ -235,6 +237,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
235
  )
236
  self.default_sample_size = 64
237
 
 
238
  def _get_t5_prompt_embeds(
239
  self,
240
  prompt: Union[str, List[str]] = None,
@@ -281,6 +284,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
281
 
282
  return prompt_embeds
283
 
 
284
  def _get_clip_prompt_embeds(
285
  self,
286
  prompt: Union[str, List[str]],
@@ -317,11 +321,12 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
317
  prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
318
 
319
  # duplicate text embeddings for each generation per prompt, using mps friendly method
320
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
321
  prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
322
 
323
  return prompt_embeds
324
 
 
325
  def encode_prompt(
326
  self,
327
  prompt: Union[str, List[str]],
@@ -368,10 +373,6 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
368
  scale_lora_layers(self.text_encoder_2, lora_scale)
369
 
370
  prompt = [prompt] if isinstance(prompt, str) else prompt
371
- if prompt is not None:
372
- batch_size = len(prompt)
373
- else:
374
- batch_size = prompt_embeds.shape[0]
375
 
376
  if prompt_embeds is None:
377
  prompt_2 = prompt_2 or prompt
@@ -401,11 +402,11 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
401
  unscale_lora_layers(self.text_encoder_2, lora_scale)
402
 
403
  dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
404
- text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
405
- text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
406
 
407
  return prompt_embeds, pooled_prompt_embeds, text_ids
408
 
 
409
  def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
410
  if isinstance(generator, list):
411
  image_latents = [
@@ -421,12 +422,12 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
421
  return image_latents
422
 
423
  # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
424
- def get_timesteps(self, timesteps, num_inference_steps, strength, device):
425
  # get the original timestep using init_timestep
426
  init_timestep = min(num_inference_steps * strength, num_inference_steps)
427
 
428
  t_start = int(max(num_inference_steps - init_timestep, 0))
429
- timesteps = timesteps[t_start * self.scheduler.order :]
430
  if hasattr(self.scheduler, "set_begin_index"):
431
  self.scheduler.set_begin_index(t_start * self.scheduler.order)
432
 
@@ -436,12 +437,16 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
436
  self,
437
  prompt,
438
  prompt_2,
 
 
439
  strength,
440
  height,
441
  width,
 
442
  prompt_embeds=None,
443
  pooled_prompt_embeds=None,
444
  callback_on_step_end_tensor_inputs=None,
 
445
  max_sequence_length=None,
446
  ):
447
  if strength < 0 or strength > 1:
@@ -481,10 +486,24 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
481
  "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
482
  )
483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
  if max_sequence_length is not None and max_sequence_length > 512:
485
  raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
486
 
487
  @staticmethod
 
488
  def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
489
  latent_image_ids = torch.zeros(height // 2, width // 2, 3)
490
  latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
@@ -492,14 +511,14 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
492
 
493
  latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
494
 
495
- latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
496
  latent_image_ids = latent_image_ids.reshape(
497
- batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
498
  )
499
 
500
  return latent_image_ids.to(device=device, dtype=dtype)
501
 
502
  @staticmethod
 
503
  def _pack_latents(latents, batch_size, num_channels_latents, height, width):
504
  latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
505
  latents = latents.permute(0, 2, 4, 1, 3, 5)
@@ -508,6 +527,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
508
  return latents
509
 
510
  @staticmethod
 
511
  def _unpack_latents(latents, height, width, vae_scale_factor):
512
  batch_size, num_patches, channels = latents.shape
513
 
@@ -523,6 +543,8 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
523
 
524
  def prepare_latents(
525
  self,
 
 
526
  batch_size,
527
  num_channels_latents,
528
  height,
@@ -531,9 +553,6 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
531
  device,
532
  generator,
533
  latents=None,
534
- image=None,
535
- timestep=None,
536
- is_strength_max=None,
537
  ):
538
  if isinstance(generator, list) and len(generator) != batch_size:
539
  raise ValueError(
@@ -541,27 +560,33 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
541
  f" size of {batch_size}. Make sure the batch size matches the length of the generators."
542
  )
543
 
544
- if (image is None or timestep is None) and not is_strength_max:
545
- raise ValueError(
546
- "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
547
- "However, either the image or the noise timestep has not been provided."
548
- )
549
-
550
  height = 2 * (int(height) // self.vae_scale_factor)
551
  width = 2 * (int(width) // self.vae_scale_factor)
552
 
553
  shape = (batch_size, num_channels_latents, height, width)
554
  latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
555
- # return latents.to(device=device, dtype=dtype), latent_image_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
556
 
557
  if latents is None:
558
- image = image.to(device=device, dtype=dtype)
559
- image_latents = self._encode_vae_image(image=image, generator=generator)
560
  else:
561
- image_latents = latents
 
562
 
563
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
564
- latents = noise if is_strength_max else self.scheduler.scale_noise(image_latents, timestep, noise)
565
  noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
566
  image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
567
  latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
@@ -572,6 +597,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
572
  mask,
573
  masked_image,
574
  batch_size,
 
575
  num_images_per_prompt,
576
  height,
577
  width,
@@ -579,12 +605,12 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
579
  device,
580
  generator,
581
  ):
 
 
582
  # resize the mask to latents shape as we concatenate the mask to the latents
583
  # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
584
  # and half precision
585
- mask = torch.nn.functional.interpolate(
586
- mask, size=(2 * height // self.vae_scale_factor, 2 * width // self.vae_scale_factor)
587
- )
588
  mask = mask.to(device=device, dtype=dtype)
589
 
590
  batch_size = batch_size * num_images_per_prompt
@@ -618,6 +644,22 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
618
 
619
  # aligning device to prevent device errors when concating it with the latent model input
620
  masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  return mask, masked_image_latents
622
 
623
  @property
@@ -644,6 +686,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
644
  prompt_2: Optional[Union[str, List[str]]] = None,
645
  image: PipelineImageInput = None,
646
  mask_image: PipelineImageInput = None,
 
647
  height: Optional[int] = None,
648
  width: Optional[int] = None,
649
  padding_mask_crop: Optional[int] = None,
@@ -686,6 +729,9 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
686
  color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
687
  H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
688
  1)`, or `(H, W)`.
 
 
 
689
  height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
690
  The height in pixels of the generated image. This is set to 1024 by default for the best results.
691
  width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -766,19 +812,22 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
766
  self.check_inputs(
767
  prompt,
768
  prompt_2,
 
 
769
  strength,
770
  height,
771
  width,
 
772
  prompt_embeds=prompt_embeds,
773
  pooled_prompt_embeds=pooled_prompt_embeds,
774
  callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
 
775
  max_sequence_length=max_sequence_length,
776
  )
777
 
778
  self._guidance_scale = guidance_scale
779
  self._joint_attention_kwargs = joint_attention_kwargs
780
  self._interrupt = False
781
- is_strength_max = strength == 1.0
782
 
783
  # 2. Preprocess mask and image
784
  if padding_mask_crop is not None:
@@ -840,7 +889,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
840
  sigmas,
841
  mu=mu,
842
  )
843
- timesteps, num_inference_steps = self.get_timesteps(timesteps, num_inference_steps, strength, device)
844
 
845
  if num_inference_steps < 1:
846
  raise ValueError(
@@ -853,6 +902,8 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
853
  num_channels_latents = self.transformer.config.in_channels // 4
854
 
855
  latents, noise, original_image_latents, latent_image_ids = self.prepare_latents(
 
 
856
  batch_size * num_images_per_prompt,
857
  num_channels_latents,
858
  height,
@@ -861,9 +912,6 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
861
  device,
862
  generator,
863
  latents,
864
- init_image,
865
- latent_timestep,
866
- is_strength_max,
867
  )
868
 
869
  # start diff diff preparation
@@ -876,6 +924,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
876
  original_mask,
877
  masked_image,
878
  batch_size,
 
879
  num_images_per_prompt,
880
  height,
881
  width,
@@ -885,41 +934,29 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
885
  )
886
 
887
  mask_thresholds = torch.arange(num_inference_steps, dtype=original_mask.dtype) / num_inference_steps
888
- mask_thresholds = mask_thresholds.unsqueeze(1).unsqueeze(1).to(device)
889
- masks = (original_mask > mask_thresholds)
890
- masks = self._pack_latents(
891
- masks.repeat(num_channels_latents, 1, 1, 1).permute(1, 0, 2, 3),
892
- len(mask_thresholds),
893
- num_channels_latents,
894
- 2 * (int(height) // self.vae_scale_factor),
895
- 2 * (int(width) // self.vae_scale_factor),
896
- )
897
  # end diff diff preparation
898
 
899
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
900
- self._num_timesteps = len(timesteps)
 
 
 
 
 
 
901
 
902
  # 6. Denoising loop
903
- latents_dtype = latents.dtype
904
- # for 64 channel transformer only.
905
- image_latent = original_image_latents
906
  with self.progress_bar(total=num_inference_steps) as progress_bar:
907
  for i, t in enumerate(timesteps):
908
  if self.interrupt:
909
  continue
910
 
911
- timestep = t.expand(latents.shape[0]).to(latents_dtype)
912
-
913
- # handle guidance
914
- if self.transformer.config.guidance_embeds:
915
- guidance = torch.tensor([guidance_scale], device=device)
916
- guidance = guidance.expand(latents.shape[0])
917
- else:
918
- guidance = None
919
-
920
  noise_pred = self.transformer(
921
  hidden_states=latents,
922
- # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
923
  timestep=timestep / 1000,
924
  guidance=guidance,
925
  pooled_projections=pooled_prompt_embeds,
@@ -931,8 +968,12 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
931
  )[0]
932
 
933
  # compute the previous noisy sample x_t -> x_t-1
 
934
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
935
 
 
 
 
936
  if i < len(timesteps) - 1:
937
  noise_timestep = timesteps[i + 1]
938
  image_latent = self.scheduler.scale_noise(
@@ -981,4 +1022,4 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
981
  if not return_dict:
982
  return (image,)
983
 
984
- return FluxPipelineOutput(images=image)
 
16
  from typing import Any, Callable, Dict, List, Optional, Union
17
 
18
  import numpy as np
19
+ import PIL.Image
20
  import torch
21
  from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
22
 
 
81
  """
82
 
83
 
84
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
85
  def calculate_shift(
86
  image_seq_len,
87
  base_seq_len: int = 256,
 
237
  )
238
  self.default_sample_size = 64
239
 
240
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
241
  def _get_t5_prompt_embeds(
242
  self,
243
  prompt: Union[str, List[str]] = None,
 
284
 
285
  return prompt_embeds
286
 
287
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
288
  def _get_clip_prompt_embeds(
289
  self,
290
  prompt: Union[str, List[str]],
 
321
  prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
322
 
323
  # duplicate text embeddings for each generation per prompt, using mps friendly method
324
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
325
  prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
326
 
327
  return prompt_embeds
328
 
329
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
330
  def encode_prompt(
331
  self,
332
  prompt: Union[str, List[str]],
 
373
  scale_lora_layers(self.text_encoder_2, lora_scale)
374
 
375
  prompt = [prompt] if isinstance(prompt, str) else prompt
 
 
 
 
376
 
377
  if prompt_embeds is None:
378
  prompt_2 = prompt_2 or prompt
 
402
  unscale_lora_layers(self.text_encoder_2, lora_scale)
403
 
404
  dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
405
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
 
406
 
407
  return prompt_embeds, pooled_prompt_embeds, text_ids
408
 
409
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
410
  def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
411
  if isinstance(generator, list):
412
  image_latents = [
 
422
  return image_latents
423
 
424
  # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
425
+ def get_timesteps(self, num_inference_steps, strength, device):
426
  # get the original timestep using init_timestep
427
  init_timestep = min(num_inference_steps * strength, num_inference_steps)
428
 
429
  t_start = int(max(num_inference_steps - init_timestep, 0))
430
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
431
  if hasattr(self.scheduler, "set_begin_index"):
432
  self.scheduler.set_begin_index(t_start * self.scheduler.order)
433
 
 
437
  self,
438
  prompt,
439
  prompt_2,
440
+ image,
441
+ mask_image,
442
  strength,
443
  height,
444
  width,
445
+ output_type,
446
  prompt_embeds=None,
447
  pooled_prompt_embeds=None,
448
  callback_on_step_end_tensor_inputs=None,
449
+ padding_mask_crop=None,
450
  max_sequence_length=None,
451
  ):
452
  if strength < 0 or strength > 1:
 
486
  "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
487
  )
488
 
489
+ if padding_mask_crop is not None:
490
+ if not isinstance(image, PIL.Image.Image):
491
+ raise ValueError(
492
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
493
+ )
494
+ if not isinstance(mask_image, PIL.Image.Image):
495
+ raise ValueError(
496
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
497
+ f" {type(mask_image)}."
498
+ )
499
+ if output_type != "pil":
500
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
501
+
502
  if max_sequence_length is not None and max_sequence_length > 512:
503
  raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
504
 
505
  @staticmethod
506
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
507
  def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
508
  latent_image_ids = torch.zeros(height // 2, width // 2, 3)
509
  latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
 
511
 
512
  latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
513
 
 
514
  latent_image_ids = latent_image_ids.reshape(
515
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
516
  )
517
 
518
  return latent_image_ids.to(device=device, dtype=dtype)
519
 
520
  @staticmethod
521
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
522
  def _pack_latents(latents, batch_size, num_channels_latents, height, width):
523
  latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
524
  latents = latents.permute(0, 2, 4, 1, 3, 5)
 
527
  return latents
528
 
529
  @staticmethod
530
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
531
  def _unpack_latents(latents, height, width, vae_scale_factor):
532
  batch_size, num_patches, channels = latents.shape
533
 
 
543
 
544
  def prepare_latents(
545
  self,
546
+ image,
547
+ timestep,
548
  batch_size,
549
  num_channels_latents,
550
  height,
 
553
  device,
554
  generator,
555
  latents=None,
 
 
 
556
  ):
557
  if isinstance(generator, list) and len(generator) != batch_size:
558
  raise ValueError(
 
560
  f" size of {batch_size}. Make sure the batch size matches the length of the generators."
561
  )
562
 
 
 
 
 
 
 
563
  height = 2 * (int(height) // self.vae_scale_factor)
564
  width = 2 * (int(width) // self.vae_scale_factor)
565
 
566
  shape = (batch_size, num_channels_latents, height, width)
567
  latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
568
+
569
+ image = image.to(device=device, dtype=dtype)
570
+ image_latents = self._encode_vae_image(image=image, generator=generator)
571
+
572
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
573
+ # expand init_latents for batch_size
574
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
575
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
576
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
577
+ raise ValueError(
578
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
579
+ )
580
+ else:
581
+ image_latents = torch.cat([image_latents], dim=0)
582
 
583
  if latents is None:
584
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
585
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
586
  else:
587
+ noise = latents.to(device)
588
+ latents = noise
589
 
 
 
590
  noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
591
  image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
592
  latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
 
597
  mask,
598
  masked_image,
599
  batch_size,
600
+ num_channels_latents,
601
  num_images_per_prompt,
602
  height,
603
  width,
 
605
  device,
606
  generator,
607
  ):
608
+ height = 2 * (int(height) // self.vae_scale_factor)
609
+ width = 2 * (int(width) // self.vae_scale_factor)
610
  # resize the mask to latents shape as we concatenate the mask to the latents
611
  # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
612
  # and half precision
613
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
 
 
614
  mask = mask.to(device=device, dtype=dtype)
615
 
616
  batch_size = batch_size * num_images_per_prompt
 
644
 
645
  # aligning device to prevent device errors when concating it with the latent model input
646
  masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
647
+
648
+ masked_image_latents = self._pack_latents(
649
+ masked_image_latents,
650
+ batch_size,
651
+ num_channels_latents,
652
+ height,
653
+ width,
654
+ )
655
+ mask = self._pack_latents(
656
+ mask.repeat(1, num_channels_latents, 1, 1),
657
+ batch_size,
658
+ num_channels_latents,
659
+ height,
660
+ width,
661
+ )
662
+
663
  return mask, masked_image_latents
664
 
665
  @property
 
686
  prompt_2: Optional[Union[str, List[str]]] = None,
687
  image: PipelineImageInput = None,
688
  mask_image: PipelineImageInput = None,
689
+ masked_image_latents: PipelineImageInput = None,
690
  height: Optional[int] = None,
691
  width: Optional[int] = None,
692
  padding_mask_crop: Optional[int] = None,
 
729
  color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
730
  H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
731
  1)`, or `(H, W)`.
732
+ mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
733
+ `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
734
+ latents tensor will ge generated by `mask_image`.
735
  height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
736
  The height in pixels of the generated image. This is set to 1024 by default for the best results.
737
  width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
 
812
  self.check_inputs(
813
  prompt,
814
  prompt_2,
815
+ image,
816
+ mask_image,
817
  strength,
818
  height,
819
  width,
820
+ output_type=output_type,
821
  prompt_embeds=prompt_embeds,
822
  pooled_prompt_embeds=pooled_prompt_embeds,
823
  callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
824
+ padding_mask_crop=padding_mask_crop,
825
  max_sequence_length=max_sequence_length,
826
  )
827
 
828
  self._guidance_scale = guidance_scale
829
  self._joint_attention_kwargs = joint_attention_kwargs
830
  self._interrupt = False
 
831
 
832
  # 2. Preprocess mask and image
833
  if padding_mask_crop is not None:
 
889
  sigmas,
890
  mu=mu,
891
  )
892
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
893
 
894
  if num_inference_steps < 1:
895
  raise ValueError(
 
902
  num_channels_latents = self.transformer.config.in_channels // 4
903
 
904
  latents, noise, original_image_latents, latent_image_ids = self.prepare_latents(
905
+ init_image,
906
+ latent_timestep,
907
  batch_size * num_images_per_prompt,
908
  num_channels_latents,
909
  height,
 
912
  device,
913
  generator,
914
  latents,
 
 
 
915
  )
916
 
917
  # start diff diff preparation
 
924
  original_mask,
925
  masked_image,
926
  batch_size,
927
+ num_channels_latents,
928
  num_images_per_prompt,
929
  height,
930
  width,
 
934
  )
935
 
936
  mask_thresholds = torch.arange(num_inference_steps, dtype=original_mask.dtype) / num_inference_steps
937
+ mask_thresholds = mask_thresholds.reshape(-1, 1, 1, 1).to(device)
938
+ masks = original_mask > mask_thresholds
 
 
 
 
 
 
 
939
  # end diff diff preparation
940
 
941
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
942
+
943
+ # handle guidance
944
+ if self.transformer.config.guidance_embeds:
945
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
946
+ guidance = guidance.expand(latents.shape[0])
947
+ else:
948
+ guidance = None
949
 
950
  # 6. Denoising loop
 
 
 
951
  with self.progress_bar(total=num_inference_steps) as progress_bar:
952
  for i, t in enumerate(timesteps):
953
  if self.interrupt:
954
  continue
955
 
956
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
957
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
 
 
 
 
 
 
 
958
  noise_pred = self.transformer(
959
  hidden_states=latents,
 
960
  timestep=timestep / 1000,
961
  guidance=guidance,
962
  pooled_projections=pooled_prompt_embeds,
 
968
  )[0]
969
 
970
  # compute the previous noisy sample x_t -> x_t-1
971
+ latents_dtype = latents.dtype
972
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
973
 
974
+ # for 64 channel transformer only.
975
+ image_latent = original_image_latents
976
+
977
  if i < len(timesteps) - 1:
978
  noise_timestep = timesteps[i + 1]
979
  image_latent = self.scheduler.scale_noise(
 
1022
  if not return_dict:
1023
  return (image,)
1024
 
1025
+ return FluxPipelineOutput(images=image)