linoyts HF staff commited on
Commit
dc736bc
1 Parent(s): a42a306

Update clip_slider_pipeline.py

Browse files
Files changed (1) hide show
  1. clip_slider_pipeline.py +153 -34
clip_slider_pipeline.py CHANGED
@@ -33,16 +33,12 @@ class CLIPSlider:
33
 
34
  def find_latent_direction(self,
35
  target_word:str,
36
- opposite:str,
37
- num_iterations: int = None):
38
 
39
  # lets identify a latent direction by taking differences between opposites
40
  # target_word = "happy"
41
  # opposite = "sad"
42
- if num_iterations is not None:
43
- iterations = num_iterations
44
- else:
45
- iterations = self.iterations
46
 
47
  with torch.no_grad():
48
  positives = []
@@ -78,8 +74,6 @@ class CLIPSlider:
78
  only_pooler = False,
79
  normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
80
  correlation_weight_factor = 1.0,
81
- avg_diff = None,
82
- avg_diff_2nd = None,
83
  **pipeline_kwargs
84
  ):
85
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
@@ -90,14 +84,14 @@ class CLIPSlider:
90
  max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
91
  prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
92
 
93
- if avg_diff_2nd and normalize_scales:
94
  denominator = abs(scale) + abs(scale_2nd)
95
  scale = scale / denominator
96
  scale_2nd = scale_2nd / denominator
97
  if only_pooler:
98
- prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale
99
- if avg_diff_2nd:
100
- prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd
101
  else:
102
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
103
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
@@ -109,9 +103,9 @@ class CLIPSlider:
109
 
110
  # weights = torch.sigmoid((weights-0.5)*7)
111
  prompt_embeds = prompt_embeds + (
112
- weights * avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
113
- if avg_diff_2nd:
114
- prompt_embeds += weights * avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
115
 
116
 
117
  torch.manual_seed(seed)
@@ -354,22 +348,21 @@ class CLIPSliderXL_inv(CLIPSlider):
354
 
355
  return images
356
 
357
-
358
- class T5SliderFlux(CLIPSlider):
359
-
360
  def find_latent_direction(self,
361
  target_word:str,
362
  opposite:str,
363
- num_iterations:int ):
364
 
365
  # lets identify a latent direction by taking differences between opposites
366
  # target_word = "happy"
367
  # opposite = "sad"
 
368
  if num_iterations is not None:
369
  iterations = num_iterations
370
  else:
371
  iterations = self.iterations
372
-
373
  with torch.no_grad():
374
  positives = []
375
  negatives = []
@@ -378,6 +371,137 @@ class T5SliderFlux(CLIPSlider):
378
  subject = random.choice(SUBJECTS)
379
  pos_prompt = f"a {medium} of a {target_word} {subject}"
380
  neg_prompt = f"a {medium} of a {opposite} {subject}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
  pos_toks = self.pipe.tokenizer_2(pos_prompt,
383
  return_tensors="pt",
@@ -412,9 +536,7 @@ class T5SliderFlux(CLIPSlider):
412
  seed = 15,
413
  only_pooler = False,
414
  normalize_scales = False,
415
- correlation_weight_factor = 0.6,
416
- avg_diff = None,
417
- avg_diff_2nd = None,
418
  **pipeline_kwargs
419
  ):
420
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
@@ -453,34 +575,31 @@ class T5SliderFlux(CLIPSlider):
453
  prompt_embeds = self.pipe.text_encoder_2(toks.to(self.device), output_hidden_states=False)[0]
454
  dtype = self.pipe.text_encoder_2.dtype
455
  prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device)
456
- print("1", prompt_embeds.shape)
457
- if avg_diff_2nd is not None and normalize_scales:
458
  denominator = abs(scale) + abs(scale_2nd)
459
  scale = scale / denominator
460
  scale_2nd = scale_2nd / denominator
461
  if only_pooler:
462
- prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale
463
- if avg_diff_2nd is not None:
464
- prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd
465
  else:
466
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
467
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
468
 
469
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, prompt_embeds.shape[2])
470
- print("weights", weights.shape)
471
 
472
  standard_weights = torch.ones_like(weights)
473
 
474
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
475
  prompt_embeds = prompt_embeds + (
476
- weights * avg_diff * scale)
477
- print("2", prompt_embeds.shape)
478
- if avg_diff_2nd is not None:
479
  prompt_embeds += (
480
- weights * avg_diff_2nd * scale_2nd)
481
 
