linoyts HF staff commited on
Commit
85c4d8a
1 Parent(s): b0b3cd4

Update ledits/pipeline_leditspp_stable_diffusion_xl.py

Browse files
ledits/pipeline_leditspp_stable_diffusion_xl.py CHANGED
@@ -882,6 +882,8 @@ class LEditsPPPipelineStableDiffusionXL(
882
  avg_diff_2 = None,
883
  correlation_weight_factor = 0.7,
884
  scale=2,
 
 
885
  **kwargs,
886
  ):
887
  r"""
@@ -1014,9 +1016,10 @@ class LEditsPPPipelineStableDiffusionXL(
1014
 
1015
  eta = self.eta
1016
  num_images_per_prompt = 1
1017
- latents = self.init_latents
 
1018
 
1019
- zs = self.zs
1020
  self.scheduler.set_timesteps(len(self.scheduler.timesteps))
1021
 
1022
  if use_intersect_mask:
@@ -1094,6 +1097,7 @@ class LEditsPPPipelineStableDiffusionXL(
1094
  # self.scheduler.set_timesteps(num_inference_steps, device=device)
1095
 
1096
  timesteps = self.inversion_steps
 
1097
  t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
1098
 
1099
  if use_cross_attn_mask:
@@ -1698,7 +1702,8 @@ class LEditsPPPipelineStableDiffusionXL(
1698
  if num_zero_noise_steps > 0:
1699
  zs[-num_zero_noise_steps:] = torch.zeros_like(zs[-num_zero_noise_steps:])
1700
  self.zs = zs
1701
- return LEditsPPInversionPipelineOutput(images=resized, vae_reconstruction_images=image_rec)
 
1702
 
1703
 
1704
  # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.rescale_noise_cfg
 
882
  avg_diff_2 = None,
883
  correlation_weight_factor = 0.7,
884
  scale=2,
885
+ init_latents: [torch.Tensor] = None,
886
+ zs: [torch.Tensor] = None,
887
  **kwargs,
888
  ):
889
  r"""
 
1016
 
1017
  eta = self.eta
1018
  num_images_per_prompt = 1
1019
+ #latents = self.init_latents
1020
+ latents = init_latents
1021
 
1022
+ #zs = self.zs
1023
  self.scheduler.set_timesteps(len(self.scheduler.timesteps))
1024
 
1025
  if use_intersect_mask:
 
1097
  # self.scheduler.set_timesteps(num_inference_steps, device=device)
1098
 
1099
  timesteps = self.inversion_steps
1100
+ timesteps = inversion_steps
1101
  t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
1102
 
1103
  if use_cross_attn_mask:
 
1702
  if num_zero_noise_steps > 0:
1703
  zs[-num_zero_noise_steps:] = torch.zeros_like(zs[-num_zero_noise_steps:])
1704
  self.zs = zs
1705
+ #return LEditsPPInversionPipelineOutput(images=resized, vae_reconstruction_images=image_rec)
1706
+ return xts[-1], zs
1707
 
1708
 
1709
  # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.rescale_noise_cfg