rynmurdock commited on
Commit
e55aab4
β€’
1 Parent(s): 430249a

update sol probs others

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +224 -211
  3. twitter_prompts.csv +47 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__*
app.py CHANGED
@@ -6,21 +6,27 @@ import torch
6
 
7
  # lol
8
  DEVICE = 'cuda'
9
- STEPS = 6
10
  output_hidden_state = False
11
  device = "cuda"
12
  dtype = torch.bfloat16
13
- N_IMG_EMBS = 3
14
 
 
 
 
 
 
15
  import logging
 
16
  import os
17
  import imageio
18
  import gradio as gr
19
  import numpy as np
20
- from sklearn.svm import SVC
21
- from sklearn import preprocessing
22
  import pandas as pd
23
  from apscheduler.schedulers.background import BackgroundScheduler
 
 
24
 
25
  import random
26
  import time
@@ -37,8 +43,12 @@ prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'l
37
  import spaces
38
  start_time = time.time()
39
 
 
 
 
 
40
  ####################### Setup Model
41
- from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL
42
  from transformers import CLIPTextModel
43
  from huggingface_hub import hf_hub_download
44
  from safetensors.torch import load_file
@@ -46,9 +56,8 @@ from PIL import Image
46
  from transformers import CLIPVisionModelWithProjection
47
  import uuid
48
  import av
49
- import torchvision
50
 
51
- def write_video(file_name, images, fps=17):
52
  container = av.open(file_name, mode="w")
53
 
54
  stream = container.add_stream("h264", rate=fps)
@@ -89,182 +98,133 @@ device_map='cuda')
89
  #unet = UNet2DConditionModel.from_pretrained(finetune_path+'/unet/').to(dtype)
90
  #text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
91
 
 
 
 
 
 
 
92
 
93
- unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet',).to(dtype).to('cpu')
94
- text_encoder = CLIPTextModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='text_encoder',
95
- device_map='cpu').to(dtype)
 
 
 
96
 
97
- adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
98
- pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype,
99
- unet=unet, text_encoder=text_encoder)
100
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
101
- pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
102
- pipe.set_adapters(["lcm-lora"], [.95])
103
- pipe.fuse_lora()
104
 
 
 
105
 
106
- #pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder)
107
- #pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
108
- #repo = "ByteDance/AnimateDiff-Lightning"
109
- #ckpt = f"animatediff_lightning_4step_diffusers.safetensors"
110
 
111
 
112
- pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15_vit-G.bin", map_location='cpu')
113
- # This IP adapter improves outputs substantially.
114
- pipe.set_ip_adapter_scale(.6)
115
- pipe.unet.fuse_qkv_projections()
116
- #pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
117
 
118
- pipe.to(device=DEVICE)
 
119
 
120
  #pipe.unet = torch.compile(pipe.unet)
121
  #pipe.vae = torch.compile(pipe.vae)
122
 
123
 
124
- #############################################################
125
-
126
- from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
127
-
128
- pali = PaliGemmaForConditionalGeneration.from_pretrained('google/paligemma-3b-mix-224', torch_dtype=dtype).eval().to('cuda')
129
- processor = AutoProcessor.from_pretrained('google/paligemma-3b-mix-224')
130
-
131
- #pali = torch.compile(pali)
132
 
133
  @spaces.GPU()