482
  torch.manual_seed(seed)
483
  images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
484
- **pipeline_kwargs).images[0]
485
 
486
  return images
 
33
 
34
  def find_latent_direction(self,
35
  target_word:str,
36
+ opposite:str):
 
37
 
38
  # lets identify a latent direction by taking differences between opposites
39
  # target_word = "happy"
40
  # opposite = "sad"
41
+
 
 
 
42
 
43
  with torch.no_grad():
44
  positives = []
 
74
  only_pooler = False,
75
  normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
76
  correlation_weight_factor = 1.0,
 
 
77
  **pipeline_kwargs
78
  ):
79
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
 
84
  max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
85
  prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
86
 
87
+ if self.avg_diff_2nd and normalize_scales:
88
  denominator = abs(scale) + abs(scale_2nd)
89
  scale = scale / denominator
90
  scale_2nd = scale_2nd / denominator
91
  if only_pooler:
92
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale
93
+ if self.avg_diff_2nd:
94
+ prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd
95
  else:
96
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
97
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
 
103
 
104
  # weights = torch.sigmoid((weights-0.5)*7)
105
  prompt_embeds = prompt_embeds + (
106
+ weights * self.avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
107
+ if self.avg_diff_2nd:
108
+ prompt_embeds += weights * self.avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
109
 
110
 
111
  torch.manual_seed(seed)
 
348
 
349
  return images
350
 
351
+ class CLIPSliderFlux(CLIPSlider):
 
 
352
  def find_latent_direction(self,
353
  target_word:str,
354
  opposite:str,
355
+ num_iterations: int = None):
356
 
357
  # lets identify a latent direction by taking differences between opposites
358
  # target_word = "happy"
359
  # opposite = "sad"
360
+
361
  if num_iterations is not None:
362
  iterations = num_iterations
363
  else:
364
  iterations = self.iterations
365
+
366
  with torch.no_grad():
367
  positives = []
368
  negatives = []
 
371
  subject = random.choice(SUBJECTS)
372
  pos_prompt = f"a {medium} of a {target_word} {subject}"
373
  neg_prompt = f"a {medium} of a {opposite} {subject}"
374
+ pos_toks = self.pipe.tokenizer(pos_prompt,
375
+ padding="max_length",
376
+ max_length=self.pipe.tokenizer_max_length,
377
+ truncation=True,
378
+ return_overflowing_tokens=False,
379
+ return_length=False,
380
+ return_tensors="pt",).input_ids.cuda()
381
+ neg_toks = self.pipe.tokenizer(neg_prompt,
382
+ padding="max_length",
383
+ max_length=self.pipe.tokenizer_max_length,
384
+ truncation=True,
385
+ return_overflowing_tokens=False,
386
+ return_length=False,
387
+ return_tensors="pt",).input_ids.cuda()
388
+ pos = self.pipe.text_encoder(pos_toks).pooler_output
389
+ neg = self.pipe.text_encoder(neg_toks).pooler_output
390
+ positives.append(pos)
391
+ negatives.append(neg)
392
+
393
+ positives = torch.cat(positives, dim=0)
394
+ negatives = torch.cat(negatives, dim=0)
395
+
396
+ diffs = positives - negatives
397
+
398
+ avg_diff = diffs.mean(0, keepdim=True)
399
+ return avg_diff
400
+
401
+ def generate(self,
402
+ prompt = "a photo of a house",
403
+ scale = 2,
404
+ scale_2nd = 2,
405
+ seed = 15,
406
+ normalize_scales = False,
407
+ avg_diff = None,
408
+ avg_diff_2nd = None,
409
+ **pipeline_kwargs
410
+ ):
411
+ # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
412
+ # if pooler token only [-4,4] work well
413
+
414
+ with torch.no_grad():
415
+ text_inputs = self.pipe.tokenizer(
416
+ prompt,
417
+ padding="max_length",
418
+ max_length=77,
419
+ truncation=True,
420
+ return_overflowing_tokens=False,
421
+ return_length=False,
422
+ return_tensors="pt",
423
+ )
424
+
425
+ text_input_ids = text_inputs.input_ids
426
+ prompt_embeds = self.pipe.text_encoder(text_input_ids.to(self.device), output_hidden_states=False)
427
+
428
+ # Use pooled output of CLIPTextModel
429
+ prompt_embeds = prompt_embeds.pooler_output
430
+ pooled_prompt_embeds = prompt_embeds.to(dtype=self.pipe.text_encoder.dtype, device=self.device)
431
+
432
+ # Use pooled output of CLIPTextModel
433
+
434
+ text_inputs = self.pipe.tokenizer_2(
435
+ prompt,
436
+ padding="max_length",
437
+ max_length=512,
438
+ truncation=True,
439
+ return_length=False,
440
+ return_overflowing_tokens=False,
441
+ return_tensors="pt",
442
+ )
443
+ toks = text_inputs.input_ids
444
+ prompt_embeds = self.pipe.text_encoder_2(toks.to(self.device), output_hidden_states=False)[0]
445
+ dtype = self.pipe.text_encoder_2.dtype
446
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device)
447
+ if avg_diff_2nd is not None and normalize_scales:
448
+ denominator = abs(scale) + abs(scale_2nd)
449
+ scale = scale / denominator
450
+ scale_2nd = scale_2nd / denominator
451
+
452
+ pooled_prompt_embeds = pooled_prompt_embeds + avg_diff * scale
453
+ if avg_diff_2nd is not None:
454
+ pooled_prompt_embeds += avg_diff_2nd * scale_2nd
455
+
456
+ torch.manual_seed(seed)
457
+ images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
458
+ **pipeline_kwargs).images
459
+
460
+ return images
461
+
462
+ def spectrum(self,
463
+ prompt="a photo of a house",
464
+ low_scale=-2,
465
+ low_scale_2nd=-2,
466
+ high_scale=2,
467
+ high_scale_2nd=2,
468
+ steps=5,
469
+ seed=15,
470
+ normalize_scales=False,
471
+ **pipeline_kwargs
472
+ ):
473
+
474
+ images = []
475
+ for i in range(steps):
476
+ scale = low_scale + (high_scale - low_scale) * i / (steps - 1)
477
+ scale_2nd = low_scale_2nd + (high_scale_2nd - low_scale_2nd) * i / (steps - 1)
478
+ image = self.generate(prompt, scale, scale_2nd, seed, normalize_scales, **pipeline_kwargs)
479
+ images.append(image[0].resize((512,512)))
480
+
481
+ canvas = Image.new('RGB', (640 * steps, 640))
482
+ for i, im in enumerate(images):
483
+ canvas.paste(im, (640 * i, 0))
484
+
485
+ return canvas
486
+ class T5SliderFlux(CLIPSlider):
487
+
488
+ def find_latent_direction(self,
489
+ target_word:str,
490
+ opposite:str):
491
+
492
+ # lets identify a latent direction by taking differences between opposites
493
+ # target_word = "happy"
494
+ # opposite = "sad"
495
+
496
+
497
+ with torch.no_grad():
498
+ positives = []
499
+ negatives = []
500
+ for i in tqdm(range(self.iterations)):
501
+ medium = random.choice(MEDIUMS)
502
+ subject = random.choice(SUBJECTS)
503
+ pos_prompt = f"a {medium} of a {target_word} {subject}"
504
+ neg_prompt = f"a {medium} of a {opposite} {subject}"
505
 
