linoyts HF staff commited on
Commit
545388a
1 Parent(s): f5c6998

revert taking avg_diff out from attributes

Browse files
Files changed (1) hide show
  1. clip_slider_pipeline.py +24 -27
clip_slider_pipeline.py CHANGED
@@ -17,8 +17,8 @@ class CLIPSlider:
17
  iterations: int = 300,
18
  ):
19
 
20
- #self.device = device
21
- self.pipe = sd_pipe
22
  self.iterations = iterations
23
  if target_word != "" or opposite != "":
24
  self.avg_diff = self.find_latent_direction(target_word, opposite)
@@ -73,8 +73,6 @@ class CLIPSlider:
73
  only_pooler = False,
74
  normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
75
  correlation_weight_factor = 1.0,
76
- avg_diff = None,
77
- avg_diff_2nd = None,
78
  **pipeline_kwargs
79
  ):
80
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
@@ -85,14 +83,14 @@ class CLIPSlider:
85
  max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
86
  prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
87
 
88
- if avg_diff_2nd and normalize_scales:
89
  denominator = abs(scale) + abs(scale_2nd)
90
  scale = scale / denominator
91
  scale_2nd = scale_2nd / denominator
92
  if only_pooler:
93
- prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale
94
- if avg_diff_2nd:
95
- prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd
96
  else:
97
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
98
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
@@ -104,9 +102,9 @@ class CLIPSlider:
104
 
105
  # weights = torch.sigmoid((weights-0.5)*7)
106
  prompt_embeds = prompt_embeds + (
107
- weights * avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
108
- if avg_diff_2nd:
109
- prompt_embeds += weights * avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
110
 
111
 
112
  torch.manual_seed(seed)
@@ -200,8 +198,6 @@ class CLIPSliderXL(CLIPSlider):
200
  only_pooler = False,
201
  normalize_scales = False,
202
  correlation_weight_factor = 1.0,
203
- avg_diff = None,
204
- avg_diff_2nd= None,
205
  **pipeline_kwargs
206
  ):
207
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
@@ -236,15 +232,16 @@ class CLIPSliderXL(CLIPSlider):
236
  pooled_prompt_embeds = prompt_embeds[0]
237
  prompt_embeds = prompt_embeds.hidden_states[-2]
238
 
239
- if avg_diff_2nd and normalize_scales:
240
  denominator = abs(scale) + abs(scale_2nd)
241
  scale = scale / denominator
242
  scale_2nd = scale_2nd / denominator
243
  if only_pooler:
244
- prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff[0] * scale
245
- if avg_diff_2nd:
246
- prompt_embeds[:, toks.argmax()] += avg_diff_2nd[0] * scale_2nd
247
  else:
 
248
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
249
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
250
 
@@ -254,18 +251,18 @@ class CLIPSliderXL(CLIPSlider):
254
  standard_weights = torch.ones_like(weights)
255
 
256
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
257
- prompt_embeds = prompt_embeds + (weights * avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
258
- if avg_diff_2nd:
259
- prompt_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd)
260
  else:
261
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
262
 
263
  standard_weights = torch.ones_like(weights)
264
 
265
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
266
- prompt_embeds = prompt_embeds + (weights * avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
267
- if avg_diff_2nd:
268
- prompt_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd)
269
 
270
  bs_embed, seq_len, _ = prompt_embeds.shape
271
  prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
@@ -328,7 +325,7 @@ class CLIPSlider3(CLIPSlider):
328
  positives2 = torch.cat(positives2, dim=0)
329
  negatives2 = torch.cat(negatives2, dim=0)
330
  diffs2 = positives2 - negatives2
331
- avg_diff2 = diffs2.mean(0, keepdim=True)
332
  return (avg_diff, avg_diff2)
333
 
