linoyts HF staff commited on
Commit
80d35a7
1 Parent(s): 7c81d9d

Update clip_slider_pipeline.py

Browse files
Files changed (1) hide show
  1. clip_slider_pipeline.py +7 -3
clip_slider_pipeline.py CHANGED
@@ -4,7 +4,7 @@ import random
4
  from tqdm import tqdm
5
  from constants import SUBJECTS, MEDIUMS
6
  from PIL import Image
7
-
8
  class CLIPSlider:
9
  def __init__(
10
  self,
@@ -214,7 +214,7 @@ class CLIPSliderXL(CLIPSlider):
214
  ):
215
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
216
  # if pooler token only [-4,4] work well
217
-
218
  text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
219
  tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
220
  with torch.no_grad():
@@ -282,9 +282,13 @@ class CLIPSliderXL(CLIPSlider):
282
 
283
  prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
284
  pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
285
-
 
286
  torch.manual_seed(seed)
 
287
  image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
288
  **pipeline_kwargs).images[0]
 
 
289
 
290
  return image
 
4
  from tqdm import tqdm
5
  from constants import SUBJECTS, MEDIUMS
6
  from PIL import Image
7
+ import time
8
  class CLIPSlider:
9
  def __init__(
10
  self,
 
214
  ):
215
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
216
  # if pooler token only [-4,4] work well
217
+ start_time = time.time()
218
  text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
219
  tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
220
  with torch.no_grad():
 
282
 
283
  prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
284
  pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
285
+ end_time = time.time()
286
+ print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
287
  torch.manual_seed(seed)
288
+ start_time = time.time()
289
  image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
290
  **pipeline_kwargs).images[0]
291
+ end_time = time.time()
292
+ print(f"generation time - pipe: {end_time - start_time:.2f} ms")
293
 
294
  return image