Spaces:
Sleeping
Sleeping
rynmurdock
commited on
Commit
β’
e55aab4
1
Parent(s):
430249a
update sol probs others
Browse files- .gitignore +1 -0
- app.py +224 -211
- twitter_prompts.csv +47 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__*
|
app.py
CHANGED
@@ -6,21 +6,27 @@ import torch
|
|
6 |
|
7 |
# lol
|
8 |
DEVICE = 'cuda'
|
9 |
-
STEPS =
|
10 |
output_hidden_state = False
|
11 |
device = "cuda"
|
12 |
dtype = torch.bfloat16
|
13 |
-
N_IMG_EMBS = 3
|
14 |
|
|
|
|
|
|
|
|
|
|
|
15 |
import logging
|
|
|
16 |
import os
|
17 |
import imageio
|
18 |
import gradio as gr
|
19 |
import numpy as np
|
20 |
-
from sklearn.svm import
|
21 |
-
from sklearn import preprocessing
|
22 |
import pandas as pd
|
23 |
from apscheduler.schedulers.background import BackgroundScheduler
|
|
|
|
|
24 |
|
25 |
import random
|
26 |
import time
|
@@ -37,8 +43,12 @@ prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'l
|
|
37 |
import spaces
|
38 |
start_time = time.time()
|
39 |
|
|
|
|
|
|
|
|
|
40 |
####################### Setup Model
|
41 |
-
from diffusers import
|
42 |
from transformers import CLIPTextModel
|
43 |
from huggingface_hub import hf_hub_download
|
44 |
from safetensors.torch import load_file
|
@@ -46,9 +56,8 @@ from PIL import Image
|
|
46 |
from transformers import CLIPVisionModelWithProjection
|
47 |
import uuid
|
48 |
import av
|
49 |
-
import torchvision
|
50 |
|
51 |
-
def write_video(file_name, images, fps=
|
52 |
container = av.open(file_name, mode="w")
|
53 |
|
54 |
stream = container.add_stream("h264", rate=fps)
|
@@ -89,182 +98,133 @@ device_map='cuda')
|
|
89 |
#unet = UNet2DConditionModel.from_pretrained(finetune_path+'/unet/').to(dtype)
|
90 |
#text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
96 |
|
97 |
-
|
98 |
-
pipe =
|
99 |
-
unet=unet, text_encoder=text_encoder)
|
100 |
-
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
|
101 |
-
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
|
102 |
-
pipe.set_adapters(["lcm-lora"], [.95])
|
103 |
-
pipe.fuse_lora()
|
104 |
|
|
|
|
|
105 |
|
106 |
-
#pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder)
|
107 |
-
#pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
108 |
-
#repo = "ByteDance/AnimateDiff-Lightning"
|
109 |
-
#ckpt = f"animatediff_lightning_4step_diffusers.safetensors"
|
110 |
|
111 |
|
112 |
-
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15_vit-G.bin", map_location='cpu')
|
113 |
-
# This IP adapter improves outputs substantially.
|
114 |
-
pipe.set_ip_adapter_scale(.6)
|
115 |
-
pipe.unet.fuse_qkv_projections()
|
116 |
-
#pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
|
117 |
|
118 |
-
pipe.
|
|
|
119 |
|
120 |
#pipe.unet = torch.compile(pipe.unet)
|
121 |
#pipe.vae = torch.compile(pipe.vae)
|
122 |
|
123 |
|
124 |
-
#############################################################
|
125 |
-
|
126 |
-
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
127 |
-
|
128 |
-
pali = PaliGemmaForConditionalGeneration.from_pretrained('google/paligemma-3b-mix-224', torch_dtype=dtype).eval().to('cuda')
|
129 |
-
processor = AutoProcessor.from_pretrained('google/paligemma-3b-mix-224')
|
130 |
-
|
131 |
-
#pali = torch.compile(pali)
|
132 |
|
133 |
@spaces.GPU()
|
134 |
-
def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None):
|
135 |
-
inputs_embeds = pali.get_input_embeddings()(input_ids.to('cuda'))
|
136 |
-
selected_image_feature = image_outputs.to(dtype).to('cuda')
|
137 |
-
image_features = pali.multi_modal_projector(selected_image_feature)
|
138 |
-
|
139 |
-
if cache_position is None:
|
140 |
-
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
|
141 |
-
inputs_embeds, attention_mask, labels, position_ids = pali._merge_input_ids_with_image_features(
|
142 |
-
image_features, inputs_embeds, input_ids, attention_mask, None, None, cache_position
|
143 |
-
)
|
144 |
-
return inputs_embeds
|
145 |
-
|
146 |
-
|
147 |
-
# TODO cache descriptions?
|
148 |
-
@spaces.GPU(duration=20)
|
149 |
-
def generate_pali(n_embs):
|
150 |
-
prompt = 'caption en'
|
151 |
-
model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
|
152 |
-
# we need to get im_embs taken in here.
|
153 |
-
|
154 |
-
descs = ''
|
155 |
-
for n, emb in enumerate(n_embs):
|
156 |
-
if n < len(n_embs)-1:
|
157 |
-
input_len = model_inputs["input_ids"].shape[-1]
|
158 |
-
input_embeds = to_wanted_embs(emb,
|
159 |
-
model_inputs["input_ids"].to(device),
|
160 |
-
model_inputs["attention_mask"].to(device))
|
161 |
-
generation = pali.generate(max_new_tokens=20, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
|
162 |
-
decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
163 |
-
descs += f'Description: {decoded}\n'
|
164 |
-
else:
|
165 |
-
prompt = f'en {descs} Describe a new image that is similar. Description:'
|
166 |
-
model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
|
167 |
-
input_len = model_inputs["input_ids"].shape[-1]
|
168 |
-
input_embeds = to_wanted_embs(emb,
|
169 |
-
model_inputs["input_ids"].to(device),
|
170 |
-
model_inputs["attention_mask"].to(device))
|
171 |
-
generation = pali.generate(max_new_tokens=20, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
|
172 |
-
decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
173 |
-
|
174 |
-
return decoded
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
#############################################################
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
@spaces.GPU(duration=20)
|
184 |
def generate_gpu(in_im_embs, prompt='the scene'):
|
185 |
with torch.no_grad():
|
186 |
-
|
|
|
187 |
output = pipe(prompt=prompt, guidance_scale=1, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
|
188 |
im_emb, _ = pipe.encode_image(
|
189 |
-
output.
|
190 |
)
|
191 |
im_emb = im_emb.detach().to('cpu').to(torch.float32)
|
192 |
-
|
193 |
-
im = torch.nn.functional.interpolate(im, (224, 224)).to(dtype).to('cuda')
|
194 |
-
im = (im - .5) * 2
|
195 |
-
gemb = pali.vision_tower(im).last_hidden_state.detach().to('cpu').to(torch.float32)
|
196 |
-
return output, im_emb, gemb
|
197 |
|
198 |
|
199 |
def generate(in_im_embs, prompt='the scene'):
|
200 |
-
output, im_emb
|
201 |
-
nsfw =maybe_nsfw(output.
|
|
|
202 |
name = str(uuid.uuid4()).replace("-", "")
|
203 |
-
path = f"/tmp/{name}.
|
204 |
|
205 |
if nsfw:
|
206 |
gr.Warning("NSFW content detected.")
|
207 |
# TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring.
|
208 |
-
return None, im_emb
|
209 |
-
|
210 |
|
211 |
-
output.
|
212 |
-
|
213 |
-
write_video(path, output.frames[0])
|
214 |
-
return path, im_emb, gemb
|
215 |
|
216 |
|
217 |
#######################
|
218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
def get_user_emb(embs, ys):
|
220 |
-
# handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
|
221 |
-
|
222 |
-
if len(list(ys)) <= 10:
|
223 |
-
aways = [torch.zeros_like(embs[0]) for i in range(10)]
|
224 |
-
embs += aways
|
225 |
-
awal = [0 for i in range(5)] + [1 for i in range(5)]
|
226 |
-
ys += awal
|
227 |
-
|
228 |
-
indices = list(range(len(embs)))
|
229 |
# sample only as many negatives as there are positives
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
#neg_indices = random.sample(neg_indices, lower)
|
234 |
-
#pos_indices = random.sample(pos_indices, lower)
|
235 |
-
|
236 |
|
237 |
-
|
238 |
-
|
239 |
-
#
|
240 |
-
|
241 |
-
|
242 |
-
|
|
|
|
|
|
|
|
|
243 |
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
|
249 |
-
|
250 |
-
|
251 |
|
252 |
-
#
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
coef_ = coef_ / coef_.abs().max()
|
257 |
|
258 |
-
|
259 |
-
|
260 |
-
return im_emb
|
261 |
|
262 |
|
263 |
def pluck_img(user_id, user_emb):
|
264 |
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
|
265 |
while len(not_rated_rows) == 0:
|
266 |
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
|
267 |
-
time.sleep(.
|
268 |
# TODO optimize this lol
|
269 |
best_sim = -100000
|
270 |
for i in not_rated_rows.iterrows():
|
@@ -274,8 +234,7 @@ def pluck_img(user_id, user_emb):
|
|
274 |
best_sim = sim
|
275 |
best_row = i[1]
|
276 |
img = best_row['paths']
|
277 |
-
|
278 |
-
return img, text
|
279 |
|
280 |
|
281 |
def background_next_image():
|
@@ -283,10 +242,10 @@ def background_next_image():
|
|
283 |
# only let it get N (maybe 3) ahead of the user
|
284 |
#not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
285 |
rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
|
286 |
-
|
|
|
287 |
# not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
288 |
-
|
289 |
-
time.sleep(.01)
|
290 |
|
291 |
user_id_list = set(rated_rows['latest_user_to_rate'].to_list())
|
292 |
for uid in user_id_list:
|
@@ -300,22 +259,32 @@ def background_next_image():
|
|
300 |
rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
|
301 |
|
302 |
# we pop previous ratings if there are > n
|
303 |
-
if len(rated_from_user) >=
|
304 |
oldest = rated_from_user.iloc[0]['paths']
|
305 |
prevs_df = prevs_df[prevs_df['paths'] != oldest]
|
306 |
# we don't compute more after n are in the queue for them
|
307 |
-
if len(unrated_from_user) >=
|
308 |
continue
|
309 |
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
else:
|
317 |
-
text = '
|
318 |
-
img, embs
|
319 |
|
320 |
if img:
|
321 |
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'text', 'gemb'])
|
@@ -324,7 +293,6 @@ def background_next_image():
|
|
324 |
tmp_df['user:rating'] = [{' ': ' '}]
|
325 |
tmp_df['from_user_id'] = [uid]
|
326 |
tmp_df['text'] = [text]
|
327 |
-
tmp_df['gemb'] = [new_gem]
|
328 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
329 |
# we can free up storage by deleting the image
|
330 |
if len(prevs_df) > 500:
|
@@ -340,37 +308,52 @@ def background_next_image():
|
|
340 |
|
341 |
def pluck_embs_ys(user_id):
|
342 |
rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
embs = rated_rows['embeddings'].to_list()
|
345 |
ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
|
346 |
-
|
347 |
-
return embs, ys, gembs
|
348 |
|
349 |
def next_image(calibrate_prompts, user_id):
|
350 |
with torch.no_grad():
|
351 |
if len(calibrate_prompts) > 0:
|
352 |
cal_video = calibrate_prompts.pop(0)
|
353 |
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
354 |
-
return image, calibrate_prompts,
|
355 |
else:
|
356 |
-
embs, ys
|
357 |
-
|
358 |
-
|
359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
|
361 |
|
362 |
|
363 |
def start(_, calibrate_prompts, user_id, request: gr.Request):
|
364 |
user_id = int(str(time.time())[-7:].replace('.', ''))
|
365 |
-
image, calibrate_prompts
|
366 |
return [
|
367 |
-
gr.Button(value='
|
368 |
gr.Button(value='Neither (Space)', interactive=True, visible=False),
|
369 |
-
gr.Button(value='
|
370 |
gr.Button(value='Start', interactive=False),
|
|
|
|
|
371 |
image,
|
372 |
calibrate_prompts,
|
373 |
-
user_id
|
|
|
374 |
]
|
375 |
|
376 |
|
@@ -378,27 +361,34 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
|
|
378 |
global prevs_df
|
379 |
|
380 |
|
381 |
-
if choice == '
|
382 |
-
choice = 1
|
383 |
elif choice == 'Neither (Space)':
|
384 |
-
img, calibrate_prompts,
|
385 |
-
return img, calibrate_prompts,
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
else:
|
387 |
-
choice
|
388 |
|
389 |
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
|
390 |
# TODO skip allowing rating & just continue
|
391 |
-
|
|
|
392 |
print('NSFW -- choice is disliked')
|
393 |
-
choice = 0
|
394 |
|
395 |
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
396 |
# if it's still in the dataframe, add the choice
|
397 |
if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
|
398 |
prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
|
399 |
prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
|
400 |
-
img, calibrate_prompts,
|
401 |
-
return img, calibrate_prompts
|
402 |
|
403 |
css = '''.gradio-container{max-width: 700px !important}
|
404 |
#description{text-align: center}
|
@@ -461,53 +451,71 @@ Explore the latent space without text prompts based on your preferences. Learn m
|
|
461 |
user_id = gr.State()
|
462 |
# calibration videos -- this is a misnomer now :D
|
463 |
calibrate_prompts = gr.State([
|
464 |
-
'./first.
|
465 |
-
'./second.
|
466 |
-
'./
|
467 |
-
'./
|
468 |
-
'./
|
469 |
-
'./sixth.mp4',
|
470 |
])
|
471 |
def l():
|
472 |
return None
|
473 |
|
474 |
with gr.Row(elem_id='output-image'):
|
475 |
-
img = gr.
|
476 |
label='Lightning',
|
477 |
-
autoplay=True,
|
478 |
interactive=False,
|
479 |
-
height=512,
|
480 |
-
width=512,
|
481 |
#include_audio=False,
|
482 |
-
elem_id="video_output"
|
|
|
483 |
)
|
484 |
-
img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
|
485 |
-
|
486 |
-
|
|
|
487 |
with gr.Row(equal_height=True):
|
488 |
-
b3 = gr.Button(value='
|
|
|
489 |
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False)
|
490 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
b1.click(
|
492 |
choose,
|
493 |
[img, b1, calibrate_prompts, user_id],
|
494 |
-
[img, calibrate_prompts,
|
495 |
)
|
496 |
b2.click(
|
497 |
choose,
|
498 |
[img, b2, calibrate_prompts, user_id],
|
499 |
-
[img, calibrate_prompts,
|
500 |
)
|
501 |
b3.click(
|
502 |
choose,
|
503 |
[img, b3, calibrate_prompts, user_id],
|
504 |
-
[img, calibrate_prompts,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
505 |
)
|
506 |
with gr.Row():
|
507 |
b4 = gr.Button(value='Start')
|
508 |
b4.click(start,
|
509 |
[b4, calibrate_prompts, user_id],
|
510 |
-
[b1, b2, b3, b4, img, calibrate_prompts, user_id]
|
511 |
)
|
512 |
with gr.Row():
|
513 |
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several videos and then roam. </ div><br><br><br>
|
@@ -518,37 +526,42 @@ Explore the latent space without text prompts based on your preferences. Learn m
|
|
518 |
</ div>''')
|
519 |
|
520 |
# TODO quiet logging
|
521 |
-
log = logging.getLogger('log_here')
|
522 |
-
log.setLevel(logging.ERROR)
|
523 |
|
524 |
scheduler = BackgroundScheduler()
|
525 |
-
scheduler.add_job(func=background_next_image, trigger="interval", seconds=.
|
526 |
scheduler.start()
|
527 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
528 |
|
529 |
# prep our calibration videos
|
530 |
-
for im in [
|
531 |
-
'./first.
|
532 |
-
'./second.
|
533 |
-
'./
|
534 |
-
'./
|
535 |
-
'./
|
536 |
-
'./sixth.mp4',
|
537 |
-
'./seventh.mp4',
|
538 |
-
'./eigth.mp4',
|
539 |
-
'./ninth.mp4',
|
540 |
-
'./tenth.mp4',
|
541 |
]:
|
542 |
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb'])
|
543 |
tmp_df['paths'] = [im]
|
544 |
image = list(imageio.imiter(im))
|
545 |
image = image[len(image)//2]
|
546 |
-
|
547 |
-
|
|
|
548 |
tmp_df['user:rating'] = [{' ': ' '}]
|
|
|
549 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
550 |
|
551 |
-
|
552 |
-
demo.launch(share=True)
|
553 |
|
554 |
|
|
|
6 |
|
7 |
# lol
|
8 |
DEVICE = 'cuda'
|
9 |
+
STEPS = 8
|
10 |
output_hidden_state = False
|
11 |
device = "cuda"
|
12 |
dtype = torch.bfloat16
|
|
|
13 |
|
14 |
+
|
15 |
+
import spaces
|
16 |
+
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
import matplotlib
|
19 |
import logging
|
20 |
+
|
21 |
import os
|
22 |
import imageio
|
23 |
import gradio as gr
|
24 |
import numpy as np
|
25 |
+
from sklearn.svm import LinearSVC
|
|
|
26 |
import pandas as pd
|
27 |
from apscheduler.schedulers.background import BackgroundScheduler
|
28 |
+
import sched
|
29 |
+
import threading
|
30 |
|
31 |
import random
|
32 |
import time
|
|
|
43 |
import spaces
|
44 |
start_time = time.time()
|
45 |
|
46 |
+
prompt_list = [p for p in list(set(
|
47 |
+
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
48 |
+
|
49 |
+
|
50 |
####################### Setup Model
|
51 |
+
from diffusers import EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL, AutoPipelineForText2Image
|
52 |
from transformers import CLIPTextModel
|
53 |
from huggingface_hub import hf_hub_download
|
54 |
from safetensors.torch import load_file
|
|
|
56 |
from transformers import CLIPVisionModelWithProjection
|
57 |
import uuid
|
58 |
import av
|
|
|
59 |
|
60 |
+
def write_video(file_name, images, fps=16):
|
61 |
container = av.open(file_name, mode="w")
|
62 |
|
63 |
stream = container.add_stream("h264", rate=fps)
|
|
|
98 |
#unet = UNet2DConditionModel.from_pretrained(finetune_path+'/unet/').to(dtype)
|
99 |
#text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
|
100 |
|
101 |
+
#rynmurdock/Sea_Claws
|
102 |
+
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
103 |
+
sdxl_lightening = "ByteDance/SDXL-Lightning"
|
104 |
+
ckpt = "sdxl_lightning_8step_unet.safetensors"
|
105 |
+
unet = UNet2DConditionModel.from_config(model_id, subfolder="unet", low_cpu_mem_usage=True, device_map=DEVICE).to(torch.float16)
|
106 |
+
unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt)))
|
107 |
|
108 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map=DEVICE)
|
109 |
+
pipe = AutoPipelineForText2Image.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", image_encoder=image_encoder, low_cpu_mem_usage=True)
|
110 |
+
pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl_vit-h.bin')))
|
111 |
+
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl_vit-h.bin")
|
112 |
+
pipe.register_modules(image_encoder = image_encoder)
|
113 |
+
pipe.set_ip_adapter_scale(0.8)
|
114 |
|
115 |
+
#pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
116 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
+
pipe.to(device=DEVICE).to(dtype=dtype)
|
119 |
+
output_hidden_state = False
|
120 |
|
|
|
|
|
|
|
|
|
121 |
|
122 |
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
+
# pipe.unet.fuse_qkv_projections()
|
125 |
+
#pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
|
126 |
|
127 |
#pipe.unet = torch.compile(pipe.unet)
|
128 |
#pipe.vae = torch.compile(pipe.vae)
|
129 |
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
@spaces.GPU()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
def generate_gpu(in_im_embs, prompt='the scene'):
|
134 |
with torch.no_grad():
|
135 |
+
print(prompt)
|
136 |
+
in_im_embs = in_im_embs.to('cuda').unsqueeze(0)
|
137 |
output = pipe(prompt=prompt, guidance_scale=1, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
|
138 |
im_emb, _ = pipe.encode_image(
|
139 |
+
output.images[0], 'cuda', 1, output_hidden_state
|
140 |
)
|
141 |
im_emb = im_emb.detach().to('cpu').to(torch.float32)
|
142 |
+
return output, im_emb
|
|
|
|
|
|
|
|
|
143 |
|
144 |
|
145 |
def generate(in_im_embs, prompt='the scene'):
|
146 |
+
output, im_emb = generate_gpu(in_im_embs, prompt)
|
147 |
+
nsfw = maybe_nsfw(output.images[0])
|
148 |
+
|
149 |
name = str(uuid.uuid4()).replace("-", "")
|
150 |
+
path = f"/tmp/{name}.png"
|
151 |
|
152 |
if nsfw:
|
153 |
gr.Warning("NSFW content detected.")
|
154 |
# TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring.
|
155 |
+
return None, im_emb
|
|
|
156 |
|
157 |
+
output.images[0].save(path)
|
158 |
+
return path, im_emb
|
|
|
|
|
159 |
|
160 |
|
161 |
#######################
|
162 |
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
@spaces.GPU()
|
169 |
+
def solver(embs, ys):
|
170 |
+
print('ys:', ys,'EMBS:', embs.shape, embs)
|
171 |
+
ys = torch.tensor(ys).to('cpu', dtype=torch.float32).squeeze().unsqueeze(1)
|
172 |
+
|
173 |
+
sol = LinearSVC(class_weight='balanced').fit(np.array(embs), np.array(torch.tensor(ys).float() * 2 - 1)).coef_
|
174 |
+
return torch.tensor(sol).to('cpu', dtype=torch.float32)
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
def get_user_emb(embs, ys):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
# sample only as many negatives as there are positives
|
181 |
+
indices = range(len(ys))
|
182 |
+
pos_indices = [i for i in indices if ys[i] > .5]
|
183 |
+
neg_indices = [i for i in indices if ys[i] <= .5]
|
|
|
|
|
|
|
184 |
|
185 |
+
mini = min(len(pos_indices), len(neg_indices))
|
186 |
+
|
187 |
+
if len(ys) > 20: # drop earliest of whichever of neg or pos is most abundant
|
188 |
+
if len(pos_indices) > len(neg_indices):
|
189 |
+
ind = pos_indices[0]
|
190 |
+
else:
|
191 |
+
ind = neg_indices[0]
|
192 |
+
ys.pop(ind)
|
193 |
+
embs.pop(ind)
|
194 |
+
print('Dropping at 20')
|
195 |
|
196 |
+
if mini < 1:
|
197 |
+
feature_embs = torch.stack([torch.randn(1280), torch.randn(1280)])
|
198 |
+
ys_t = [0, 1]
|
199 |
+
print('Not enough ratings.')
|
200 |
+
else:
|
201 |
+
indices = range(len(ys))
|
202 |
+
ys_t = [ys[i] for i in indices]
|
203 |
+
feature_embs = torch.stack([embs[e].detach().cpu() for e in indices]).squeeze()
|
204 |
+
|
205 |
+
# scaler = preprocessing.StandardScaler().fit(feature_embs)
|
206 |
+
# feature_embs = scaler.transform(feature_embs)
|
207 |
+
# ys_t = ys
|
208 |
+
|
209 |
+
print(np.array(feature_embs).shape, np.array(ys_t).shape)
|
210 |
|
211 |
+
sol = solver(feature_embs.squeeze(), ys_t)
|
212 |
+
dif = torch.tensor(sol, dtype=dtype).to(device)
|
213 |
|
214 |
+
# could j have a base vector of a black image
|
215 |
+
latest_pos = (random.sample([feature_embs[i] for i in range(len(ys_t)) if ys_t[i] > .5], 1)[0]).to(device, dtype)
|
216 |
+
|
217 |
+
dif = ((dif / dif.std()) * latest_pos.std())
|
|
|
218 |
|
219 |
+
sol = (1*latest_pos + 3*dif)/4
|
220 |
+
return sol
|
|
|
221 |
|
222 |
|
223 |
def pluck_img(user_id, user_emb):
|
224 |
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
|
225 |
while len(not_rated_rows) == 0:
|
226 |
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
|
227 |
+
time.sleep(.1)
|
228 |
# TODO optimize this lol
|
229 |
best_sim = -100000
|
230 |
for i in not_rated_rows.iterrows():
|
|
|
234 |
best_sim = sim
|
235 |
best_row = i[1]
|
236 |
img = best_row['paths']
|
237 |
+
return img
|
|
|
238 |
|
239 |
|
240 |
def background_next_image():
|
|
|
242 |
# only let it get N (maybe 3) ahead of the user
|
243 |
#not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
244 |
rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
|
245 |
+
if len(rated_rows) < 4:
|
246 |
+
time.sleep(.1)
|
247 |
# not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
248 |
+
return
|
|
|
249 |
|
250 |
user_id_list = set(rated_rows['latest_user_to_rate'].to_list())
|
251 |
for uid in user_id_list:
|
|
|
259 |
rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
|
260 |
|
261 |
# we pop previous ratings if there are > n
|
262 |
+
if len(rated_from_user) >= 15:
|
263 |
oldest = rated_from_user.iloc[0]['paths']
|
264 |
prevs_df = prevs_df[prevs_df['paths'] != oldest]
|
265 |
# we don't compute more after n are in the queue for them
|
266 |
+
if len(unrated_from_user) >= 10:
|
267 |
continue
|
268 |
|
269 |
+
if len(rated_rows) < 5:
|
270 |
+
continue
|
271 |
+
|
272 |
+
embs, ys = pluck_embs_ys(uid)
|
273 |
+
|
274 |
+
user_emb = get_user_emb(embs, [y[1] for y in ys])
|
275 |
+
|
276 |
+
|
277 |
+
global glob_idx
|
278 |
+
glob_idx += 1
|
279 |
+
if glob_idx >= (len(prompt_list)-1):
|
280 |
+
glob_idx = 0
|
281 |
+
|
282 |
+
|
283 |
+
if glob_idx % 7 == 0:
|
284 |
+
text = prompt_list[glob_idx]
|
285 |
else:
|
286 |
+
text = 'an image'
|
287 |
+
img, embs = generate(user_emb, text)
|
288 |
|
289 |
if img:
|
290 |
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'text', 'gemb'])
|
|
|
293 |
tmp_df['user:rating'] = [{' ': ' '}]
|
294 |
tmp_df['from_user_id'] = [uid]
|
295 |
tmp_df['text'] = [text]
|
|
|
296 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
297 |
# we can free up storage by deleting the image
|
298 |
if len(prevs_df) > 500:
|
|
|
308 |
|
309 |
def pluck_embs_ys(user_id):
|
310 |
rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
|
311 |
+
#not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]]
|
312 |
+
#while len(not_rated_rows) == 0:
|
313 |
+
# not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]]
|
314 |
+
# rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
|
315 |
+
# time.sleep(.01)
|
316 |
+
# print('current user has 0 not_rated_rows')
|
317 |
|
318 |
embs = rated_rows['embeddings'].to_list()
|
319 |
ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
|
320 |
+
return embs, ys
|
|
|
321 |
|
322 |
def next_image(calibrate_prompts, user_id):
|
323 |
with torch.no_grad():
|
324 |
if len(calibrate_prompts) > 0:
|
325 |
cal_video = calibrate_prompts.pop(0)
|
326 |
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
327 |
+
return image, calibrate_prompts,
|
328 |
else:
|
329 |
+
embs, ys = pluck_embs_ys(user_id)
|
330 |
+
ys_here = [y[1] for y in ys]
|
331 |
+
user_emb = get_user_emb(embs, ys_here)
|
332 |
+
image = pluck_img(user_id, user_emb)
|
333 |
+
return image, calibrate_prompts,
|
334 |
+
|
335 |
+
|
336 |
+
|
337 |
+
|
338 |
+
|
339 |
+
|
340 |
|
341 |
|
342 |
|
343 |
def start(_, calibrate_prompts, user_id, request: gr.Request):
|
344 |
user_id = int(str(time.time())[-7:].replace('.', ''))
|
345 |
+
image, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
346 |
return [
|
347 |
+
gr.Button(value='π', interactive=True),
|
348 |
gr.Button(value='Neither (Space)', interactive=True, visible=False),
|
349 |
+
gr.Button(value='π', interactive=True),
|
350 |
gr.Button(value='Start', interactive=False),
|
351 |
+
gr.Button(value='π Content', interactive=True, visible=False),
|
352 |
+
gr.Button(value='π Style', interactive=True, visible=False),
|
353 |
image,
|
354 |
calibrate_prompts,
|
355 |
+
user_id,
|
356 |
+
|
357 |
]
|
358 |
|
359 |
|
|
|
361 |
global prevs_df
|
362 |
|
363 |
|
364 |
+
if choice == 'π':
|
365 |
+
choice = [1, 1]
|
366 |
elif choice == 'Neither (Space)':
|
367 |
+
img, calibrate_prompts, = next_image(calibrate_prompts, user_id)
|
368 |
+
return img, calibrate_prompts,
|
369 |
+
elif choice == 'π':
|
370 |
+
choice = [0, 0]
|
371 |
+
elif choice == 'π Style':
|
372 |
+
choice = [0, 1]
|
373 |
+
elif choice == 'π Content':
|
374 |
+
choice = [1, 0]
|
375 |
else:
|
376 |
+
assert False, f'choice is {choice}'
|
377 |
|
378 |
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
|
379 |
# TODO skip allowing rating & just continue
|
380 |
+
|
381 |
+
if img is None:
|
382 |
print('NSFW -- choice is disliked')
|
383 |
+
choice = [0, 0]
|
384 |
|
385 |
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
386 |
# if it's still in the dataframe, add the choice
|
387 |
if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
|
388 |
prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
|
389 |
prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
|
390 |
+
img, calibrate_prompts, = next_image(calibrate_prompts, user_id)
|
391 |
+
return img, calibrate_prompts
|
392 |
|
393 |
css = '''.gradio-container{max-width: 700px !important}
|
394 |
#description{text-align: center}
|
|
|
451 |
user_id = gr.State()
|
452 |
# calibration videos -- this is a misnomer now :D
|
453 |
calibrate_prompts = gr.State([
|
454 |
+
'./first.png',
|
455 |
+
'./second.png',
|
456 |
+
'./sixth.png',
|
457 |
+
'./fifth.png',
|
458 |
+
'./fourth.png',
|
|
|
459 |
])
|
460 |
def l():
|
461 |
return None
|
462 |
|
463 |
with gr.Row(elem_id='output-image'):
|
464 |
+
img = gr.Image(
|
465 |
label='Lightning',
|
466 |
+
# autoplay=True,
|
467 |
interactive=False,
|
468 |
+
# height=512,
|
469 |
+
# width=512,
|
470 |
#include_audio=False,
|
471 |
+
elem_id="video_output",
|
472 |
+
type='filepath',
|
473 |
)
|
474 |
+
#img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
|
475 |
+
|
476 |
+
|
477 |
+
|
478 |
with gr.Row(equal_height=True):
|
479 |
+
b3 = gr.Button(value='π', interactive=False, elem_id="dislike")
|
480 |
+
|
481 |
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False)
|
482 |
+
|
483 |
+
b1 = gr.Button(value='π', interactive=False, elem_id="like")
|
484 |
+
with gr.Row(equal_height=True):
|
485 |
+
b6 = gr.Button(value='π Style', interactive=False, elem_id="dislike like", visible=False)
|
486 |
+
|
487 |
+
b5 = gr.Button(value='π Content', interactive=False, elem_id="like dislike", visible=False)
|
488 |
+
|
489 |
b1.click(
|
490 |
choose,
|
491 |
[img, b1, calibrate_prompts, user_id],
|
492 |
+
[img, calibrate_prompts, ],
|
493 |
)
|
494 |
b2.click(
|
495 |
choose,
|
496 |
[img, b2, calibrate_prompts, user_id],
|
497 |
+
[img, calibrate_prompts, ],
|
498 |
)
|
499 |
b3.click(
|
500 |
choose,
|
501 |
[img, b3, calibrate_prompts, user_id],
|
502 |
+
[img, calibrate_prompts, ],
|
503 |
+
)
|
504 |
+
b5.click(
|
505 |
+
choose,
|
506 |
+
[img, b5, calibrate_prompts, user_id],
|
507 |
+
[img, calibrate_prompts, ],
|
508 |
+
)
|
509 |
+
b6.click(
|
510 |
+
choose,
|
511 |
+
[img, b6, calibrate_prompts, user_id],
|
512 |
+
[img, calibrate_prompts, ],
|
513 |
)
|
514 |
with gr.Row():
|
515 |
b4 = gr.Button(value='Start')
|
516 |
b4.click(start,
|
517 |
[b4, calibrate_prompts, user_id],
|
518 |
+
[b1, b2, b3, b4, b5, b6, img, calibrate_prompts, user_id, ]
|
519 |
)
|
520 |
with gr.Row():
|
521 |
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several videos and then roam. </ div><br><br><br>
|
|
|
526 |
</ div>''')
|
527 |
|
528 |
# TODO quiet logging
|
|
|
|
|
529 |
|
530 |
scheduler = BackgroundScheduler()
|
531 |
+
scheduler.add_job(func=background_next_image, trigger="interval", seconds=.2)
|
532 |
scheduler.start()
|
533 |
|
534 |
+
#thread = threading.Thread(target=background_next_image,)
|
535 |
+
#thread.start()
|
536 |
+
|
537 |
+
# TODO shouldn't call this before gradio launch, yeah?
|
538 |
+
@spaces.GPU()
|
539 |
+
def encode_space(x):
|
540 |
+
im_emb, _ = pipe.encode_image(
|
541 |
+
image, DEVICE, 1, output_hidden_state
|
542 |
+
)
|
543 |
+
return im_emb.detach().to('cpu').to(torch.float32)
|
544 |
|
545 |
# prep our calibration videos
|
546 |
+
for im, txt in [ # TODO more movement
|
547 |
+
('./first.png', 'describe the scene: a sketch'),
|
548 |
+
('./second.png', 'describe the scene: omens in the suburbs'),
|
549 |
+
('./sixth.png', 'describe the scene: geometric abstract art of a windmill'),
|
550 |
+
('./fifth.png', 'describe the scene: memento mori'),
|
551 |
+
('./fourth.png', 'describe the scene: a green plate with anespresso'),
|
|
|
|
|
|
|
|
|
|
|
552 |
]:
|
553 |
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb'])
|
554 |
tmp_df['paths'] = [im]
|
555 |
image = list(imageio.imiter(im))
|
556 |
image = image[len(image)//2]
|
557 |
+
im_emb = encode_space(image)
|
558 |
+
|
559 |
+
tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
|
560 |
tmp_df['user:rating'] = [{' ': ' '}]
|
561 |
+
tmp_df['text'] = [txt]
|
562 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
563 |
|
564 |
+
glob_idx = 0
|
565 |
+
demo.launch(share=True,)
|
566 |
|
567 |
|
twitter_prompts.csv
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,0
|
2 |
+
0,a sunset
|
3 |
+
1,a still life in blue
|
4 |
+
2,last day on earth
|
5 |
+
3,the conch shell
|
6 |
+
4,the winds of change
|
7 |
+
5,a surrealist eye
|
8 |
+
6,a surrealist polaroid photo of an apple
|
9 |
+
7,metaphysics
|
10 |
+
8,the sun is setting into my glass of tea
|
11 |
+
9,the moon at 3am
|
12 |
+
10,a memento mori
|
13 |
+
11,quaking aspen tree
|
14 |
+
12,violets and daffodils
|
15 |
+
13,espresso
|
16 |
+
14,sisyphus
|
17 |
+
15,high windows of stained glass
|
18 |
+
16,a green dog
|
19 |
+
17,an adorable companion; it is a pig
|
20 |
+
18,bird of paradise
|
21 |
+
19,a complex intricate machine
|
22 |
+
20,a white clock
|
23 |
+
21,a film featuring the landscape Salt Lake City Utah
|
24 |
+
22,a creature
|
25 |
+
23,a house set aflame.
|
26 |
+
24,a gorgeous landscape by Cy Twombly
|
27 |
+
25,smoke rises from the caterpillar's hookah
|
28 |
+
26,corvid in red
|
29 |
+
27,Monet's pond
|
30 |
+
28,Genesis
|
31 |
+
29,Death is a black camel that kneels down so we can ride
|
32 |
+
30,a cherry tree made of fractals
|
33 |
+
29,the end of the sidewalk
|
34 |
+
30,a polaroid photo of a bustling city of lights and sky scrapers
|
35 |
+
31,The Fig Tree metaphor
|
36 |
+
32,God killed Van Gogh.
|
37 |
+
33,a cosmic entity alien with four eyes.
|
38 |
+
34,a horse with 128 eyes.
|
39 |
+
35,a being with an infinite set of eyes (it is omniscient)
|
40 |
+
36,A sticky-note magnum opus featuring birds
|
41 |
+
37,Moka Pot
|
42 |
+
38,the moon is a sickle cell
|
43 |
+
39,The Penultimate Supper
|
44 |
+
40,Art
|
45 |
+
41,surrealism
|
46 |
+
42,a god made of wires & dust
|
47 |
+
43,a dandelion blown into the universe
|