aatir commited on
Commit
5894739
·
1 Parent(s): c2cf3a4

rectified dtypes

Browse files
Files changed (2) hide show
  1. app.py +20 -5
  2. lib_omost/pipeline.py +9 -7
app.py CHANGED
@@ -35,21 +35,36 @@ import lib_omost.canvas as omost_canvas
35
 
36
  # SDXL
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  sdxl_name = 'SG161222/RealVisXL_V4.0'
39
- # sdxl_name = 'stabilityai/stable-diffusion-xl-base-1.0'
40
 
41
  tokenizer = CLIPTokenizer.from_pretrained(
42
  sdxl_name, subfolder="tokenizer")
43
  tokenizer_2 = CLIPTokenizer.from_pretrained(
44
  sdxl_name, subfolder="tokenizer_2")
45
  text_encoder = CLIPTextModel.from_pretrained(
46
- sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16", device_map="auto")
47
  text_encoder_2 = CLIPTextModel.from_pretrained(
48
- sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16", device_map="auto")
49
  vae = AutoencoderKL.from_pretrained(
50
- sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16", device_map="auto") # bfloat16 vae
51
  unet = UNet2DConditionModel.from_pretrained(
52
- sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16", device_map="auto")
53
 
54
  unet.set_attn_processor(AttnProcessor2_0())
55
  vae.set_attn_processor(AttnProcessor2_0())
 
35
 
36
  # SDXL
37
 
38
+ # sdxl_name = 'SG161222/RealVisXL_V4.0'
39
+ # # sdxl_name = 'stabilityai/stable-diffusion-xl-base-1.0'
40
+
41
+ # tokenizer = CLIPTokenizer.from_pretrained(
42
+ # sdxl_name, subfolder="tokenizer")
43
+ # tokenizer_2 = CLIPTokenizer.from_pretrained(
44
+ # sdxl_name, subfolder="tokenizer_2")
45
+ # text_encoder = CLIPTextModel.from_pretrained(
46
+ # sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16", device_map="auto")
47
+ # text_encoder_2 = CLIPTextModel.from_pretrained(
48
+ # sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16", device_map="auto")
49
+ # vae = AutoencoderKL.from_pretrained(
50
+ # sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16", device_map="auto") # bfloat16 vae
51
+ # unet = UNet2DConditionModel.from_pretrained(
52
+ # sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16", device_map="auto")
53
+
54
  sdxl_name = 'SG161222/RealVisXL_V4.0'
 
55
 
56
  tokenizer = CLIPTokenizer.from_pretrained(
57
  sdxl_name, subfolder="tokenizer")
58
  tokenizer_2 = CLIPTokenizer.from_pretrained(
59
  sdxl_name, subfolder="tokenizer_2")
60
  text_encoder = CLIPTextModel.from_pretrained(
61
+ sdxl_name, subfolder="text_encoder", torch_dtype=torch.float32, device_map="auto")
62
  text_encoder_2 = CLIPTextModel.from_pretrained(
63
+ sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float32, device_map="auto")
64
  vae = AutoencoderKL.from_pretrained(
65
+ sdxl_name, subfolder="vae", torch_dtype=torch.float32, device_map="auto")
66
  unet = UNet2DConditionModel.from_pretrained(
67
+ sdxl_name, subfolder="unet", torch_dtype=torch.float32, device_map="auto")
68
 
69
  unet.set_attn_processor(AttnProcessor2_0())
70
  vae.set_attn_processor(AttnProcessor2_0())
lib_omost/pipeline.py CHANGED
@@ -90,6 +90,7 @@ class KModel:
90
  return torch.cat([sigmas, sigmas.new_zeros([1])])
91
 
92
  def __call__(self, x, sigma, **extra_args):
 
93
  x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data ** 2) ** 0.5
94
  t = self.timestep(sigma)
95
  cfg_scale = extra_args['cfg_scale']
@@ -380,6 +381,7 @@ class StableDiffusionXLOmostPipeline(StableDiffusionXLImg2ImgPipeline):
380
  ):
381
 
382
  device = self.unet.device
 
383
  cross_attention_kwargs = cross_attention_kwargs or {}
384
 
385
  # Sigmas
@@ -405,13 +407,13 @@ class StableDiffusionXLOmostPipeline(StableDiffusionXLImg2ImgPipeline):
405
 
406
  # Batch
407
 
408
- latents = latents.to(device)
409
- add_time_ids = add_time_ids.repeat(batch_size, 1).to(device)
410
- add_neg_time_ids = add_neg_time_ids.repeat(batch_size, 1).to(device)
411
- prompt_embeds = [(k.to(device), v.repeat(batch_size, 1, 1).to(noise)) for k, v in prompt_embeds]
412
- negative_prompt_embeds = [(k.to(device), v.repeat(batch_size, 1, 1).to(noise)) for k, v in negative_prompt_embeds]
413
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(batch_size, 1).to(noise)
414
- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1).to(noise)
415
 
416
  # Feeds
417
 
 
90
  return torch.cat([sigmas, sigmas.new_zeros([1])])
91
 
92
  def __call__(self, x, sigma, **extra_args):
93
+ dtype = torch.float32
94
  x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data ** 2) ** 0.5
95
  t = self.timestep(sigma)
96
  cfg_scale = extra_args['cfg_scale']
 
381
  ):
382
 
383
  device = self.unet.device
384
+ dtype = torch.float32
385
  cross_attention_kwargs = cross_attention_kwargs or {}
386
 
387
  # Sigmas
 
407
 
408
  # Batch
409
 
410
+ latents = latents.to(device).to(dtype)
411
+ add_time_ids = add_time_ids.repeat(batch_size, 1).to(device).to(dtype)
412
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size, 1).to(device).to(dtype)
413
+ prompt_embeds = [(k.to(device), v.repeat(batch_size, 1, 1).to(device).to(dtype)) for k, v in prompt_embeds]
414
+ negative_prompt_embeds = [(k.to(device), v.repeat(batch_size, 1, 1).to(device).to(dtype)) for k, v in negative_prompt_embeds]
415
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(batch_size, 1).to(device).to(dtype)
416
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1).to(device).to(dtype)
417
 
418
  # Feeds
419