134
- def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None):
135
- inputs_embeds = pali.get_input_embeddings()(input_ids.to('cuda'))
136
- selected_image_feature = image_outputs.to(dtype).to('cuda')
137
- image_features = pali.multi_modal_projector(selected_image_feature)
138
-
139
- if cache_position is None:
140
- cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
141
- inputs_embeds, attention_mask, labels, position_ids = pali._merge_input_ids_with_image_features(
142
- image_features, inputs_embeds, input_ids, attention_mask, None, None, cache_position
143
- )
144
- return inputs_embeds
145
-
146
-
147
- # TODO cache descriptions?
148
- @spaces.GPU(duration=20)
149
- def generate_pali(n_embs):
150
- prompt = 'caption en'
151
- model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
152
- # we need to get im_embs taken in here.
153
-
154
- descs = ''
155
- for n, emb in enumerate(n_embs):
156
- if n < len(n_embs)-1:
157
- input_len = model_inputs["input_ids"].shape[-1]
158
- input_embeds = to_wanted_embs(emb,
159
- model_inputs["input_ids"].to(device),
160
- model_inputs["attention_mask"].to(device))
161
- generation = pali.generate(max_new_tokens=20, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
162
- decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
163
- descs += f'Description: {decoded}\n'
164
- else:
165
- prompt = f'en {descs} Describe a new image that is similar. Description:'
166
- model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
167
- input_len = model_inputs["input_ids"].shape[-1]
168
- input_embeds = to_wanted_embs(emb,
169
- model_inputs["input_ids"].to(device),
170
- model_inputs["attention_mask"].to(device))
171
- generation = pali.generate(max_new_tokens=20, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
172
- decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
173
-
174
- return decoded
175
-
176
-
177
-
178
-
179
- #############################################################
180
-
181
-
182
-
183
- @spaces.GPU(duration=20)
184
  def generate_gpu(in_im_embs, prompt='the scene'):
185
  with torch.no_grad():
186
- in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
 
187
  output = pipe(prompt=prompt, guidance_scale=1, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
188
  im_emb, _ = pipe.encode_image(
189
- output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
190
  )
191
  im_emb = im_emb.detach().to('cpu').to(torch.float32)
192
- im = torchvision.transforms.ToTensor()(output.frames[0][len(output.frames[0])//2]).unsqueeze(0)
193
- im = torch.nn.functional.interpolate(im, (224, 224)).to(dtype).to('cuda')
194
- im = (im - .5) * 2
195
- gemb = pali.vision_tower(im).last_hidden_state.detach().to('cpu').to(torch.float32)
196
- return output, im_emb, gemb
197
 
198
 
199
  def generate(in_im_embs, prompt='the scene'):
200
- output, im_emb, gemb = generate_gpu(in_im_embs, prompt)
201
- nsfw =maybe_nsfw(output.frames[0][len(output.frames[0])//2])
 
202
  name = str(uuid.uuid4()).replace("-", "")
203
- path = f"/tmp/{name}.mp4"
204
 
205
  if nsfw:
206
  gr.Warning("NSFW content detected.")
207
  # TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring.
208
- return None, im_emb, gemb
209
-
210
 
211
- output.frames[0] = output.frames[0] + list(reversed(output.frames[0]))
212
-
213
- write_video(path, output.frames[0])
214
- return path, im_emb, gemb
215
 
216
 
217
  #######################
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  def get_user_emb(embs, ys):
220
- # handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
221
-
222
- if len(list(ys)) <= 10:
223
- aways = [torch.zeros_like(embs[0]) for i in range(10)]
224
- embs += aways
225
- awal = [0 for i in range(5)] + [1 for i in range(5)]
226
- ys += awal
227
-
228
- indices = list(range(len(embs)))
229
  # sample only as many negatives as there are positives
230
- pos_indices = [i for i in indices if ys[i] == 1]
231
- neg_indices = [i for i in indices if ys[i] == 0]
232
- #lower = min(len(pos_indices), len(neg_indices))
233
- #neg_indices = random.sample(neg_indices, lower)
234
- #pos_indices = random.sample(pos_indices, lower)
235
-
236
 
237
- # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
238
- # this ends up adding a rating but losing an embedding, it seems.
239
- # let's take off a rating if so to continue without indexing errors.
240
- if len(ys) > len(embs):
241
- print('ys are longer than embs; popping latest rating')
242
- ys.pop(-1)
 
 
 
 
243
 
244
- feature_embs = torch.stack([embs[i].squeeze().to('cpu') for i in indices]).to('cpu')
245
- #scaler = preprocessing.StandardScaler().fit(feature_embs)
246
- #feature_embs = scaler.transform(feature_embs)
247
- chosen_y = np.array([ys[i] for i in indices])
 
 
 
 
 
 
 
 
 
 
248
 
249
- if feature_embs.norm() != 0:
250
- feature_embs = feature_embs / feature_embs.norm()
251
 
252
- #lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
253
- #class_weight='balanced'
254
- lin_class = SVC(max_iter=500, kernel='linear', C=.1, ).fit(feature_embs.squeeze(), chosen_y)
255
- coef_ = torch.tensor(lin_class.coef_, dtype=torch.float32).detach().to('cpu')
256
- coef_ = coef_ / coef_.abs().max()
257
 
258
- w = 1# if len(embs) % 2 == 0 else 0
259
- im_emb = w * coef_.to(dtype=dtype)
260
- return im_emb
261
 
262
 
263
  def pluck_img(user_id, user_emb):
264
  not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
265
  while len(not_rated_rows) == 0:
266
  not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
267
- time.sleep(.001)
268
  # TODO optimize this lol
269
  best_sim = -100000
270
  for i in not_rated_rows.iterrows():
@@ -274,8 +234,7 @@ def pluck_img(user_id, user_emb):
274
  best_sim = sim
275
  best_row = i[1]
276
  img = best_row['paths']
277
- text = best_row.get('text', '')
278
- return img, text
279
 
280
 
281
  def background_next_image():
@@ -283,10 +242,10 @@ def background_next_image():
283
  # only let it get N (maybe 3) ahead of the user
284
  #not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
285
  rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
286
- while len(rated_rows) < 5:
 
287
  # not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
288
- rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
289
- time.sleep(.01)
290
 
291
  user_id_list = set(rated_rows['latest_user_to_rate'].to_list())
292
  for uid in user_id_list:
@@ -300,22 +259,32 @@ def background_next_image():
300
  rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
301
 
302
  # we pop previous ratings if there are > n
303
- if len(rated_from_user) >= 25:
304
  oldest = rated_from_user.iloc[0]['paths']
305
  prevs_df = prevs_df[prevs_df['paths'] != oldest]
306
  # we don't compute more after n are in the queue for them
307
- if len(unrated_from_user) >= 20:
308
  continue
309
 
310
- embs, ys, gembs = pluck_embs_ys(uid)
311
- user_emb = get_user_emb(embs, ys) * 3
312
- pos_gembs = [g for g, y in zip(gembs, ys) if y == 1]
313
- if len(pos_gembs) > 4:
314
- hist_gem = random.sample(pos_gembs, N_IMG_EMBS) # rng n embeddings
315
- text = generate_pali(hist_gem)
 
 
 
 
 
 
 
 
 
 
316
  else:
317
- text = 'the scene'
318
- img, embs, new_gem = generate(user_emb, text)
319
 
320
  if img:
321
  tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'text', 'gemb'])
@@ -324,7 +293,6 @@ def background_next_image():
324
  tmp_df['user:rating'] = [{' ': ' '}]
325
  tmp_df['from_user_id'] = [uid]
326
  tmp_df['text'] = [text]
327
- tmp_df['gemb'] = [new_gem]
328
  prevs_df = pd.concat((prevs_df, tmp_df))
329
  # we can free up storage by deleting the image
330
  if len(prevs_df) > 500:
@@ -340,37 +308,52 @@ def background_next_image():
340
 
341
  def pluck_embs_ys(user_id):
342
  rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
 
 
 
 
 
 
343
 
344
  embs = rated_rows['embeddings'].to_list()
345
  ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
346
- gembs = rated_rows['gemb'].to_list()
347
- return embs, ys, gembs
348
 
349
  def next_image(calibrate_prompts, user_id):
350
  with torch.no_grad():
351
  if len(calibrate_prompts) > 0:
352
  cal_video = calibrate_prompts.pop(0)
353
  image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
354
- return image, calibrate_prompts, ''
355
  else:
356
- embs, ys, gembs = pluck_embs_ys(user_id)
357
- user_emb = get_user_emb(embs, ys) * 3
358
- image, text = pluck_img(user_id, user_emb)
359
- return image, calibrate_prompts, text
 
 
 
 
 
 
 
360
 
361
 
362
 
363
  def start(_, calibrate_prompts, user_id, request: gr.Request):
364
  user_id = int(str(time.time())[-7:].replace('.', ''))
365
- image, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
366
  return [
367
- gr.Button(value='Like (L)', interactive=True),
368
  gr.Button(value='Neither (Space)', interactive=True, visible=False),
369
- gr.Button(value='Dislike (A)', interactive=True),
370
  gr.Button(value='Start', interactive=False),
 
 
371
  image,
372
  calibrate_prompts,
373
- user_id
 
374
  ]
375
 
376
 
@@ -378,27 +361,34 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
378
  global prevs_df
379
 
380
 
381
- if choice == 'Like (L)':
382
- choice = 1
383
  elif choice == 'Neither (Space)':
384
- img, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
385
- return img, calibrate_prompts, text
 
 
 
 
 
 
386
  else:
387
- choice = 0
388
 
389
  # if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
390
  # TODO skip allowing rating & just continue
391
- if img == None:
 
392
  print('NSFW -- choice is disliked')
393
- choice = 0
394
 
395
  row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
396
  # if it's still in the dataframe, add the choice
397
  if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
398
  prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
399
  prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
400
- img, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
401
- return img, calibrate_prompts, text
402
 
403
  css = '''.gradio-container{max-width: 700px !important}
404
  #description{text-align: center}
@@ -461,53 +451,71 @@ Explore the latent space without text prompts based on your preferences. Learn m
461
  user_id = gr.State()
462
  # calibration videos -- this is a misnomer now :D
463
  calibrate_prompts = gr.State([
464
- './first.mp4',
465
- './second.mp4',
466
- './third.mp4',
467
- './fourth.mp4',
468
- './fifth.mp4',
469
- './sixth.mp4',
470
  ])
471
  def l():
472
  return None
473
 
474
  with gr.Row(elem_id='output-image'):
475
- img = gr.Video(
476
  label='Lightning',
477
- autoplay=True,
478
  interactive=False,
479
- height=512,
480
- width=512,
481
  #include_audio=False,
482
- elem_id="video_output"
 
483
  )
484
- img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
485
- with gr.Row():
486
- text = gr.Textbox(interactive=False, visible=True, label='Text')
 
487
  with gr.Row(equal_height=True):
488
- b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
 
489
  b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False)
490
- b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
 
 
 
 
 
 
491
  b1.click(
492
  choose,
493
  [img, b1, calibrate_prompts, user_id],
494
- [img, calibrate_prompts, text],
495
  )
496
  b2.click(
497
  choose,
498
  [img, b2, calibrate_prompts, user_id],
499
- [img, calibrate_prompts, text],
500
  )
501
  b3.click(
502
  choose,
503
  [img, b3, calibrate_prompts, user_id],
504
- [img, calibrate_prompts, text],
 
 
 
 
 
 
 
 
 
 
505
  )
506
  with gr.Row():
507
  b4 = gr.Button(value='Start')
508
  b4.click(start,
509
  [b4, calibrate_prompts, user_id],
510
- [b1, b2, b3, b4, img, calibrate_prompts, user_id]
511
  )
512
  with gr.Row():
513
  html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several videos and then roam. </ div><br><br><br>
@@ -518,37 +526,42 @@ Explore the latent space without text prompts based on your preferences. Learn m
518
  </ div>''')
519
 
520
  # TODO quiet logging
521
- log = logging.getLogger('log_here')
522
- log.setLevel(logging.ERROR)
523
 
524
  scheduler = BackgroundScheduler()
525
- scheduler.add_job(func=background_next_image, trigger="interval", seconds=.5)
526
  scheduler.start()
527
 
 
 
 
 
 
 
 
 
 
 
528
 
529
  # prep our calibration videos
530
- for im in [
531
- './first.mp4',
532
- './second.mp4',
533
- './third.mp4',
534
- './fourth.mp4',
535
- './fifth.mp4',
536
- './sixth.mp4',
537
- './seventh.mp4',
538
- './eigth.mp4',
539
- './ninth.mp4',
540
- './tenth.mp4',
541
  ]:
542
  tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb'])
543
  tmp_df['paths'] = [im]
544
  image = list(imageio.imiter(im))
545
  image = image[len(image)//2]
546
- tmp_df['embeddings'] = [torch.load(im.replace('mp4', 'im_.pt'))]
547
- tmp_df['gemb'] = [torch.load(im.replace('mp4', 'gemb_.pt'))]
 
548
  tmp_df['user:rating'] = [{' ': ' '}]
 
549
  prevs_df = pd.concat((prevs_df, tmp_df))
550
 
551
-
552
- demo.launch(share=True)
553
 
554
 
 
6
 
7
  # lol
8
  DEVICE = 'cuda'
9
+ STEPS = 8
10
  output_hidden_state = False
11
  device = "cuda"
12
  dtype = torch.bfloat16
 
13
 
14
+
15
+ import spaces
16
+
17
+ import matplotlib.pyplot as plt
18
+ import matplotlib
19
  import logging
20
+
21
  import os
22
  import imageio
23
  import gradio as gr
24
  import numpy as np
25
+ from sklearn.svm import LinearSVC
 
26
  import pandas as pd
27
  from apscheduler.schedulers.background import BackgroundScheduler
28
+ import sched
29
+ import threading
30
 
31
  import random
32
  import time
 
43
  import spaces
44
  start_time = time.time()
45
 
46
+ prompt_list = [p for p in list(set(
47
+ pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
48
+
49
+
50
  ####################### Setup Model
51
+ from diffusers import EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL, AutoPipelineForText2Image
52
  from transformers import CLIPTextModel
53
  from huggingface_hub import hf_hub_download
54
  from safetensors.torch import load_file
 
56
  from transformers import CLIPVisionModelWithProjection
57
  import uuid
58
  import av
 
59
 
60
+ def write_video(file_name, images, fps=16):
61
  container = av.open(file_name, mode="w")
62
 
63
  stream = container.add_stream("h264", rate=fps)
 
98
  #unet = UNet2DConditionModel.from_pretrained(finetune_path+'/unet/').to(dtype)
99
  #text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
100
 
101
+ #rynmurdock/Sea_Claws
102
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
103
+ sdxl_lightening = "ByteDance/SDXL-Lightning"
104
+ ckpt = "sdxl_lightning_8step_unet.safetensors"
105
+ unet = UNet2DConditionModel.from_config(model_id, subfolder="unet", low_cpu_mem_usage=True, device_map=DEVICE).to(torch.float16)
106
+ unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt)))
107
 
108
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map=DEVICE)
109
+ pipe = AutoPipelineForText2Image.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", image_encoder=image_encoder, low_cpu_mem_usage=True)
110
+ pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl_vit-h.bin')))
111
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl_vit-h.bin")
112
+ pipe.register_modules(image_encoder = image_encoder)
113
+ pipe.set_ip_adapter_scale(0.8)
114
 
115
+ #pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16, low_cpu_mem_usage=True)
116
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
 
 
 
 
 
117
 
118
+ pipe.to(device=DEVICE).to(dtype=dtype)
119
+ output_hidden_state = False
120
 
 
 
 
 
121
 
122
 
 
 
 
 
 
123
 
124
+ # pipe.unet.fuse_qkv_projections()
125
+ #pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
126
 
127
  #pipe.unet = torch.compile(pipe.unet)
128
  #pipe.vae = torch.compile(pipe.vae)
129
 
130
 
 
 
 
 
 
 
 
 
131
 
132
  @spaces.GPU()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def generate_gpu(in_im_embs, prompt='the scene'):
134
  with torch.no_grad():
135
+ print(prompt)
136
+ in_im_embs = in_im_embs.to('cuda').unsqueeze(0)
137
  output = pipe(prompt=prompt, guidance_scale=1, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
138
  im_emb, _ = pipe.encode_image(
139
+ output.images[0], 'cuda', 1, output_hidden_state
140
  )
141
  im_emb = im_emb.detach().to('cpu').to(torch.float32)
142
+ return output, im_emb
 
 
 
 
143
 
144
 
145
  def generate(in_im_embs, prompt='the scene'):
146
+ output, im_emb = generate_gpu(in_im_embs, prompt)
147
+ nsfw = maybe_nsfw(output.images[0])
148
+
149
  name = str(uuid.uuid4()).replace("-", "")
150
+ path = f"/tmp/{name}.png"
151
 
152
  if nsfw:
153
  gr.Warning("NSFW content detected.")
154
  # TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring.
155
+ return None, im_emb
 
156
 
157
+ output.images[0].save(path)
158
+ return path, im_emb
 
 
159
 
160
 
161
  #######################
162
 
163
+
164
+
165
+
166
+
167
+
168
+ @spaces.GPU()
169
+ def solver(embs, ys):
170
+ print('ys:', ys,'EMBS:', embs.shape, embs)
171
+ ys = torch.tensor(ys).to('cpu', dtype=torch.float32).squeeze().unsqueeze(1)
172
+
173
+ sol = LinearSVC(class_weight='balanced').fit(np.array(embs), np.array(torch.tensor(ys).float() * 2 - 1)).coef_
174
+ return torch.tensor(sol).to('cpu', dtype=torch.float32)
175
+
176
+
177
+
178
+
179
  def get_user_emb(embs, ys):
 
 
 
 
 
 
 
 
 
180
  # sample only as many negatives as there are positives
181
+ indices = range(len(ys))
182
+ pos_indices = [i for i in indices if ys[i] > .5]
183
+ neg_indices = [i for i in indices if ys[i] <= .5]
 
 
 
184
 
185
+ mini = min(len(pos_indices), len(neg_indices))
186
+
187
+ if len(ys) > 20: # drop earliest of whichever of neg or pos is most abundant
188
+ if len(pos_indices) > len(neg_indices):
189
+ ind = pos_indices[0]
190
+ else:
191
+ ind = neg_indices[0]
192
+ ys.pop(ind)
193
+ embs.pop(ind)
194
+ print('Dropping at 20')
195
 
196
+ if mini < 1:
197
+ feature_embs = torch.stack([torch.randn(1280), torch.randn(1280)])
198
+ ys_t = [0, 1]
199
+ print('Not enough ratings.')
200
+ else:
201
+ indices = range(len(ys))
202
+ ys_t = [ys[i] for i in indices]
203
+ feature_embs = torch.stack([embs[e].detach().cpu() for e in indices]).squeeze()
204
+
205
+ # scaler = preprocessing.StandardScaler().fit(feature_embs)
206
+ # feature_embs = scaler.transform(feature_embs)
207
+ # ys_t = ys
208
+
209
+ print(np.array(feature_embs).shape, np.array(ys_t).shape)
210
 
211
+ sol = solver(feature_embs.squeeze(), ys_t)
212
+ dif = torch.tensor(sol, dtype=dtype).to(device)
213
 
214
+ # could j have a base vector of a black image
215
+ latest_pos = (random.sample([feature_embs[i] for i in range(len(ys_t)) if ys_t[i] > .5], 1)[0]).to(device, dtype)
216
+
217
+ dif = ((dif / dif.std()) * latest_pos.std())
 
218
 
219
+ sol = (1*latest_pos + 3*dif)/4
220
+ return sol
 
221
 
222
 
223
  def pluck_img(user_id, user_emb):
224
  not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
225
  while len(not_rated_rows) == 0:
226
  not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
227
+ time.sleep(.1)
228
  # TODO optimize this lol
229
  best_sim = -100000
230
  for i in not_rated_rows.iterrows():
 
234
  best_sim = sim
235
  best_row = i[1]
236
  img = best_row['paths']
237
+ return img
 
238
 
239
 
240
  def background_next_image():
 
242
  # only let it get N (maybe 3) ahead of the user
243
  #not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
244
  rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
245
+ if len(rated_rows) < 4:
246
+ time.sleep(.1)
247
  # not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
248
+ return
 
249
 
250
  user_id_list = set(rated_rows['latest_user_to_rate'].to_list())
251
  for uid in user_id_list:
 
259
  rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
260
 
261
  # we pop previous ratings if there are > n
262
+ if len(rated_from_user) >= 15:
263
  oldest = rated_from_user.iloc[0]['paths']
264
  prevs_df = prevs_df[prevs_df['paths'] != oldest]
265
  # we don't compute more after n are in the queue for them
266
+ if len(unrated_from_user) >= 10:
267
  continue
268
 
269
+ if len(rated_rows) < 5:
270
+ continue
271
+
272
+ embs, ys = pluck_embs_ys(uid)
273
+
274
+ user_emb = get_user_emb(embs, [y[1] for y in ys])
275
+
276
+
277
+ global glob_idx
278
+ glob_idx += 1
279
+ if glob_idx >= (len(prompt_list)-1):
280
+ glob_idx = 0
281
+
282
+
283
+ if glob_idx % 7 == 0:
284
+ text = prompt_list[glob_idx]
285
  else:
286
+ text = 'an image'
287
+ img, embs = generate(user_emb, text)
288
 
289
  if img:
290
  tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'text', 'gemb'])
 
293
  tmp_df['user:rating'] = [{' ': ' '}]
294
  tmp_df['from_user_id'] = [uid]
295
  tmp_df['text'] = [text]
 
296
  prevs_df = pd.concat((prevs_df, tmp_df))
297
  # we can free up storage by deleting the image
298
  if len(prevs_df) > 500:
 
308
 
309
  def pluck_embs_ys(user_id):
310
  rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
311
+ #not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]]
312
+ #while len(not_rated_rows) == 0:
313
+ # not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]]
314
+ # rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
315
+ # time.sleep(.01)
316
+ # print('current user has 0 not_rated_rows')
317
 
318
  embs = rated_rows['embeddings'].to_list()
319
  ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
320
+ return embs, ys
 
321
 
322
  def next_image(calibrate_prompts, user_id):
323
  with torch.no_grad():
324
  if len(calibrate_prompts) > 0:
325
  cal_video = calibrate_prompts.pop(0)
326
  image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
327
+ return image, calibrate_prompts,
328
  else:
329
+ embs, ys = pluck_embs_ys(user_id)
330
+ ys_here = [y[1] for y in ys]
331
+ user_emb = get_user_emb(embs, ys_here)
332
+ image = pluck_img(user_id, user_emb)
333
+ return image, calibrate_prompts,
334
+
335
+
336
+
337
+
338
+
339
+
340
 
341
 
342
 
343
  def start(_, calibrate_prompts, user_id, request: gr.Request):
344
  user_id = int(str(time.time())[-7:].replace('.', ''))
345
+ image, calibrate_prompts = next_image(calibrate_prompts, user_id)
346
  return [
347
+ gr.Button(value='πŸ‘', interactive=True),
348
  gr.Button(value='Neither (Space)', interactive=True, visible=False),
349
+ gr.Button(value='πŸ‘Ž', interactive=True),
350
  gr.Button(value='Start', interactive=False),
351
+ gr.Button(value='πŸ‘ Content', interactive=True, visible=False),
352
+ gr.Button(value='πŸ‘ Style', interactive=True, visible=False),
353
  image,
354
  calibrate_prompts,
355
+ user_id,
356
+
357
  ]
358
 
359
 
 
361
  global prevs_df
362
 
363
 
364
+ if choice == 'πŸ‘':
365
+ choice = [1, 1]
366
  elif choice == 'Neither (Space)':
367
+ img, calibrate_prompts, = next_image(calibrate_prompts, user_id)
368
+ return img, calibrate_prompts,
369
+ elif choice == 'πŸ‘Ž':
370
+ choice = [0, 0]
371
+ elif choice == 'πŸ‘ Style':
372
+ choice = [0, 1]
373
+ elif choice == 'πŸ‘ Content':
374
+ choice = [1, 0]
375
  else:
376
+ assert False, f'choice is {choice}'
377
 
378
  # if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
379
  # TODO skip allowing rating & just continue
380
+
381
+ if img is None:
382
  print('NSFW -- choice is disliked')
383
+ choice = [0, 0]
384
 
385
  row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
386
  # if it's still in the dataframe, add the choice
387
  if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
388
  prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
389
  prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
390
+ img, calibrate_prompts, = next_image(calibrate_prompts, user_id)
391
+ return img, calibrate_prompts
392
 
393
  css = '''.gradio-container{max-width: 700px !important}
394
  #description{text-align: center}
 
451
  user_id = gr.State()
452
  # calibration videos -- this is a misnomer now :D
453
  calibrate_prompts = gr.State([
454
+ './first.png',
455
+ './second.png',
456
+ './sixth.png',
457
+ './fifth.png',
458
+ './fourth.png',
 
459
  ])
460
  def l():
461
  return None
462
 
463
  with gr.Row(elem_id='output-image'):
464
+ img = gr.Image(
465
  label='Lightning',
466
+ # autoplay=True,
467
  interactive=False,
468
+ # height=512,
469
+ # width=512,
470
  #include_audio=False,
471
+ elem_id="video_output",
472
+ type='filepath',
473
  )
474
+ #img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
475
+
476
+
477
+
478
  with gr.Row(equal_height=True):
479
+ b3 = gr.Button(value='πŸ‘Ž', interactive=False, elem_id="dislike")
480
+
481
  b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False)
482
+
483
+ b1 = gr.Button(value='πŸ‘', interactive=False, elem_id="like")
484
+ with gr.Row(equal_height=True):
485
+ b6 = gr.Button(value='πŸ‘ Style', interactive=False, elem_id="dislike like", visible=False)
486
+
487
+ b5 = gr.Button(value='πŸ‘ Content', interactive=False, elem_id="like dislike", visible=False)
488
+
489
  b1.click(
490
  choose,
491
  [img, b1, calibrate_prompts, user_id],
492
+ [img, calibrate_prompts, ],
493
  )
494
  b2.click(
495
  choose,
496
  [img, b2, calibrate_prompts, user_id],
497
+ [img, calibrate_prompts, ],
498
  )
499
  b3.click(
500
  choose,
501
  [img, b3, calibrate_prompts, user_id],
502
+ [img, calibrate_prompts, ],
503
+ )
504
+ b5.click(
505
+ choose,
506
+ [img, b5, calibrate_prompts, user_id],
507
+ [img, calibrate_prompts, ],
508
+ )
509
+ b6.click(
510
+ choose,
511
+ [img, b6, calibrate_prompts, user_id],
512
+ [img, calibrate_prompts, ],
513
  )
514
  with gr.Row():
515
  b4 = gr.Button(value='Start')
516
  b4.click(start,
517
  [b4, calibrate_prompts, user_id],
518
+ [b1, b2, b3, b4, b5, b6, img, calibrate_prompts, user_id, ]
519
  )
520
  with gr.Row():
521
  html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several videos and then roam. </ div><br><br><br>
 
526
  </ div>''')
527
 
528
  # TODO quiet logging
 
 
529
 
530
  scheduler = BackgroundScheduler()
531
+ scheduler.add_job(func=background_next_image, trigger="interval", seconds=.2)
532
  scheduler.start()
533
 
534
+ #thread = threading.Thread(target=background_next_image,)
535
+ #thread.start()
536
+
537
+ # TODO shouldn't call this before gradio launch, yeah?
538
+ @spaces.GPU()
539
+ def encode_space(x):
540
+ im_emb, _ = pipe.encode_image(
541
+ image, DEVICE, 1, output_hidden_state
542
+ )
543
+ return im_emb.detach().to('cpu').to(torch.float32)
544
 
545
  # prep our calibration videos
546
+ for im, txt in [ # TODO more movement
547
+ ('./first.png', 'describe the scene: a sketch'),
548
+ ('./second.png', 'describe the scene: omens in the suburbs'),
549
+ ('./sixth.png', 'describe the scene: geometric abstract art of a windmill'),
550
+ ('./fifth.png', 'describe the scene: memento mori'),
551
+ ('./fourth.png', 'describe the scene: a green plate with anespresso'),
 
 
 
 
 
552
  ]:
553
  tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb'])
554
  tmp_df['paths'] = [im]
555
  image = list(imageio.imiter(im))
556
  image = image[len(image)//2]
557
+ im_emb = encode_space(image)
558
+
559
+ tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
560
  tmp_df['user:rating'] = [{' ': ' '}]
561
+ tmp_df['text'] = [txt]
562
  prevs_df = pd.concat((prevs_df, tmp_df))
563
 
564
+ glob_idx = 0
565
+ demo.launch(share=True,)
566
 
567
 
twitter_prompts.csv ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,0
2
+ 0,a sunset
3
+ 1,a still life in blue
4
+ 2,last day on earth
5
+ 3,the conch shell
6
+ 4,the winds of change
7
+ 5,a surrealist eye
8
+ 6,a surrealist polaroid photo of an apple
9
+ 7,metaphysics
10
+ 8,the sun is setting into my glass of tea
11
+ 9,the moon at 3am
12
+ 10,a memento mori
13
+ 11,quaking aspen tree
14
+ 12,violets and daffodils
15
+ 13,espresso
16
+ 14,sisyphus
17
+ 15,high windows of stained glass
18
+ 16,a green dog
19
+ 17,an adorable companion; it is a pig
20
+ 18,bird of paradise
21
+ 19,a complex intricate machine
22
+ 20,a white clock
23
+ 21,a film featuring the landscape Salt Lake City Utah
24
+ 22,a creature
25
+ 23,a house set aflame.
26
+ 24,a gorgeous landscape by Cy Twombly
27
+ 25,smoke rises from the caterpillar's hookah
28
+ 26,corvid in red
29
+ 27,Monet's pond
30
+ 28,Genesis
31
+ 29,Death is a black camel that kneels down so we can ride
32
+ 30,a cherry tree made of fractals
33
+ 29,the end of the sidewalk
34
+ 30,a polaroid photo of a bustling city of lights and sky scrapers
35
+ 31,The Fig Tree metaphor
36
+ 32,God killed Van Gogh.
37
+ 33,a cosmic entity alien with four eyes.
38
+ 34,a horse with 128 eyes.
39
+ 35,a being with an infinite set of eyes (it is omniscient)
40
+ 36,A sticky-note magnum opus featuring birds
41
+ 37,Moka Pot
42
+ 38,the moon is a sickle cell
43
+ 39,The Penultimate Supper
44
+ 40,Art
45
+ 41,surrealism
46
+ 42,a god made of wires & dust
47
+ 43,a dandelion blown into the universe