Spaces:
Running
on
Zero
Running
on
Zero
import math | |
from typing import Tuple, Union, Optional | |
def make_unet_conversion_map(): | |
unet_conversion_map_layer = [] | |
for i in range(3): # num_blocks is 3 in sdxl | |
# loop over downblocks/upblocks | |
for j in range(2): | |
# loop over resnets/attentions for downblocks | |
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." | |
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." | |
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) | |
if i < 3: | |
# no attention layers in down_blocks.3 | |
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." | |
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." | |
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) | |
for j in range(3): | |
# loop over resnets/attentions for upblocks | |
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." | |
sd_up_res_prefix = f"output_blocks.{3*i + j}.0." | |
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) | |
# if i > 0: commentout for sdxl | |
# no attention layers in up_blocks.0 | |
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." | |
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." | |
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) | |
if i < 3: | |
# no downsample in down_blocks.3 | |
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." | |
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." | |
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) | |
# no upsample in up_blocks.3 | |
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." | |
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl | |
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) | |
hf_mid_atn_prefix = "mid_block.attentions.0." | |
sd_mid_atn_prefix = "middle_block.1." | |
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) | |
for j in range(2): | |
hf_mid_res_prefix = f"mid_block.resnets.{j}." | |
sd_mid_res_prefix = f"middle_block.{2*j}." | |
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) | |
unet_conversion_map_resnet = [ | |
# (stable-diffusion, HF Diffusers) | |
("in_layers.0.", "norm1."), | |
("in_layers.2.", "conv1."), | |
("out_layers.0.", "norm2."), | |
("out_layers.3.", "conv2."), | |
("emb_layers.1.", "time_emb_proj."), | |
("skip_connection.", "conv_shortcut."), | |
] | |
unet_conversion_map = [] | |
for sd, hf in unet_conversion_map_layer: | |
if "resnets" in hf: | |
for sd_res, hf_res in unet_conversion_map_resnet: | |
unet_conversion_map.append((sd + sd_res, hf + hf_res)) | |
else: | |
unet_conversion_map.append((sd, hf)) | |
for j in range(2): | |
hf_time_embed_prefix = f"time_embedding.linear_{j+1}." | |
sd_time_embed_prefix = f"time_embed.{j*2}." | |
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) | |
for j in range(2): | |
hf_label_embed_prefix = f"add_embedding.linear_{j+1}." | |
sd_label_embed_prefix = f"label_emb.0.{j*2}." | |
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) | |
unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) | |
unet_conversion_map.append(("out.0.", "conv_norm_out.")) | |
unet_conversion_map.append(("out.2.", "conv_out.")) | |
return unet_conversion_map | |
def convert_unet_state_dict(src_sd, conversion_map): | |
converted_sd = {} | |
for src_key, value in src_sd.items(): | |
src_key_fragments = src_key.split(".")[:-1] # remove weight/bias | |
while len(src_key_fragments) > 0: | |
src_key_prefix = ".".join(src_key_fragments) + "." | |
if src_key_prefix in conversion_map: | |
converted_prefix = conversion_map[src_key_prefix] | |
converted_key = converted_prefix + src_key[len(src_key_prefix):] | |
converted_sd[converted_key] = value | |
break | |
src_key_fragments.pop(-1) | |
assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map" | |
return converted_sd | |
def convert_sdxl_unet_state_dict_to_diffusers(sd): | |
unet_conversion_map = make_unet_conversion_map() | |
conversion_dict = {sd: hf for sd, hf in unet_conversion_map} | |
return convert_unet_state_dict(sd, conversion_dict) | |
def extract_unet_state_dict(state_dict): | |
unet_sd = {} | |
UNET_KEY_PREFIX = "model.diffusion_model." | |
for k, v in state_dict.items(): | |
if k.startswith(UNET_KEY_PREFIX): | |
unet_sd[k[len(UNET_KEY_PREFIX):]] = v | |
return unet_sd | |
def log_model_info(model, name): | |
sd = model.state_dict() if hasattr(model, "state_dict") else model | |
print( | |
f"{name}:", | |
f" number of parameters: {sum(p.numel() for p in sd.values())}", | |
f" dtype: {sd[next(iter(sd))].dtype}", | |
sep='\n' | |
) | |
def around_reso(img_w, img_h, reso: Union[Tuple[int, int], int], divisible: Optional[int] = None, max_width=None, max_height=None) -> Tuple[int, int]: | |
r""" | |
w*h = reso*reso | |
w/h = img_w/img_h | |
=> w = img_ar*h | |
=> img_ar*h^2 = reso | |
=> h = sqrt(reso / img_ar) | |
""" | |
reso = reso if isinstance(reso, tuple) else (reso, reso) | |
divisible = divisible or 1 | |
if img_w * img_h <= reso[0] * reso[1] and (not max_width or img_w <= max_width) and (not max_height or img_h <= max_height) and img_w % divisible == 0 and img_h % divisible == 0: | |
return (img_w, img_h) | |
img_ar = img_w / img_h | |
around_h = math.sqrt(reso[0]*reso[1] / img_ar) | |
around_w = img_ar * around_h // divisible * divisible | |
if max_width and around_w > max_width: | |
around_h = around_h * max_width // around_w | |
around_w = max_width | |
elif max_height and around_h > max_height: | |
around_w = around_w * max_height // around_h | |
around_h = max_height | |
around_h = min(around_h, max_height) if max_height else around_h | |
around_w = min(around_w, max_width) if max_width else around_w | |
around_h = int(around_h // divisible * divisible) | |
around_w = int(around_w // divisible * divisible) | |
return (around_w, around_h) | |