linoyts HF staff commited on
Commit
59ce957
1 Parent(s): 31f811d

Update clip_slider_pipeline.py

Browse files
Files changed (1) hide show
  1. clip_slider_pipeline.py +25 -22
clip_slider_pipeline.py CHANGED
@@ -73,6 +73,8 @@ class CLIPSlider:
73
  only_pooler = False,
74
  normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
75
  correlation_weight_factor = 1.0,
 
 
76
  **pipeline_kwargs
77
  ):
78
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
@@ -83,14 +85,14 @@ class CLIPSlider:
83
  max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
84
  prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
85
 
86
- if self.avg_diff_2nd and normalize_scales:
87
  denominator = abs(scale) + abs(scale_2nd)
88
  scale = scale / denominator
89
  scale_2nd = scale_2nd / denominator
90
  if only_pooler:
91
- prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale
92
- if self.avg_diff_2nd:
93
- prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd
94
  else:
95
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
96
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
@@ -102,9 +104,9 @@ class CLIPSlider:
102
 
103
  # weights = torch.sigmoid((weights-0.5)*7)
104
  prompt_embeds = prompt_embeds + (
105
- weights * self.avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
106
- if self.avg_diff_2nd:
107
- prompt_embeds += weights * self.avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
108
 
109
 
110
  torch.manual_seed(seed)
@@ -198,6 +200,8 @@ class CLIPSliderXL(CLIPSlider):
198
  only_pooler = False,
199
  normalize_scales = False,
200
  correlation_weight_factor = 1.0,
 
 
201
  **pipeline_kwargs
202
  ):
203
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
@@ -232,16 +236,15 @@ class CLIPSliderXL(CLIPSlider):
232
  pooled_prompt_embeds = prompt_embeds[0]
233
  prompt_embeds = prompt_embeds.hidden_states[-2]
234
 
235
- if self.avg_diff_2nd and normalize_scales:
236
  denominator = abs(scale) + abs(scale_2nd)
237
  scale = scale / denominator
238
  scale_2nd = scale_2nd / denominator
239
  if only_pooler:
240
- prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale
241
- if self.avg_diff_2nd:
242
- prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd[0] * scale_2nd
243
  else:
244
- print(self.avg_diff)
245
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
246
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
247
 
@@ -251,18 +254,18 @@ class CLIPSliderXL(CLIPSlider):
251
  standard_weights = torch.ones_like(weights)
252
 
253
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
254
- prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
255
- if self.avg_diff_2nd:
256
- prompt_embeds += (weights * self.avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd)
257
  else:
258
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
259
 
260
  standard_weights = torch.ones_like(weights)
261
 
262
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
263
- prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
264
- if self.avg_diff_2nd:
265
- prompt_embeds += (weights * self.avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd)
266
 
267
  bs_embed, seq_len, _ = prompt_embeds.shape
268
  prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
@@ -325,7 +328,7 @@ class CLIPSlider3(CLIPSlider):
325
  positives2 = torch.cat(positives2, dim=0)
326
  negatives2 = torch.cat(negatives2, dim=0)
327
  diffs2 = positives2 - negatives2
328
- avg_diff2 = diffs2.mean(0, keepdim=True)
329
  return (avg_diff, avg_diff2)
330
 
331
  def generate(self,
@@ -383,7 +386,7 @@ class CLIPSlider3(CLIPSlider):
383
  t5_prompt_embed_shape = prompt_embeds.shape[-1]
384
 
385
  if only_pooler:
386
- prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale
387
  else:
388
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
389
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
@@ -393,14 +396,14 @@ class CLIPSlider3(CLIPSlider):
393
  standard_weights = torch.ones_like(weights)
394
 
395
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
396
- prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
397
  else:
398
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
399
 
400
  standard_weights = torch.ones_like(weights)
401
 
402
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
403
- prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
404
 
405
  bs_embed, seq_len, _ = prompt_embeds.shape
406
  prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
 
73
  only_pooler = False,
74
  normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
75
  correlation_weight_factor = 1.0,
76
+ avg_diff = None,
77
+ avg_diff_2nd = None,
78
  **pipeline_kwargs
79
  ):
80
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
 
85
  max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
86
  prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
87
 
88
+ if avg_diff_2nd and normalize_scales:
89
  denominator = abs(scale) + abs(scale_2nd)
90
  scale = scale / denominator
91
  scale_2nd = scale_2nd / denominator
92
  if only_pooler:
93
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale
94
+ if avg_diff_2nd:
95
+ prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd
96
  else:
97
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
98
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
 
104
 
105
  # weights = torch.sigmoid((weights-0.5)*7)
106
  prompt_embeds = prompt_embeds + (
107
+ weights * avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
108
+ if avg_diff_2nd:
109
+ prompt_embeds += weights * avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
110
 
111
 
112
  torch.manual_seed(seed)
 
200
  only_pooler = False,
201
  normalize_scales = False,
202
  correlation_weight_factor = 1.0,
203
+ avg_diff = None,
204
+ avg_diff_2nd= None,
205
  **pipeline_kwargs
206
  ):
207
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
 
236
  pooled_prompt_embeds = prompt_embeds[0]
237
  prompt_embeds = prompt_embeds.hidden_states[-2]
238
 
239
+ if avg_diff_2nd and normalize_scales:
240
  denominator = abs(scale) + abs(scale_2nd)
241
  scale = scale / denominator
242
  scale_2nd = scale_2nd / denominator
243
  if only_pooler:
244
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff[0] * scale
245
+ if avg_diff_2nd:
246
+ prompt_embeds[:, toks.argmax()] += avg_diff_2nd[0] * scale_2nd
247
  else:
 
248
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
249
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
250
 
 
254
  standard_weights = torch.ones_like(weights)
255
 
256
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
257
+ prompt_embeds = prompt_embeds + (weights * avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
258
+ if avg_diff_2nd:
259
+ prompt_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd)
260
  else:
261
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
262
 
263
  standard_weights = torch.ones_like(weights)
264
 
265
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
266
+ prompt_embeds = prompt_embeds + (weights * avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
267
+ if avg_diff_2nd:
268
+ prompt_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd)
269
 
270
  bs_embed, seq_len, _ = prompt_embeds.shape
271
  prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
 
328
  positives2 = torch.cat(positives2, dim=0)
329
  negatives2 = torch.cat(negatives2, dim=0)
330
  diffs2 = positives2 - negatives2
331
+ avg_diff2 = diffs2.mean(0, keepdim=True)
332
  return (avg_diff, avg_diff2)
333
 
334
  def generate(self,
 
386
  t5_prompt_embed_shape = prompt_embeds.shape[-1]
387
 
388
  if only_pooler:
389
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff[0] * scale
390
  else:
391
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
392
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
 
396
  standard_weights = torch.ones_like(weights)
397
 
398
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
399
+ prompt_embeds = prompt_embeds + (weights * avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
400
  else:
401
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
402
 
403
  standard_weights = torch.ones_like(weights)
404
 
405
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
406
+ prompt_embeds = prompt_embeds + (weights * avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
407
 
408
  bs_embed, seq_len, _ = prompt_embeds.shape
409
  prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)