rectified dtypes
Browse files- app.py +20 -5
- lib_omost/pipeline.py +9 -7
app.py
CHANGED
@@ -35,21 +35,36 @@ import lib_omost.canvas as omost_canvas
|
|
35 |
|
36 |
# SDXL
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
sdxl_name = 'SG161222/RealVisXL_V4.0'
|
39 |
-
# sdxl_name = 'stabilityai/stable-diffusion-xl-base-1.0'
|
40 |
|
41 |
tokenizer = CLIPTokenizer.from_pretrained(
|
42 |
sdxl_name, subfolder="tokenizer")
|
43 |
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
44 |
sdxl_name, subfolder="tokenizer_2")
|
45 |
text_encoder = CLIPTextModel.from_pretrained(
|
46 |
-
sdxl_name, subfolder="text_encoder", torch_dtype=torch.
|
47 |
text_encoder_2 = CLIPTextModel.from_pretrained(
|
48 |
-
sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.
|
49 |
vae = AutoencoderKL.from_pretrained(
|
50 |
-
sdxl_name, subfolder="vae", torch_dtype=torch.
|
51 |
unet = UNet2DConditionModel.from_pretrained(
|
52 |
-
sdxl_name, subfolder="unet", torch_dtype=torch.
|
53 |
|
54 |
unet.set_attn_processor(AttnProcessor2_0())
|
55 |
vae.set_attn_processor(AttnProcessor2_0())
|
|
|
35 |
|
36 |
# SDXL
|
37 |
|
38 |
+
# sdxl_name = 'SG161222/RealVisXL_V4.0'
|
39 |
+
# # sdxl_name = 'stabilityai/stable-diffusion-xl-base-1.0'
|
40 |
+
|
41 |
+
# tokenizer = CLIPTokenizer.from_pretrained(
|
42 |
+
# sdxl_name, subfolder="tokenizer")
|
43 |
+
# tokenizer_2 = CLIPTokenizer.from_pretrained(
|
44 |
+
# sdxl_name, subfolder="tokenizer_2")
|
45 |
+
# text_encoder = CLIPTextModel.from_pretrained(
|
46 |
+
# sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16", device_map="auto")
|
47 |
+
# text_encoder_2 = CLIPTextModel.from_pretrained(
|
48 |
+
# sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16", device_map="auto")
|
49 |
+
# vae = AutoencoderKL.from_pretrained(
|
50 |
+
# sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16", device_map="auto") # bfloat16 vae
|
51 |
+
# unet = UNet2DConditionModel.from_pretrained(
|
52 |
+
# sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16", device_map="auto")
|
53 |
+
|
54 |
sdxl_name = 'SG161222/RealVisXL_V4.0'
|
|
|
55 |
|
56 |
tokenizer = CLIPTokenizer.from_pretrained(
|
57 |
sdxl_name, subfolder="tokenizer")
|
58 |
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
59 |
sdxl_name, subfolder="tokenizer_2")
|
60 |
text_encoder = CLIPTextModel.from_pretrained(
|
61 |
+
sdxl_name, subfolder="text_encoder", torch_dtype=torch.float32, device_map="auto")
|
62 |
text_encoder_2 = CLIPTextModel.from_pretrained(
|
63 |
+
sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float32, device_map="auto")
|
64 |
vae = AutoencoderKL.from_pretrained(
|
65 |
+
sdxl_name, subfolder="vae", torch_dtype=torch.float32, device_map="auto")
|
66 |
unet = UNet2DConditionModel.from_pretrained(
|
67 |
+
sdxl_name, subfolder="unet", torch_dtype=torch.float32, device_map="auto")
|
68 |
|
69 |
unet.set_attn_processor(AttnProcessor2_0())
|
70 |
vae.set_attn_processor(AttnProcessor2_0())
|
lib_omost/pipeline.py
CHANGED
@@ -90,6 +90,7 @@ class KModel:
|
|
90 |
return torch.cat([sigmas, sigmas.new_zeros([1])])
|
91 |
|
92 |
def __call__(self, x, sigma, **extra_args):
|
|
|
93 |
x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data ** 2) ** 0.5
|
94 |
t = self.timestep(sigma)
|
95 |
cfg_scale = extra_args['cfg_scale']
|
@@ -380,6 +381,7 @@ class StableDiffusionXLOmostPipeline(StableDiffusionXLImg2ImgPipeline):
|
|
380 |
):
|
381 |
|
382 |
device = self.unet.device
|
|
|
383 |
cross_attention_kwargs = cross_attention_kwargs or {}
|
384 |
|
385 |
# Sigmas
|
@@ -405,13 +407,13 @@ class StableDiffusionXLOmostPipeline(StableDiffusionXLImg2ImgPipeline):
|
|
405 |
|
406 |
# Batch
|
407 |
|
408 |
-
latents = latents.to(device)
|
409 |
-
add_time_ids = add_time_ids.repeat(batch_size, 1).to(device)
|
410 |
-
add_neg_time_ids = add_neg_time_ids.repeat(batch_size, 1).to(device)
|
411 |
-
prompt_embeds = [(k.to(device), v.repeat(batch_size, 1, 1).to(
|
412 |
-
negative_prompt_embeds = [(k.to(device), v.repeat(batch_size, 1, 1).to(
|
413 |
-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(batch_size, 1).to(
|
414 |
-
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1).to(
|
415 |
|
416 |
# Feeds
|
417 |
|
|
|
90 |
return torch.cat([sigmas, sigmas.new_zeros([1])])
|
91 |
|
92 |
def __call__(self, x, sigma, **extra_args):
|
93 |
+
dtype = torch.float32
|
94 |
x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data ** 2) ** 0.5
|
95 |
t = self.timestep(sigma)
|
96 |
cfg_scale = extra_args['cfg_scale']
|
|
|
381 |
):
|
382 |
|
383 |
device = self.unet.device
|
384 |
+
dtype = torch.float32
|
385 |
cross_attention_kwargs = cross_attention_kwargs or {}
|
386 |
|
387 |
# Sigmas
|
|
|
407 |
|
408 |
# Batch
|
409 |
|
410 |
+
latents = latents.to(device).to(dtype)
|
411 |
+
add_time_ids = add_time_ids.repeat(batch_size, 1).to(device).to(dtype)
|
412 |
+
add_neg_time_ids = add_neg_time_ids.repeat(batch_size, 1).to(device).to(dtype)
|
413 |
+
prompt_embeds = [(k.to(device), v.repeat(batch_size, 1, 1).to(device).to(dtype)) for k, v in prompt_embeds]
|
414 |
+
negative_prompt_embeds = [(k.to(device), v.repeat(batch_size, 1, 1).to(device).to(dtype)) for k, v in negative_prompt_embeds]
|
415 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(batch_size, 1).to(device).to(dtype)
|
416 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1).to(device).to(dtype)
|
417 |
|
418 |
# Feeds
|
419 |
|