rynmurdock commited on
Commit
a8fb4ec
1 Parent(s): b4f2949

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -29
app.py CHANGED
@@ -45,7 +45,8 @@ import spaces
45
  start_time = time.time()
46
 
47
  ####################### Setup Model
48
- from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL
 
49
  from transformers import CLIPTextModel
50
  from huggingface_hub import hf_hub_download
51
  from safetensors.torch import load_file
@@ -55,7 +56,6 @@ import uuid
55
  import av
56
 
57
  def write_video(file_name, images, fps=17):
58
- print('Saving')
59
  container = av.open(file_name, mode="w")
60
 
61
  stream = container.add_stream("h264", rate=fps)
@@ -76,7 +76,6 @@ def write_video(file_name, images, fps=17):
76
  container.mux(packet)
77
  # Close the file
78
  container.close()
79
- print('Saved')
80
 
81
  def imio_write_video(file_name, images, fps=15):
82
  writer = imageio.get_writer(file_name, fps=fps)
@@ -128,14 +127,11 @@ pipe.to(device=DEVICE)
128
 
129
  @spaces.GPU()
130
  def generate_gpu(in_im_embs):
131
- print('start gen')
132
  in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
133
  output = pipe(prompt='', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
134
- print('image is made')
135
  im_emb, _ = pipe.encode_image(
136
  output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
137
  )
138
- print('im_emb is made')
139
  im_emb = im_emb.detach().to('cpu').to(torch.float32)
140
  return output, im_emb
141
 
@@ -168,7 +164,6 @@ def get_user_emb(embs, ys):
168
  embs.append(.01*torch.randn(1280))
169
  ys.append(0)
170
  ys.append(1)
171
- print('Fixing only one feedback class available.\n')
172
 
173
  indices = list(range(len(embs)))
174
  # sample only as many negatives as there are positives
@@ -177,14 +172,12 @@ def get_user_emb(embs, ys):
177
  #lower = min(len(pos_indices), len(neg_indices))
178
  #neg_indices = random.sample(neg_indices, lower)
179
  #pos_indices = random.sample(pos_indices, lower)
180
- print(len(neg_indices), len(pos_indices))
181
 
182
 
183
  # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
184
  # this ends up adding a rating but losing an embedding, it seems.
185
  # let's take off a rating if so to continue without indexing errors.
186
  if len(ys) > len(embs):
187
- print('ys are longer than embs; popping latest rating')
188
  ys.pop(-1)
189
 
190
  feature_embs = np.array(torch.stack([embs[i].squeeze().to('cpu') for i in indices]).to('cpu'))
@@ -192,12 +185,10 @@ def get_user_emb(embs, ys):
192
  #feature_embs = scaler.transform(feature_embs)
193
  chosen_y = np.array([ys[i] for i in indices])
194
 
195
- print('Gathering coefficients')
196
  #lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
197
  lin_class = SVC(max_iter=20, kernel='linear', C=.1, class_weight='balanced').fit(feature_embs, chosen_y)
198
  coef_ = torch.tensor(lin_class.coef_, dtype=torch.double).detach().to('cpu')
199
  coef_ = coef_ / coef_.abs().max() * 3
200
- print('Gathered')
201
 
202
  w = 1# if len(embs) % 2 == 0 else 0
203
  im_emb = w * coef_.to(dtype=dtype)
@@ -205,7 +196,6 @@ def get_user_emb(embs, ys):
205
 
206
 
207
  def pluck_img(user_id, user_emb):
208
- print(user_id, 'user_id')
209
  not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
210
  while len(not_rated_rows) == 0:
211
  not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
@@ -231,7 +221,6 @@ def background_next_image():
231
  # not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
232
  rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
233
  time.sleep(.01)
234
- print('all users have 4 or less rows rated')
235
 
236
  user_id_list = set(rated_rows['latest_user_to_rate'].to_list())
237
  for uid in user_id_list:
@@ -253,15 +242,12 @@ def background_next_image():
253
  continue
254
 
255
  if len(rated_rows) < 4:
256
- print(f'latest user {uid} has < 4 rows') # or > 7 unrated rows')
257
  continue
258
 
259
- print(uid)
260
  embs, ys = pluck_embs_ys(uid)
261
 
262
  user_emb = get_user_emb(embs, ys)
263
  img, embs = generate(user_emb)
264
- print(img)
265
  if img:
266
  tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate'])
267
  tmp_df['paths'] = [img]
@@ -276,16 +262,10 @@ def background_next_image():
276
  cands['sum_bad_ratings'] = [sum([int(t==0) for t in i.values()]) for i in cands['user:rating']]
277
  worst_row = cands.loc[cands['sum_bad_ratings']==cands['sum_bad_ratings'].max()].iloc[0]
278
  worst_path = worst_row['paths']
279
- print('Removing worst row:', worst_row, 'from prevs_df of len', len(prevs_df))
280
  if os.path.isfile(worst_path):
281
  os.remove(worst_path)
282
- else:
283
- # If it fails, inform the user.
284
- print("Error: %s file not found" % worst_path)
285
-
286
  # only keep x images & embeddings & ips, then remove the most often disliked besides calibrating
287
  prevs_df = prevs_df[prevs_df['paths'] != worst_path]
288
- print('prevs_df is now length:', len(prevs_df))
289
 
290
  def pluck_embs_ys(user_id):
291
  rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
@@ -298,21 +278,17 @@ def pluck_embs_ys(user_id):
298
 
299
  embs = rated_rows['embeddings'].to_list()
300
  ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
301
- print('embs', 'ys', embs, ys)
302
  return embs, ys
303
 
304
  def next_image(calibrate_prompts, user_id):
305
- print(prevs_df)
306
 
307
  with torch.no_grad():
308
  if len(calibrate_prompts) > 0:
309
- print('######### Calibrating with sample media #########')
310
  cal_video = calibrate_prompts.pop(0)
311
  image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
312
 
313
  return image, calibrate_prompts
314
  else:
315
- print('######### Roaming #########')
316
  embs, ys = pluck_embs_ys(user_id)
317
  user_emb = get_user_emb(embs, ys)
318
  image = pluck_img(user_id, user_emb)
@@ -355,7 +331,6 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
355
  # if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
356
  # TODO skip allowing rating & just continue
357
  if img == None:
358
- print('NSFW -- choice is disliked')
359
  choice = 0
360
 
361
  row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
@@ -425,7 +400,6 @@ with gr.Blocks(css=css, head=js_head) as demo:
425
  Explore the latent space without text prompts based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
426
  ''', elem_id="description")
427
  user_id = gr.State()
428
- print('USER_ID: ',user_id)
429
  # calibration videos -- this is a misnomer now :D
430
  calibrate_prompts = gr.State([
431
  './first.mp4',
@@ -487,7 +461,7 @@ log = logging.getLogger('log_here')
487
  log.setLevel(logging.ERROR)
488
 
489
  scheduler = BackgroundScheduler()
490
- scheduler.add_job(func=background_next_image, trigger="interval", seconds=.1)
491
  scheduler.start()
492
 
493
  #thread = threading.Thread(target=background_next_image,)
 
45
  start_time = time.time()
46
 
47
  ####################### Setup Model
48
+ from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL, utils
49
+ utils.logging.disable_progress_bar
50
  from transformers import CLIPTextModel
51
  from huggingface_hub import hf_hub_download
52
  from safetensors.torch import load_file
 
56
  import av
57
 
58
  def write_video(file_name, images, fps=17):
 
59
  container = av.open(file_name, mode="w")
60
 
61
  stream = container.add_stream("h264", rate=fps)
 
76
  container.mux(packet)
77
  # Close the file
78
  container.close()
 
79
 
80
  def imio_write_video(file_name, images, fps=15):
81
  writer = imageio.get_writer(file_name, fps=fps)
 
127
 
128
  @spaces.GPU()
129
  def generate_gpu(in_im_embs):
 
130
  in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
131
  output = pipe(prompt='', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
 
132
  im_emb, _ = pipe.encode_image(
133
  output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
134
  )
 
135
  im_emb = im_emb.detach().to('cpu').to(torch.float32)
136
  return output, im_emb
137
 
 
164
  embs.append(.01*torch.randn(1280))
165
  ys.append(0)
166
  ys.append(1)
 
167
 
168
  indices = list(range(len(embs)))
169
  # sample only as many negatives as there are positives
 
172
  #lower = min(len(pos_indices), len(neg_indices))
173
  #neg_indices = random.sample(neg_indices, lower)
174
  #pos_indices = random.sample(pos_indices, lower)
 
175
 
176
 
177
  # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
178
  # this ends up adding a rating but losing an embedding, it seems.
179
  # let's take off a rating if so to continue without indexing errors.
180
  if len(ys) > len(embs):
 
181
  ys.pop(-1)
182
 
183
  feature_embs = np.array(torch.stack([embs[i].squeeze().to('cpu') for i in indices]).to('cpu'))
 
185
  #feature_embs = scaler.transform(feature_embs)
186
  chosen_y = np.array([ys[i] for i in indices])
187
 
 
188
  #lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
189
  lin_class = SVC(max_iter=20, kernel='linear', C=.1, class_weight='balanced').fit(feature_embs, chosen_y)
190
  coef_ = torch.tensor(lin_class.coef_, dtype=torch.double).detach().to('cpu')
191
  coef_ = coef_ / coef_.abs().max() * 3
 
192
 
193
  w = 1# if len(embs) % 2 == 0 else 0
194
  im_emb = w * coef_.to(dtype=dtype)
 
196
 
197
 
198
  def pluck_img(user_id, user_emb):
 
199
  not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
200
  while len(not_rated_rows) == 0:
201
  not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
 
221
  # not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
222
  rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
223
  time.sleep(.01)
 
224
 
225
  user_id_list = set(rated_rows['latest_user_to_rate'].to_list())
226
  for uid in user_id_list:
 
242
  continue
243
 
244
  if len(rated_rows) < 4:
 
245
  continue
246
 
 
247
  embs, ys = pluck_embs_ys(uid)
248
 
249
  user_emb = get_user_emb(embs, ys)
250
  img, embs = generate(user_emb)
 
251
  if img:
252
  tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate'])
253
  tmp_df['paths'] = [img]
 
262
  cands['sum_bad_ratings'] = [sum([int(t==0) for t in i.values()]) for i in cands['user:rating']]
263
  worst_row = cands.loc[cands['sum_bad_ratings']==cands['sum_bad_ratings'].max()].iloc[0]
264
  worst_path = worst_row['paths']
 
265
  if os.path.isfile(worst_path):
266
  os.remove(worst_path)
 
 
 
 
267
  # only keep x images & embeddings & ips, then remove the most often disliked besides calibrating
268
  prevs_df = prevs_df[prevs_df['paths'] != worst_path]
 
269
 
270
  def pluck_embs_ys(user_id):
271
  rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
 
278
 
279
  embs = rated_rows['embeddings'].to_list()
280
  ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
 
281
  return embs, ys
282
 
283
  def next_image(calibrate_prompts, user_id):
 
284
 
285
  with torch.no_grad():
286
  if len(calibrate_prompts) > 0:
 
287
  cal_video = calibrate_prompts.pop(0)
288
  image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
289
 
290
  return image, calibrate_prompts
291
  else:
 
292
  embs, ys = pluck_embs_ys(user_id)
293
  user_emb = get_user_emb(embs, ys)
294
  image = pluck_img(user_id, user_emb)
 
331
  # if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
332
  # TODO skip allowing rating & just continue
333
  if img == None:
 
334
  choice = 0
335
 
336
  row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
 
400
  Explore the latent space without text prompts based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
401
  ''', elem_id="description")
402
  user_id = gr.State()
 
403
  # calibration videos -- this is a misnomer now :D
404
  calibrate_prompts = gr.State([
405
  './first.mp4',
 
461
  log.setLevel(logging.ERROR)
462
 
463
  scheduler = BackgroundScheduler()
464
+ scheduler.add_job(func=background_next_image, trigger="interval", seconds=.3)
465
  scheduler.start()
466
 
467
  #thread = threading.Thread(target=background_next_image,)