multimodalart's picture
Squashing commit
4450790 verified
from pathlib import Path
import safetensors.torch
import torch
import tqdm
from ..log import log
from ..utils import Operation, Precision
from ..utils import output_dir as comfy_out_dir
PRUNE_DATA = {
"known_junk_prefix": [
"embedding_manager.embedder.",
"lora_te_text_model",
"control_model.",
],
"nai_keys": {
"cond_stage_model.transformer.embeddings.": "cond_stage_model.transformer.text_model.embeddings.",
"cond_stage_model.transformer.encoder.": "cond_stage_model.transformer.text_model.encoder.",
"cond_stage_model.transformer.final_layer_norm.": "cond_stage_model.transformer.text_model.final_layer_norm.",
},
}
# position_ids in clip is int64. model_ema.num_updates is int32
dtypes_to_fp16 = {torch.float32, torch.float64, torch.bfloat16}
dtypes_to_bf16 = {torch.float32, torch.float64, torch.float16}
dtypes_to_fp8 = {torch.float32, torch.float64, torch.bfloat16, torch.float16}
class MTB_ModelPruner:
@classmethod
def INPUT_TYPES(cls):
return {
"optional": {
"unet": ("MODEL",),
"clip": ("CLIP",),
"vae": ("VAE",),
},
"required": {
"save_separately": ("BOOLEAN", {"default": False}),
"save_folder": ("STRING", {"default": "checkpoints/ComfyUI"}),
"fix_clip": ("BOOLEAN", {"default": True}),
"remove_junk": ("BOOLEAN", {"default": True}),
"ema_mode": (
("disabled", "remove_ema", "ema_only"),
{"default": "remove_ema"},
),
"precision_unet": (
Precision.list_members(),
{"default": Precision.FULL.value},
),
"operation_unet": (
Operation.list_members(),
{"default": Operation.CONVERT.value},
),
"precision_clip": (
Precision.list_members(),
{"default": Precision.FULL.value},
),
"operation_clip": (
Operation.list_members(),
{"default": Operation.CONVERT.value},
),
"precision_vae": (
Precision.list_members(),
{"default": Precision.FULL.value},
),
"operation_vae": (
Operation.list_members(),
{"default": Operation.CONVERT.value},
),
},
}
OUTPUT_NODE = True
RETURN_TYPES = ()
CATEGORY = "mtb/prune"
FUNCTION = "prune"
def convert_precision(self, tensor: torch.Tensor, precision: Precision):
precision = Precision.from_str(precision)
log.debug(f"Converting to {precision}")
match precision:
case Precision.FP8:
if tensor.dtype in dtypes_to_fp8:
return tensor.to(torch.float8_e4m3fn)
log.error(f"Cannot convert {tensor.dtype} to fp8")
return tensor
case Precision.FP16:
if tensor.dtype in dtypes_to_fp16:
return tensor.half()
log.error(f"Cannot convert {tensor.dtype} to f16")
return tensor
case Precision.BF16:
if tensor.dtype in dtypes_to_bf16:
return tensor.bfloat16()
log.error(f"Cannot convert {tensor.dtype} to bf16")
return tensor
case Precision.FULL | Precision.FP32:
return tensor
def is_sdxl_model(self, clip: dict[str, torch.Tensor] | None):
if clip:
return (any(k.startswith("conditioner.embedders") for k in clip),)
return False
def has_ema(self, unet: dict[str, torch.Tensor]):
return any(k.startswith("model_ema") for k in unet)
def fix_clip(self, clip: dict[str, torch.Tensor] | None):
if self.is_sdxl_model(clip):
log.warn("[fix clip] SDXL not supported")
return
if clip is None:
return
position_id_key = (
"cond_stage_model.transformer.text_model.embeddings.position_ids"
)
if position_id_key in clip:
correct = torch.Tensor([list(range(77))]).to(torch.int64)
now = clip[position_id_key].to(torch.int64)
broken = correct.ne(now)
broken = [i for i in range(77) if broken[0][i]]
if len(broken) != 0:
clip[position_id_key] = correct
log.info(f"[Converter] Fixed broken clip\n{broken}")
else:
log.info(
"[Converter] Clip in this model is fine, skip fixing..."
)
else:
log.info("[Converter] Missing position id in model, try fixing...")
clip[position_id_key] = torch.Tensor([list(range(77))]).to(
torch.int64
)
return clip
def get_dicts(self, unet, clip, vae):
clip_sd = clip.get_sd()
state_dict = unet.model.state_dict_for_saving(
clip_sd, vae.get_sd(), None
)
unet = {
k: v
for k, v in state_dict.items()
if k.startswith("model.diffusion_model")
}
clip = {
k: v
for k, v in state_dict.items()
if k.startswith("cond_stage_model")
or k.startswith("conditioner.embedders")
}
vae = {
k: v
for k, v in state_dict.items()
if k.startswith("first_stage_model")
}
other = {
k: v
for k, v in state_dict.items()
if k not in unet and k not in vae and k not in clip
}
return (unet, clip, vae, other)
def do_remove_junk(self, tensors: dict[str, dict[str, torch.Tensor]]):
need_delete: list[str] = []
for layer in tensors:
for key in layer:
for jk in PRUNE_DATA["known_junk_prefix"]:
if key.startswith(jk):
need_delete.append(".".join([layer, key]))
for k in need_delete:
log.info(f"Removing junk data: {k}")
del tensors[k]
return tensors
def prune(
self,
*,
save_separately: bool,
save_folder: str,
fix_clip: bool,
remove_junk: bool,
ema_mode: str,
precision_unet: Precision,
precision_clip: Precision,
precision_vae: Precision,
operation_unet: str,
operation_clip: str,
operation_vae: str,
unet: dict[str, torch.Tensor] | None = None,
clip: dict[str, torch.Tensor] | None = None,
vae: dict[str, torch.Tensor] | None = None,
):
operation = {
"unet": Operation.from_str(operation_unet),
"clip": Operation.from_str(operation_clip),
"vae": Operation.from_str(operation_vae),
}
precision = {
"unet": Precision.from_str(precision_unet),
"clip": Precision.from_str(precision_clip),
"vae": Precision.from_str(precision_vae),
}
unet, clip, vae, _other = self.get_dicts(unet, clip, vae)
out_dir = Path(save_folder)
folder = out_dir.parent
if not out_dir.is_absolute():
folder = (comfy_out_dir / save_folder).parent
if not folder.exists():
if folder.parent.exists():
folder.mkdir()
else:
raise FileNotFoundError(
f"Folder {folder.parent} does not exist"
)
name = out_dir.name
save_name = f"{name}-{precision_unet}"
if ema_mode != "disabled":
save_name += f"-{ema_mode}"
if fix_clip:
save_name += "-clip-fix"
if (
any(o == Operation.CONVERT for o in operation.values())
and any(p == Precision.FP8 for p in precision.values())
and torch.__version__ < "2.1.0"
):
raise NotImplementedError(
"PyTorch 2.1.0 or newer is required for fp8 conversion"
)
if not self.is_sdxl_model(clip):
for part in [unet, vae, clip]:
if part:
nai_keys = PRUNE_DATA["nai_keys"]
for k in list(part.keys()):
for r in nai_keys:
if isinstance(k, str) and k.startswith(r):
new_key = k.replace(r, nai_keys[r])
part[new_key] = part[k]
del part[k]
log.info(
f"[Converter] Fixed novelai error key {k}"
)
break
if fix_clip:
clip = self.fix_clip(clip)
ok: dict[str, dict[str, torch.Tensor]] = {
"unet": {},
"clip": {},
"vae": {},
}
def _hf(part: str, wk: str, t: torch.Tensor):
if not isinstance(t, torch.Tensor):
log.debug("Not a torch tensor, skipping key")
return
log.debug(f"Operation {operation[part]}")
if operation[part] == Operation.CONVERT:
ok[part][wk] = self.convert_precision(
t, precision[part]
) # conv_func(t)
elif operation[part] == Operation.COPY:
ok[part][wk] = t
elif operation[part] == Operation.DELETE:
return
log.info("[Converter] Converting model...")
for part_name, part in zip(
["unet", "vae", "clip", "other"],
[unet, vae, clip],
strict=False,
):
if part:
match ema_mode:
case "remove_ema":
for k, v in tqdm.tqdm(part.items()):
if "model_ema." not in k:
_hf(part_name, k, v)
case "ema_only":
if not self.has_ema(part):
log.warn("No EMA to extract")
return
for k in tqdm.tqdm(part):
ema_k = "___"
try:
ema_k = "model_ema." + k[6:].replace(".", "")
except Exception:
pass
if ema_k in part:
_hf(part_name, k, part[ema_k])
elif not k.startswith("model_ema.") or k in [
"model_ema.num_updates",
"model_ema.decay",
]:
_hf(part_name, k, part[k])
case "disabled" | _:
for k, v in tqdm.tqdm(part.items()):
_hf(part_name, k, v)
if save_separately:
if remove_junk:
ok = self.do_remove_junk(ok)
flat_ok = {
k: v
for _, subdict in ok.items()
for k, v in subdict.items()
}
save_path = (
folder / f"{part_name}-{save_name}.safetensors"
).as_posix()
safetensors.torch.save_file(flat_ok, save_path)
ok: dict[str, dict[str, torch.Tensor]] = {
"unet": {},
"clip": {},
"vae": {},
}
if save_separately:
return ()
if remove_junk:
ok = self.do_remove_junk(ok)
flat_ok = {
k: v for _, subdict in ok.items() for k, v in subdict.items()
}
try:
safetensors.torch.save_file(
flat_ok, (folder / f"{save_name}.safetensors").as_posix()
)
except Exception as e:
log.error(e)
return ()
__nodes__ = [MTB_ModelPruner]