334
  def generate(self,
@@ -386,7 +383,7 @@ class CLIPSlider3(CLIPSlider):
386
  t5_prompt_embed_shape = prompt_embeds.shape[-1]
387
 
388
  if only_pooler:
389
- prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff[0] * scale
390
  else:
391
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
392
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
@@ -396,14 +393,14 @@ class CLIPSlider3(CLIPSlider):
396
  standard_weights = torch.ones_like(weights)
397
 
398
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
399
- prompt_embeds = prompt_embeds + (weights * avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
400
  else:
401
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
402
 
403
  standard_weights = torch.ones_like(weights)
404
 
405
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
406
- prompt_embeds = prompt_embeds + (weights * avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
407
 
408
  bs_embed, seq_len, _ = prompt_embeds.shape
409
  prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
 
17
  iterations: int = 300,
18
  ):
19
 
20
+ self.device = device
21
+ self.pipe = sd_pipe.to(self.device)
22
  self.iterations = iterations
23
  if target_word != "" or opposite != "":
24
  self.avg_diff = self.find_latent_direction(target_word, opposite)
 
73
  only_pooler = False,
74
  normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
75
  correlation_weight_factor = 1.0,
 
 
76
  **pipeline_kwargs
77
  ):
78
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
 
83
  max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
84
  prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
85
 
86
+ if self.avg_diff_2nd and normalize_scales:
87
  denominator = abs(scale) + abs(scale_2nd)
88
  scale = scale / denominator
89
  scale_2nd = scale_2nd / denominator
90
  if only_pooler:
91
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale
92
+ if self.avg_diff_2nd:
93
+ prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd
94
  else:
95
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
96
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
 
102
 
103
  # weights = torch.sigmoid((weights-0.5)*7)
104
  prompt_embeds = prompt_embeds + (
105
+ weights * self.avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
106
+ if self.avg_diff_2nd:
107
+ prompt_embeds += weights * self.avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
108
 
109
 
110
  torch.manual_seed(seed)
 
198
  only_pooler = False,
199
  normalize_scales = False,
200
  correlation_weight_factor = 1.0,
 
 
201
  **pipeline_kwargs
202
  ):
203
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
 
232
  pooled_prompt_embeds = prompt_embeds[0]
233
  prompt_embeds = prompt_embeds.hidden_states[-2]
234
 
235
+ if self.avg_diff_2nd and normalize_scales:
236
  denominator = abs(scale) + abs(scale_2nd)
237
  scale = scale / denominator
238
  scale_2nd = scale_2nd / denominator
239
  if only_pooler:
240
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale
241
+ if self.avg_diff_2nd:
242
+ prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd[0] * scale_2nd
243
  else:
244
+ print(self.avg_diff)
245
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
246
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
247
 
 
251
  standard_weights = torch.ones_like(weights)
252
 
253
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
254
+ prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
255
+ if self.avg_diff_2nd:
256
+ prompt_embeds += (weights * self.avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd)
257
  else:
258
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
259
 
260
  standard_weights = torch.ones_like(weights)
261
 
262
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
263
+ prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
264
+ if self.avg_diff_2nd:
265
+ prompt_embeds += (weights * self.avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd)
266
 
267
  bs_embed, seq_len, _ = prompt_embeds.shape
268
  prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
 
325
  positives2 = torch.cat(positives2, dim=0)
326
  negatives2 = torch.cat(negatives2, dim=0)
327
  diffs2 = positives2 - negatives2
328
+ avg_diff2 = diffs2.mean(0, keepdim=True)
329
  return (avg_diff, avg_diff2)
330
 
331
  def generate(self,
 
383
  t5_prompt_embed_shape = prompt_embeds.shape[-1]
384
 
385
  if only_pooler:
386
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale
387
  else:
388
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
389
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
 
393
  standard_weights = torch.ones_like(weights)
394
 
395
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
396
+ prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
397
  else:
398
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
399
 
400
  standard_weights = torch.ones_like(weights)
401
 
402
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
403
+ prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
404
 
405
  bs_embed, seq_len, _ = prompt_embeds.shape
406
  prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)