linoyts HF staff commited on
Commit
9efa518
1 Parent(s): 54507b8

Update ledits/pipeline_leditspp_stable_diffusion_xl.py

Browse files
ledits/pipeline_leditspp_stable_diffusion_xl.py CHANGED
@@ -613,11 +613,10 @@ class LEditsPPPipelineStableDiffusionXL(
613
  else:
614
  # "2" because SDXL always indexes from the penultimate layer.
615
  edit_concepts_embeds = edit_concepts_embeds.hidden_states[-(clip_skip + 2)]
616
-
617
- print("SHALOM???")
618
- if avg_diff is not None and avg_diff_2 is not None:
619
- #scale=3
620
- print("SHALOM")
621
  normed_prompt_embeds = edit_concepts_embeds / edit_concepts_embeds.norm(dim=-1, keepdim=True)
622
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
623
  if i == 0:
@@ -626,14 +625,26 @@ class LEditsPPPipelineStableDiffusionXL(
626
  standard_weights = torch.ones_like(weights)
627
 
628
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
629
- edit_concepts_embeds = edit_concepts_embeds + (weights * avg_diff[None, :].repeat(1,tokenizer.model_max_length, 1) * scale)
 
 
 
 
 
 
630
  else:
631
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
632
 
633
  standard_weights = torch.ones_like(weights)
634
 
635
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
636
- edit_concepts_embeds = edit_concepts_embeds + (weights * avg_diff_2[None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
 
 
 
 
 
 
637
 
638
  edit_prompt_embeds_list.append(edit_concepts_embeds)
639
  i+=1
 
613
  else:
614
  # "2" because SDXL always indexes from the penultimate layer.
615
  edit_concepts_embeds = edit_concepts_embeds.hidden_states[-(clip_skip + 2)]
616
+
617
+
618
+ if avg_diff is not None:
619
+
 
620
  normed_prompt_embeds = edit_concepts_embeds / edit_concepts_embeds.norm(dim=-1, keepdim=True)
621
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
622
  if i == 0:
 
625
  standard_weights = torch.ones_like(weights)
626
 
627
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
628
+ edit_concepts_embeds = edit_concepts_embeds + (
629
+ weights * avg_diff[0][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
630
+
631
+ if avg_diff_2nd is not None:
632
+ edit_concepts_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1,
633
+ self.pipe.tokenizer.model_max_length,
634
+ 1) * scale_2nd)
635
  else:
636
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
637
 
638
  standard_weights = torch.ones_like(weights)
639
 
640
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
641
+ edit_concepts_embeds = edit_concepts_embeds + (
642
+ weights * avg_diff[1][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
643
+ if avg_diff_2nd is not None:
644
+ edit_concepts_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1,
645
+ self.pipe.tokenizer_2.model_max_length,
646
+ 1) * scale_2nd)
647
+
648
 
649
  edit_prompt_embeds_list.append(edit_concepts_embeds)
650
  i+=1