506
  pos_toks = self.pipe.tokenizer_2(pos_prompt,
507
  return_tensors="pt",
 
536
  seed = 15,
537
  only_pooler = False,
538
  normalize_scales = False,
539
+ correlation_weight_factor = 1.0,
 
 
540
  **pipeline_kwargs
541
  ):
542
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
 
575
  prompt_embeds = self.pipe.text_encoder_2(toks.to(self.device), output_hidden_states=False)[0]
576
  dtype = self.pipe.text_encoder_2.dtype
577
  prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device)
578
+ if self.avg_diff_2nd and normalize_scales:
 
579
  denominator = abs(scale) + abs(scale_2nd)
580
  scale = scale / denominator
581
  scale_2nd = scale_2nd / denominator
582
  if only_pooler:
583
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale
584
+ if self.avg_diff_2nd:
585
+ prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd
586
  else:
587
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
588
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
589
 
590
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, prompt_embeds.shape[2])
 
591
 
592
  standard_weights = torch.ones_like(weights)
593
 
594
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
595
  prompt_embeds = prompt_embeds + (
596
+ weights * self.avg_diff * scale)
597
+ if self.avg_diff_2nd:
 
598
  prompt_embeds += (
599
+ weights * self.avg_diff_2nd * scale_2nd)
600
 
601
  torch.manual_seed(seed)
602
  images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
603
+ **pipeline_kwargs).images
604
 
605
  return images