linoyts HF staff commited on
Commit
4ea931b
·
verified ·
1 Parent(s): a7295c8

Update clip_slider_pipeline.py

Browse files
Files changed (1) hide show
  1. clip_slider_pipeline.py +30 -20
clip_slider_pipeline.py CHANGED
@@ -10,17 +10,20 @@ class CLIPSlider:
10
  self,
11
  sd_pipe,
12
  device: torch.device,
13
- target_word: str,
14
- opposite: str,
15
  target_word_2nd: str = "",
16
  opposite_2nd: str = "",
17
  iterations: int = 300,
18
  ):
19
 
20
  self.device = device
21
- self.pipe = sd_pipe.to(self.device)
22
  self.iterations = iterations
23
- self.avg_diff = self.find_latent_direction(target_word, opposite)
 
 
 
24
  if target_word_2nd != "" or opposite_2nd != "":
25
  self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd)
26
  else:
@@ -29,12 +32,15 @@ class CLIPSlider:
29
 
30
  def find_latent_direction(self,
31
  target_word:str,
32
- opposite:str):
33
 
34
  # lets identify a latent direction by taking differences between opposites
35
  # target_word = "happy"
36
  # opposite = "sad"
37
-
 
 
 
38
 
39
  with torch.no_grad():
40
  positives = []
@@ -70,6 +76,8 @@ class CLIPSlider:
70
  only_pooler = False,
71
  normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
72
  correlation_weight_factor = 1.0,
 
 
73
  **pipeline_kwargs
74
  ):
75
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
@@ -80,14 +88,14 @@ class CLIPSlider:
80
  max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
81
  prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
82
 
83
- if self.avg_diff_2nd and normalize_scales:
84
  denominator = abs(scale) + abs(scale_2nd)
85
  scale = scale / denominator
86
  scale_2nd = scale_2nd / denominator
87
  if only_pooler:
88
- prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale
89
- if self.avg_diff_2nd:
90
- prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd
91
  else:
92
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
93
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
@@ -99,9 +107,9 @@ class CLIPSlider:
99
 
100
  # weights = torch.sigmoid((weights-0.5)*7)
101
  prompt_embeds = prompt_embeds + (
102
- weights * self.avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
103
- if self.avg_diff_2nd:
104
- prompt_embeds += weights * self.avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
105
 
106
 
107
  torch.manual_seed(seed)
@@ -399,6 +407,8 @@ class T5SliderFlux(CLIPSlider):
399
  only_pooler = False,
400
  normalize_scales = False,
401
  correlation_weight_factor = 1.0,
 
 
402
  **pipeline_kwargs
403
  ):
404
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
@@ -438,14 +448,14 @@ class T5SliderFlux(CLIPSlider):
438
  dtype = self.pipe.text_encoder_2.dtype
439
  prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device)
440
  print("1", prompt_embeds.shape)
441
- if self.avg_diff_2nd and normalize_scales:
442
  denominator = abs(scale) + abs(scale_2nd)
443
  scale = scale / denominator
444
  scale_2nd = scale_2nd / denominator
445
  if only_pooler:
446
- prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale
447
- if self.avg_diff_2nd:
448
- prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd
449
  else:
450
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
451
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
@@ -457,11 +467,11 @@ class T5SliderFlux(CLIPSlider):
457
 
458
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
459
  prompt_embeds = prompt_embeds + (
460
- weights * self.avg_diff * scale)
461
  print("2", prompt_embeds.shape)
462
- if self.avg_diff_2nd:
463
  prompt_embeds += (
464
- weights * self.avg_diff_2nd * scale_2nd)
465
 
466
  torch.manual_seed(seed)
467
  images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
 
10
  self,
11
  sd_pipe,
12
  device: torch.device,
13
+ target_word: str = "",
14
+ opposite: str = "",
15
  target_word_2nd: str = "",
16
  opposite_2nd: str = "",
17
  iterations: int = 300,
18
  ):
19
 
20
  self.device = device
21
+ self.pipe = sd_pipe.to(self.device, torch.float16)
22
  self.iterations = iterations
23
+ if target_word != "" or opposite != "":
24
+ self.avg_diff = self.find_latent_direction(target_word, opposite)
25
+ else:
26
+ self.avg_diff = None
27
  if target_word_2nd != "" or opposite_2nd != "":
28
  self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd)
29
  else:
 
32
 
33
  def find_latent_direction(self,
34
  target_word:str,
35
+ opposite:str, num_iterations: int = None):
36
 
37
  # lets identify a latent direction by taking differences between opposites
38
  # target_word = "happy"
39
  # opposite = "sad"
40
+ if num_iterations is not None:
41
+ iterations = num_iterations
42
+ else:
43
+ iterations = self.iterations
44
 
45
  with torch.no_grad():
46
  positives = []
 
76
  only_pooler = False,
77
  normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
78
  correlation_weight_factor = 1.0,
79
+ avg_diff = None,
80
+ avg_diff_2nd = None,
81
  **pipeline_kwargs
82
  ):
83
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
 
88
  max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
89
  prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
90
 
91
+ if avg_diff_2nd and normalize_scales:
92
  denominator = abs(scale) + abs(scale_2nd)
93
  scale = scale / denominator
94
  scale_2nd = scale_2nd / denominator
95
  if only_pooler:
96
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale
97
+ if avg_diff_2nd:
98
+ prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd
99
  else:
100
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
101
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
 
107
 
108
  # weights = torch.sigmoid((weights-0.5)*7)
109
  prompt_embeds = prompt_embeds + (
110
+ weights * avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
111
+ if avg_diff_2nd:
112
+ prompt_embeds += weights * avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
113
 
114
 
115
  torch.manual_seed(seed)
 
407
  only_pooler = False,
408
  normalize_scales = False,
409
  correlation_weight_factor = 1.0,
410
+ avg_diff = None,
411
+ avg_diff_2nd = None,
412
  **pipeline_kwargs
413
  ):
414
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
 
448
  dtype = self.pipe.text_encoder_2.dtype
449
  prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device)
450
  print("1", prompt_embeds.shape)
451
+ if avg_diff_2nd and normalize_scales:
452
  denominator = abs(scale) + abs(scale_2nd)
453
  scale = scale / denominator
454
  scale_2nd = scale_2nd / denominator
455
  if only_pooler:
456
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale
457
+ if avg_diff_2nd:
458
+ prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd
459
  else:
460
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
461
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
 
467
 
468
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
469
  prompt_embeds = prompt_embeds + (
470
+ weights * avg_diff * scale)
471
  print("2", prompt_embeds.shape)
472
+ if avg_diff_2nd:
473
  prompt_embeds += (
474
+ weights * avg_diff_2nd * scale_2nd)
475
 
476
  torch.manual_seed(seed)
477
  images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,