ControlNeXt / utils /utils.py
Eugeoter's picture
update
76be739
import math
from typing import Tuple, Union, Optional
def make_unet_conversion_map():
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
def convert_unet_state_dict(src_sd, conversion_map):
converted_sd = {}
for src_key, value in src_sd.items():
src_key_fragments = src_key.split(".")[:-1] # remove weight/bias
while len(src_key_fragments) > 0:
src_key_prefix = ".".join(src_key_fragments) + "."
if src_key_prefix in conversion_map:
converted_prefix = conversion_map[src_key_prefix]
converted_key = converted_prefix + src_key[len(src_key_prefix):]
converted_sd[converted_key] = value
break
src_key_fragments.pop(-1)
assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
return converted_sd
def convert_sdxl_unet_state_dict_to_diffusers(sd):
unet_conversion_map = make_unet_conversion_map()
conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
return convert_unet_state_dict(sd, conversion_dict)
def extract_unet_state_dict(state_dict):
unet_sd = {}
UNET_KEY_PREFIX = "model.diffusion_model."
for k, v in state_dict.items():
if k.startswith(UNET_KEY_PREFIX):
unet_sd[k[len(UNET_KEY_PREFIX):]] = v
return unet_sd
def log_model_info(model, name):
sd = model.state_dict() if hasattr(model, "state_dict") else model
print(
f"{name}:",
f" number of parameters: {sum(p.numel() for p in sd.values())}",
f" dtype: {sd[next(iter(sd))].dtype}",
sep='\n'
)
def around_reso(img_w, img_h, reso: Union[Tuple[int, int], int], divisible: Optional[int] = None, max_width=None, max_height=None) -> Tuple[int, int]:
r"""
w*h = reso*reso
w/h = img_w/img_h
=> w = img_ar*h
=> img_ar*h^2 = reso
=> h = sqrt(reso / img_ar)
"""
reso = reso if isinstance(reso, tuple) else (reso, reso)
divisible = divisible or 1
if img_w * img_h <= reso[0] * reso[1] and (not max_width or img_w <= max_width) and (not max_height or img_h <= max_height) and img_w % divisible == 0 and img_h % divisible == 0:
return (img_w, img_h)
img_ar = img_w / img_h
around_h = math.sqrt(reso[0]*reso[1] / img_ar)
around_w = img_ar * around_h // divisible * divisible
if max_width and around_w > max_width:
around_h = around_h * max_width // around_w
around_w = max_width
elif max_height and around_h > max_height:
around_w = around_w * max_height // around_h
around_h = max_height
around_h = min(around_h, max_height) if max_height else around_h
around_w = min(around_w, max_width) if max_width else around_w
around_h = int(around_h // divisible * divisible)
around_w = int(around_w // divisible * divisible)
return (around_w, around_h)