Spaces:
Sleeping
Sleeping
Update clip_slider_pipeline.py
Browse files- clip_slider_pipeline.py +25 -22
clip_slider_pipeline.py
CHANGED
@@ -73,6 +73,8 @@ class CLIPSlider:
|
|
73 |
only_pooler = False,
|
74 |
normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
|
75 |
correlation_weight_factor = 1.0,
|
|
|
|
|
76 |
**pipeline_kwargs
|
77 |
):
|
78 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
@@ -83,14 +85,14 @@ class CLIPSlider:
|
|
83 |
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
84 |
prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
|
85 |
|
86 |
-
if
|
87 |
denominator = abs(scale) + abs(scale_2nd)
|
88 |
scale = scale / denominator
|
89 |
scale_2nd = scale_2nd / denominator
|
90 |
if only_pooler:
|
91 |
-
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] +
|
92 |
-
if
|
93 |
-
prompt_embeds[:, toks.argmax()] +=
|
94 |
else:
|
95 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
96 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
@@ -102,9 +104,9 @@ class CLIPSlider:
|
|
102 |
|
103 |
# weights = torch.sigmoid((weights-0.5)*7)
|
104 |
prompt_embeds = prompt_embeds + (
|
105 |
-
weights *
|
106 |
-
if
|
107 |
-
prompt_embeds += weights *
|
108 |
|
109 |
|
110 |
torch.manual_seed(seed)
|
@@ -198,6 +200,8 @@ class CLIPSliderXL(CLIPSlider):
|
|
198 |
only_pooler = False,
|
199 |
normalize_scales = False,
|
200 |
correlation_weight_factor = 1.0,
|
|
|
|
|
201 |
**pipeline_kwargs
|
202 |
):
|
203 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
@@ -232,16 +236,15 @@ class CLIPSliderXL(CLIPSlider):
|
|
232 |
pooled_prompt_embeds = prompt_embeds[0]
|
233 |
prompt_embeds = prompt_embeds.hidden_states[-2]
|
234 |
|
235 |
-
if
|
236 |
denominator = abs(scale) + abs(scale_2nd)
|
237 |
scale = scale / denominator
|
238 |
scale_2nd = scale_2nd / denominator
|
239 |
if only_pooler:
|
240 |
-
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] +
|
241 |
-
if
|
242 |
-
prompt_embeds[:, toks.argmax()] +=
|
243 |
else:
|
244 |
-
print(self.avg_diff)
|
245 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
246 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
247 |
|
@@ -251,18 +254,18 @@ class CLIPSliderXL(CLIPSlider):
|
|
251 |
standard_weights = torch.ones_like(weights)
|
252 |
|
253 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
254 |
-
prompt_embeds = prompt_embeds + (weights *
|
255 |
-
if
|
256 |
-
prompt_embeds += (weights *
|
257 |
else:
|
258 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
259 |
|
260 |
standard_weights = torch.ones_like(weights)
|
261 |
|
262 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
263 |
-
prompt_embeds = prompt_embeds + (weights *
|
264 |
-
if
|
265 |
-
prompt_embeds += (weights *
|
266 |
|
267 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
268 |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
@@ -325,7 +328,7 @@ class CLIPSlider3(CLIPSlider):
|
|
325 |
positives2 = torch.cat(positives2, dim=0)
|
326 |
negatives2 = torch.cat(negatives2, dim=0)
|
327 |
diffs2 = positives2 - negatives2
|
328 |
-
avg_diff2 = diffs2.mean(0, keepdim=True)
|
329 |
return (avg_diff, avg_diff2)
|
330 |
|
331 |
def generate(self,
|
@@ -383,7 +386,7 @@ class CLIPSlider3(CLIPSlider):
|
|
383 |
t5_prompt_embed_shape = prompt_embeds.shape[-1]
|
384 |
|
385 |
if only_pooler:
|
386 |
-
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] +
|
387 |
else:
|
388 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
389 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
@@ -393,14 +396,14 @@ class CLIPSlider3(CLIPSlider):
|
|
393 |
standard_weights = torch.ones_like(weights)
|
394 |
|
395 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
396 |
-
prompt_embeds = prompt_embeds + (weights *
|
397 |
else:
|
398 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
399 |
|
400 |
standard_weights = torch.ones_like(weights)
|
401 |
|
402 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
403 |
-
prompt_embeds = prompt_embeds + (weights *
|
404 |
|
405 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
406 |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
|
|
73 |
only_pooler = False,
|
74 |
normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
|
75 |
correlation_weight_factor = 1.0,
|
76 |
+
avg_diff = None,
|
77 |
+
avg_diff_2nd = None,
|
78 |
**pipeline_kwargs
|
79 |
):
|
80 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
|
85 |
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
86 |
prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
|
87 |
|
88 |
+
if avg_diff_2nd and normalize_scales:
|
89 |
denominator = abs(scale) + abs(scale_2nd)
|
90 |
scale = scale / denominator
|
91 |
scale_2nd = scale_2nd / denominator
|
92 |
if only_pooler:
|
93 |
+
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale
|
94 |
+
if avg_diff_2nd:
|
95 |
+
prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd
|
96 |
else:
|
97 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
98 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
|
|
104 |
|
105 |
# weights = torch.sigmoid((weights-0.5)*7)
|
106 |
prompt_embeds = prompt_embeds + (
|
107 |
+
weights * avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
108 |
+
if avg_diff_2nd:
|
109 |
+
prompt_embeds += weights * avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
|
110 |
|
111 |
|
112 |
torch.manual_seed(seed)
|
|
|
200 |
only_pooler = False,
|
201 |
normalize_scales = False,
|
202 |
correlation_weight_factor = 1.0,
|
203 |
+
avg_diff = None,
|
204 |
+
avg_diff_2nd= None,
|
205 |
**pipeline_kwargs
|
206 |
):
|
207 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
|
236 |
pooled_prompt_embeds = prompt_embeds[0]
|
237 |
prompt_embeds = prompt_embeds.hidden_states[-2]
|
238 |
|
239 |
+
if avg_diff_2nd and normalize_scales:
|
240 |
denominator = abs(scale) + abs(scale_2nd)
|
241 |
scale = scale / denominator
|
242 |
scale_2nd = scale_2nd / denominator
|
243 |
if only_pooler:
|
244 |
+
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff[0] * scale
|
245 |
+
if avg_diff_2nd:
|
246 |
+
prompt_embeds[:, toks.argmax()] += avg_diff_2nd[0] * scale_2nd
|
247 |
else:
|
|
|
248 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
249 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
250 |
|
|
|
254 |
standard_weights = torch.ones_like(weights)
|
255 |
|
256 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
257 |
+
prompt_embeds = prompt_embeds + (weights * avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
258 |
+
if avg_diff_2nd:
|
259 |
+
prompt_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd)
|
260 |
else:
|
261 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
262 |
|
263 |
standard_weights = torch.ones_like(weights)
|
264 |
|
265 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
266 |
+
prompt_embeds = prompt_embeds + (weights * avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
|
267 |
+
if avg_diff_2nd:
|
268 |
+
prompt_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd)
|
269 |
|
270 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
271 |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
|
|
328 |
positives2 = torch.cat(positives2, dim=0)
|
329 |
negatives2 = torch.cat(negatives2, dim=0)
|
330 |
diffs2 = positives2 - negatives2
|
331 |
+
avg_diff2 = diffs2.mean(0, keepdim=True)
|
332 |
return (avg_diff, avg_diff2)
|
333 |
|
334 |
def generate(self,
|
|
|
386 |
t5_prompt_embed_shape = prompt_embeds.shape[-1]
|
387 |
|
388 |
if only_pooler:
|
389 |
+
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff[0] * scale
|
390 |
else:
|
391 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
392 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
|
|
396 |
standard_weights = torch.ones_like(weights)
|
397 |
|
398 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
399 |
+
prompt_embeds = prompt_embeds + (weights * avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
400 |
else:
|
401 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
402 |
|
403 |
standard_weights = torch.ones_like(weights)
|
404 |
|
405 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
406 |
+
prompt_embeds = prompt_embeds + (weights * avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
|
407 |
|
408 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
409 |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|