Spaces:
Sleeping
Sleeping
Create clip_slider_pipeline.py
Browse files- clip_slider_pipeline.py +421 -0
clip_slider_pipeline.py
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import diffusers
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
from tqdm import tqdm
|
5 |
+
from constants import SUBJECTS, MEDIUMS
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
class CLIPSlider:
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
sd_pipe,
|
12 |
+
device: torch.device,
|
13 |
+
target_word: str,
|
14 |
+
opposite: str,
|
15 |
+
target_word_2nd: str = "",
|
16 |
+
opposite_2nd: str = "",
|
17 |
+
iterations: int = 300,
|
18 |
+
):
|
19 |
+
|
20 |
+
self.device = device
|
21 |
+
self.pipe = sd_pipe.to(self.device)
|
22 |
+
self.iterations = iterations
|
23 |
+
self.avg_diff = self.find_latent_direction(target_word, opposite)
|
24 |
+
if target_word_2nd != "" or opposite_2nd != "":
|
25 |
+
self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd)
|
26 |
+
else:
|
27 |
+
self.avg_diff_2nd = None
|
28 |
+
|
29 |
+
|
30 |
+
def find_latent_direction(self,
|
31 |
+
target_word:str,
|
32 |
+
opposite:str):
|
33 |
+
|
34 |
+
# lets identify a latent direction by taking differences between opposites
|
35 |
+
# target_word = "happy"
|
36 |
+
# opposite = "sad"
|
37 |
+
|
38 |
+
|
39 |
+
with torch.no_grad():
|
40 |
+
positives = []
|
41 |
+
negatives = []
|
42 |
+
for i in tqdm(range(self.iterations)):
|
43 |
+
medium = random.choice(MEDIUMS)
|
44 |
+
subject = random.choice(SUBJECTS)
|
45 |
+
pos_prompt = f"a {medium} of a {target_word} {subject}"
|
46 |
+
neg_prompt = f"a {medium} of a {opposite} {subject}"
|
47 |
+
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
48 |
+
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
49 |
+
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
50 |
+
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
51 |
+
pos = self.pipe.text_encoder(pos_toks).pooler_output
|
52 |
+
neg = self.pipe.text_encoder(neg_toks).pooler_output
|
53 |
+
positives.append(pos)
|
54 |
+
negatives.append(neg)
|
55 |
+
|
56 |
+
positives = torch.cat(positives, dim=0)
|
57 |
+
negatives = torch.cat(negatives, dim=0)
|
58 |
+
|
59 |
+
diffs = positives - negatives
|
60 |
+
|
61 |
+
avg_diff = diffs.mean(0, keepdim=True)
|
62 |
+
return avg_diff
|
63 |
+
|
64 |
+
|
65 |
+
def generate(self,
|
66 |
+
prompt = "a photo of a house",
|
67 |
+
scale = 2.,
|
68 |
+
scale_2nd = 0., # scale for the 2nd dim directions when avg_diff_2nd is not None
|
69 |
+
seed = 15,
|
70 |
+
only_pooler = False,
|
71 |
+
normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
|
72 |
+
correlation_weight_factor = 1.0,
|
73 |
+
**pipeline_kwargs
|
74 |
+
):
|
75 |
+
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
76 |
+
# if pooler token only [-4,4] work well
|
77 |
+
|
78 |
+
with torch.no_grad():
|
79 |
+
toks = self.pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True,
|
80 |
+
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
81 |
+
prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
|
82 |
+
|
83 |
+
if self.avg_diff_2nd and normalize_scales:
|
84 |
+
denominator = abs(scale) + abs(scale_2nd)
|
85 |
+
scale = scale / denominator
|
86 |
+
scale_2nd = scale_2nd / denominator
|
87 |
+
if only_pooler:
|
88 |
+
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale
|
89 |
+
if self.avg_diff_2nd:
|
90 |
+
prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd
|
91 |
+
else:
|
92 |
+
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
93 |
+
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
94 |
+
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768)
|
95 |
+
|
96 |
+
standard_weights = torch.ones_like(weights)
|
97 |
+
|
98 |
+
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
99 |
+
|
100 |
+
# weights = torch.sigmoid((weights-0.5)*7)
|
101 |
+
prompt_embeds = prompt_embeds + (
|
102 |
+
weights * self.avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
103 |
+
if self.avg_diff_2nd:
|
104 |
+
prompt_embeds += weights * self.avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
|
105 |
+
|
106 |
+
|
107 |
+
torch.manual_seed(seed)
|
108 |
+
image = self.pipe(prompt_embeds=prompt_embeds, **pipeline_kwargs).images
|
109 |
+
|
110 |
+
return image
|
111 |
+
|
112 |
+
def spectrum(self,
|
113 |
+
prompt="a photo of a house",
|
114 |
+
low_scale=-2,
|
115 |
+
low_scale_2nd=-2,
|
116 |
+
high_scale=2,
|
117 |
+
high_scale_2nd=2,
|
118 |
+
steps=5,
|
119 |
+
seed=15,
|
120 |
+
only_pooler=False,
|
121 |
+
normalize_scales=False,
|
122 |
+
correlation_weight_factor=1.0,
|
123 |
+
**pipeline_kwargs
|
124 |
+
):
|
125 |
+
|
126 |
+
images = []
|
127 |
+
for i in range(steps):
|
128 |
+
scale = low_scale + (high_scale - low_scale) * i / (steps - 1)
|
129 |
+
scale_2nd = low_scale_2nd + (high_scale_2nd - low_scale_2nd) * i / (steps - 1)
|
130 |
+
image = self.generate(prompt, scale, scale_2nd, seed, only_pooler, normalize_scales, correlation_weight_factor, **pipeline_kwargs)
|
131 |
+
images.append(image[0])
|
132 |
+
|
133 |
+
canvas = Image.new('RGB', (640 * steps, 640))
|
134 |
+
for i, im in enumerate(images):
|
135 |
+
canvas.paste(im, (640 * i, 0))
|
136 |
+
|
137 |
+
return canvas
|
138 |
+
|
139 |
+
class CLIPSliderXL(CLIPSlider):
|
140 |
+
|
141 |
+
def find_latent_direction(self,
|
142 |
+
target_word:str,
|
143 |
+
opposite:str):
|
144 |
+
|
145 |
+
# lets identify a latent direction by taking differences between opposites
|
146 |
+
# target_word = "happy"
|
147 |
+
# opposite = "sad"
|
148 |
+
|
149 |
+
|
150 |
+
with torch.no_grad():
|
151 |
+
positives = []
|
152 |
+
negatives = []
|
153 |
+
positives2 = []
|
154 |
+
negatives2 = []
|
155 |
+
for i in tqdm(range(self.iterations)):
|
156 |
+
medium = random.choice(MEDIUMS)
|
157 |
+
subject = random.choice(SUBJECTS)
|
158 |
+
pos_prompt = f"a {medium} of a {target_word} {subject}"
|
159 |
+
neg_prompt = f"a {medium} of a {opposite} {subject}"
|
160 |
+
|
161 |
+
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
162 |
+
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
163 |
+
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
164 |
+
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
165 |
+
pos = self.pipe.text_encoder(pos_toks).pooler_output
|
166 |
+
neg = self.pipe.text_encoder(neg_toks).pooler_output
|
167 |
+
positives.append(pos)
|
168 |
+
negatives.append(neg)
|
169 |
+
|
170 |
+
pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
171 |
+
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
|
172 |
+
neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
173 |
+
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
|
174 |
+
pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
|
175 |
+
neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
|
176 |
+
positives2.append(pos2)
|
177 |
+
negatives2.append(neg2)
|
178 |
+
|
179 |
+
positives = torch.cat(positives, dim=0)
|
180 |
+
negatives = torch.cat(negatives, dim=0)
|
181 |
+
diffs = positives - negatives
|
182 |
+
avg_diff = diffs.mean(0, keepdim=True)
|
183 |
+
|
184 |
+
positives2 = torch.cat(positives2, dim=0)
|
185 |
+
negatives2 = torch.cat(negatives2, dim=0)
|
186 |
+
diffs2 = positives2 - negatives2
|
187 |
+
avg_diff2 = diffs2.mean(0, keepdim=True)
|
188 |
+
return (avg_diff, avg_diff2)
|
189 |
+
|
190 |
+
def generate(self,
|
191 |
+
prompt = "a photo of a house",
|
192 |
+
scale = 2,
|
193 |
+
scale_2nd = 2,
|
194 |
+
seed = 15,
|
195 |
+
only_pooler = False,
|
196 |
+
normalize_scales = False,
|
197 |
+
correlation_weight_factor = 1.0,
|
198 |
+
**pipeline_kwargs
|
199 |
+
):
|
200 |
+
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
201 |
+
# if pooler token only [-4,4] work well
|
202 |
+
|
203 |
+
text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
|
204 |
+
tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
|
205 |
+
with torch.no_grad():
|
206 |
+
# toks = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.cuda()
|
207 |
+
# prompt_embeds = pipe.text_encoder(toks).last_hidden_state
|
208 |
+
|
209 |
+
prompt_embeds_list = []
|
210 |
+
|
211 |
+
for i, text_encoder in enumerate(text_encoders):
|
212 |
+
|
213 |
+
tokenizer = tokenizers[i]
|
214 |
+
text_inputs = tokenizer(
|
215 |
+
prompt,
|
216 |
+
padding="max_length",
|
217 |
+
max_length=tokenizer.model_max_length,
|
218 |
+
truncation=True,
|
219 |
+
return_tensors="pt",
|
220 |
+
)
|
221 |
+
toks = text_inputs.input_ids
|
222 |
+
|
223 |
+
prompt_embeds = text_encoder(
|
224 |
+
toks.to(text_encoder.device),
|
225 |
+
output_hidden_states=True,
|
226 |
+
)
|
227 |
+
|
228 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
229 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
230 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
231 |
+
|
232 |
+
if self.avg_diff_2nd and normalize_scales:
|
233 |
+
denominator = abs(scale) + abs(scale_2nd)
|
234 |
+
scale = scale / denominator
|
235 |
+
scale_2nd = scale_2nd / denominator
|
236 |
+
if only_pooler:
|
237 |
+
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale
|
238 |
+
if self.avg_diff_2nd:
|
239 |
+
prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd[0] * scale_2nd
|
240 |
+
else:
|
241 |
+
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
242 |
+
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
243 |
+
|
244 |
+
if i == 0:
|
245 |
+
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768)
|
246 |
+
|
247 |
+
standard_weights = torch.ones_like(weights)
|
248 |
+
|
249 |
+
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
250 |
+
prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
251 |
+
if self.avg_diff_2nd:
|
252 |
+
prompt_embeds += (weights * self.avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd)
|
253 |
+
else:
|
254 |
+
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
255 |
+
|
256 |
+
standard_weights = torch.ones_like(weights)
|
257 |
+
|
258 |
+
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
259 |
+
prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
|
260 |
+
if self.avg_diff_2nd:
|
261 |
+
prompt_embeds += (weights * self.avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd)
|
262 |
+
|
263 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
264 |
+
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
265 |
+
prompt_embeds_list.append(prompt_embeds)
|
266 |
+
|
267 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
268 |
+
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
269 |
+
|
270 |
+
torch.manual_seed(seed)
|
271 |
+
image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|
272 |
+
**pipeline_kwargs).images
|
273 |
+
|
274 |
+
return image
|
275 |
+
|
276 |
+
|
277 |
+
class CLIPSlider3(CLIPSlider):
|
278 |
+
def find_latent_direction(self,
|
279 |
+
target_word:str,
|
280 |
+
opposite:str):
|
281 |
+
|
282 |
+
# lets identify a latent direction by taking differences between opposites
|
283 |
+
# target_word = "happy"
|
284 |
+
# opposite = "sad"
|
285 |
+
|
286 |
+
|
287 |
+
with torch.no_grad():
|
288 |
+
positives = []
|
289 |
+
negatives = []
|
290 |
+
positives2 = []
|
291 |
+
negatives2 = []
|
292 |
+
for i in tqdm(range(self.iterations)):
|
293 |
+
medium = random.choice(MEDIUMS)
|
294 |
+
subject = random.choice(SUBJECTS)
|
295 |
+
pos_prompt = f"a {medium} of a {target_word} {subject}"
|
296 |
+
neg_prompt = f"a {medium} of a {opposite} {subject}"
|
297 |
+
|
298 |
+
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
299 |
+
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
300 |
+
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
301 |
+
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
302 |
+
pos = self.pipe.text_encoder(pos_toks).text_embeds
|
303 |
+
neg = self.pipe.text_encoder(neg_toks).text_embeds
|
304 |
+
positives.append(pos)
|
305 |
+
negatives.append(neg)
|
306 |
+
|
307 |
+
pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
308 |
+
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
|
309 |
+
neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
310 |
+
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
|
311 |
+
pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
|
312 |
+
neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
|
313 |
+
positives2.append(pos2)
|
314 |
+
negatives2.append(neg2)
|
315 |
+
|
316 |
+
positives = torch.cat(positives, dim=0)
|
317 |
+
negatives = torch.cat(negatives, dim=0)
|
318 |
+
diffs = positives - negatives
|
319 |
+
avg_diff = diffs.mean(0, keepdim=True)
|
320 |
+
|
321 |
+
positives2 = torch.cat(positives2, dim=0)
|
322 |
+
negatives2 = torch.cat(negatives2, dim=0)
|
323 |
+
diffs2 = positives2 - negatives2
|
324 |
+
avg_diff2 = diffs2.mean(0, keepdim=True)
|
325 |
+
return (avg_diff, avg_diff2)
|
326 |
+
|
327 |
+
def generate(self,
|
328 |
+
prompt = "a photo of a house",
|
329 |
+
scale = 2,
|
330 |
+
seed = 15,
|
331 |
+
only_pooler = False,
|
332 |
+
correlation_weight_factor = 1.0,
|
333 |
+
** pipeline_kwargs
|
334 |
+
):
|
335 |
+
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
336 |
+
# if pooler token only [-4,4] work well
|
337 |
+
clip_text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
|
338 |
+
clip_tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
|
339 |
+
with torch.no_grad():
|
340 |
+
# toks = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.cuda()
|
341 |
+
# prompt_embeds = pipe.text_encoder(toks).last_hidden_state
|
342 |
+
|
343 |
+
clip_prompt_embeds_list = []
|
344 |
+
clip_pooled_prompt_embeds_list = []
|
345 |
+
for i, text_encoder in enumerate(clip_text_encoders):
|
346 |
+
|
347 |
+
if i < 2:
|
348 |
+
tokenizer = clip_tokenizers[i]
|
349 |
+
text_inputs = tokenizer(
|
350 |
+
prompt,
|
351 |
+
padding="max_length",
|
352 |
+
max_length=tokenizer.model_max_length,
|
353 |
+
truncation=True,
|
354 |
+
return_tensors="pt",
|
355 |
+
)
|
356 |
+
toks = text_inputs.input_ids
|
357 |
+
|
358 |
+
prompt_embeds = text_encoder(
|
359 |
+
toks.to(text_encoder.device),
|
360 |
+
output_hidden_states=True,
|
361 |
+
)
|
362 |
+
|
363 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
364 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
365 |
+
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
366 |
+
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
|
367 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
368 |
+
else:
|
369 |
+
text_inputs = self.pipe.tokenizer_3(
|
370 |
+
prompt,
|
371 |
+
padding="max_length",
|
372 |
+
max_length=self.tokenizer_max_length,
|
373 |
+
truncation=True,
|
374 |
+
add_special_tokens=True,
|
375 |
+
return_tensors="pt",
|
376 |
+
)
|
377 |
+
toks = text_inputs.input_ids
|
378 |
+
prompt_embeds = self.pipe.text_encoder_3(toks.to(self.device))[0]
|
379 |
+
t5_prompt_embed_shape = prompt_embeds.shape[-1]
|
380 |
+
|
381 |
+
if only_pooler:
|
382 |
+
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale
|
383 |
+
else:
|
384 |
+
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
385 |
+
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
386 |
+
if i == 0:
|
387 |
+
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768)
|
388 |
+
|
389 |
+
standard_weights = torch.ones_like(weights)
|
390 |
+
|
391 |
+
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
392 |
+
prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
393 |
+
else:
|
394 |
+
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
395 |
+
|
396 |
+
standard_weights = torch.ones_like(weights)
|
397 |
+
|
398 |
+
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
399 |
+
prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
|
400 |
+
|
401 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
402 |
+
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
403 |
+
if i < 2:
|
404 |
+
clip_prompt_embeds_list.append(prompt_embeds)
|
405 |
+
|
406 |
+
clip_prompt_embeds = torch.concat(clip_prompt_embeds_list, dim=-1)
|
407 |
+
clip_pooled_prompt_embeds = torch.concat(clip_pooled_prompt_embeds_list, dim=-1)
|
408 |
+
|
409 |
+
clip_prompt_embeds = torch.nn.functional.pad(
|
410 |
+
clip_prompt_embeds, (0, t5_prompt_embed_shape - clip_prompt_embeds.shape[-1])
|
411 |
+
)
|
412 |
+
|
413 |
+
prompt_embeds = torch.cat([clip_prompt_embeds, prompt_embeds], dim=-2)
|
414 |
+
|
415 |
+
|
416 |
+
|
417 |
+
torch.manual_seed(seed)
|
418 |
+
image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=clip_pooled_prompt_embeds,
|
419 |
+
**pipeline_kwargs).images
|
420 |
+
|
421 |
+
return image
|