Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import safetensors.torch | |
import torch | |
import tqdm | |
from ..log import log | |
from ..utils import Operation, Precision | |
from ..utils import output_dir as comfy_out_dir | |
PRUNE_DATA = { | |
"known_junk_prefix": [ | |
"embedding_manager.embedder.", | |
"lora_te_text_model", | |
"control_model.", | |
], | |
"nai_keys": { | |
"cond_stage_model.transformer.embeddings.": "cond_stage_model.transformer.text_model.embeddings.", | |
"cond_stage_model.transformer.encoder.": "cond_stage_model.transformer.text_model.encoder.", | |
"cond_stage_model.transformer.final_layer_norm.": "cond_stage_model.transformer.text_model.final_layer_norm.", | |
}, | |
} | |
# position_ids in clip is int64. model_ema.num_updates is int32 | |
dtypes_to_fp16 = {torch.float32, torch.float64, torch.bfloat16} | |
dtypes_to_bf16 = {torch.float32, torch.float64, torch.float16} | |
dtypes_to_fp8 = {torch.float32, torch.float64, torch.bfloat16, torch.float16} | |
class MTB_ModelPruner: | |
def INPUT_TYPES(cls): | |
return { | |
"optional": { | |
"unet": ("MODEL",), | |
"clip": ("CLIP",), | |
"vae": ("VAE",), | |
}, | |
"required": { | |
"save_separately": ("BOOLEAN", {"default": False}), | |
"save_folder": ("STRING", {"default": "checkpoints/ComfyUI"}), | |
"fix_clip": ("BOOLEAN", {"default": True}), | |
"remove_junk": ("BOOLEAN", {"default": True}), | |
"ema_mode": ( | |
("disabled", "remove_ema", "ema_only"), | |
{"default": "remove_ema"}, | |
), | |
"precision_unet": ( | |
Precision.list_members(), | |
{"default": Precision.FULL.value}, | |
), | |
"operation_unet": ( | |
Operation.list_members(), | |
{"default": Operation.CONVERT.value}, | |
), | |
"precision_clip": ( | |
Precision.list_members(), | |
{"default": Precision.FULL.value}, | |
), | |
"operation_clip": ( | |
Operation.list_members(), | |
{"default": Operation.CONVERT.value}, | |
), | |
"precision_vae": ( | |
Precision.list_members(), | |
{"default": Precision.FULL.value}, | |
), | |
"operation_vae": ( | |
Operation.list_members(), | |
{"default": Operation.CONVERT.value}, | |
), | |
}, | |
} | |
OUTPUT_NODE = True | |
RETURN_TYPES = () | |
CATEGORY = "mtb/prune" | |
FUNCTION = "prune" | |
def convert_precision(self, tensor: torch.Tensor, precision: Precision): | |
precision = Precision.from_str(precision) | |
log.debug(f"Converting to {precision}") | |
match precision: | |
case Precision.FP8: | |
if tensor.dtype in dtypes_to_fp8: | |
return tensor.to(torch.float8_e4m3fn) | |
log.error(f"Cannot convert {tensor.dtype} to fp8") | |
return tensor | |
case Precision.FP16: | |
if tensor.dtype in dtypes_to_fp16: | |
return tensor.half() | |
log.error(f"Cannot convert {tensor.dtype} to f16") | |
return tensor | |
case Precision.BF16: | |
if tensor.dtype in dtypes_to_bf16: | |
return tensor.bfloat16() | |
log.error(f"Cannot convert {tensor.dtype} to bf16") | |
return tensor | |
case Precision.FULL | Precision.FP32: | |
return tensor | |
def is_sdxl_model(self, clip: dict[str, torch.Tensor] | None): | |
if clip: | |
return (any(k.startswith("conditioner.embedders") for k in clip),) | |
return False | |
def has_ema(self, unet: dict[str, torch.Tensor]): | |
return any(k.startswith("model_ema") for k in unet) | |
def fix_clip(self, clip: dict[str, torch.Tensor] | None): | |
if self.is_sdxl_model(clip): | |
log.warn("[fix clip] SDXL not supported") | |
return | |
if clip is None: | |
return | |
position_id_key = ( | |
"cond_stage_model.transformer.text_model.embeddings.position_ids" | |
) | |
if position_id_key in clip: | |
correct = torch.Tensor([list(range(77))]).to(torch.int64) | |
now = clip[position_id_key].to(torch.int64) | |
broken = correct.ne(now) | |
broken = [i for i in range(77) if broken[0][i]] | |
if len(broken) != 0: | |
clip[position_id_key] = correct | |
log.info(f"[Converter] Fixed broken clip\n{broken}") | |
else: | |
log.info( | |
"[Converter] Clip in this model is fine, skip fixing..." | |
) | |
else: | |
log.info("[Converter] Missing position id in model, try fixing...") | |
clip[position_id_key] = torch.Tensor([list(range(77))]).to( | |
torch.int64 | |
) | |
return clip | |
def get_dicts(self, unet, clip, vae): | |
clip_sd = clip.get_sd() | |
state_dict = unet.model.state_dict_for_saving( | |
clip_sd, vae.get_sd(), None | |
) | |
unet = { | |
k: v | |
for k, v in state_dict.items() | |
if k.startswith("model.diffusion_model") | |
} | |
clip = { | |
k: v | |
for k, v in state_dict.items() | |
if k.startswith("cond_stage_model") | |
or k.startswith("conditioner.embedders") | |
} | |
vae = { | |
k: v | |
for k, v in state_dict.items() | |
if k.startswith("first_stage_model") | |
} | |
other = { | |
k: v | |
for k, v in state_dict.items() | |
if k not in unet and k not in vae and k not in clip | |
} | |
return (unet, clip, vae, other) | |
def do_remove_junk(self, tensors: dict[str, dict[str, torch.Tensor]]): | |
need_delete: list[str] = [] | |
for layer in tensors: | |
for key in layer: | |
for jk in PRUNE_DATA["known_junk_prefix"]: | |
if key.startswith(jk): | |
need_delete.append(".".join([layer, key])) | |
for k in need_delete: | |
log.info(f"Removing junk data: {k}") | |
del tensors[k] | |
return tensors | |
def prune( | |
self, | |
*, | |
save_separately: bool, | |
save_folder: str, | |
fix_clip: bool, | |
remove_junk: bool, | |
ema_mode: str, | |
precision_unet: Precision, | |
precision_clip: Precision, | |
precision_vae: Precision, | |
operation_unet: str, | |
operation_clip: str, | |
operation_vae: str, | |
unet: dict[str, torch.Tensor] | None = None, | |
clip: dict[str, torch.Tensor] | None = None, | |
vae: dict[str, torch.Tensor] | None = None, | |
): | |
operation = { | |
"unet": Operation.from_str(operation_unet), | |
"clip": Operation.from_str(operation_clip), | |
"vae": Operation.from_str(operation_vae), | |
} | |
precision = { | |
"unet": Precision.from_str(precision_unet), | |
"clip": Precision.from_str(precision_clip), | |
"vae": Precision.from_str(precision_vae), | |
} | |
unet, clip, vae, _other = self.get_dicts(unet, clip, vae) | |
out_dir = Path(save_folder) | |
folder = out_dir.parent | |
if not out_dir.is_absolute(): | |
folder = (comfy_out_dir / save_folder).parent | |
if not folder.exists(): | |
if folder.parent.exists(): | |
folder.mkdir() | |
else: | |
raise FileNotFoundError( | |
f"Folder {folder.parent} does not exist" | |
) | |
name = out_dir.name | |
save_name = f"{name}-{precision_unet}" | |
if ema_mode != "disabled": | |
save_name += f"-{ema_mode}" | |
if fix_clip: | |
save_name += "-clip-fix" | |
if ( | |
any(o == Operation.CONVERT for o in operation.values()) | |
and any(p == Precision.FP8 for p in precision.values()) | |
and torch.__version__ < "2.1.0" | |
): | |
raise NotImplementedError( | |
"PyTorch 2.1.0 or newer is required for fp8 conversion" | |
) | |
if not self.is_sdxl_model(clip): | |
for part in [unet, vae, clip]: | |
if part: | |
nai_keys = PRUNE_DATA["nai_keys"] | |
for k in list(part.keys()): | |
for r in nai_keys: | |
if isinstance(k, str) and k.startswith(r): | |
new_key = k.replace(r, nai_keys[r]) | |
part[new_key] = part[k] | |
del part[k] | |
log.info( | |
f"[Converter] Fixed novelai error key {k}" | |
) | |
break | |
if fix_clip: | |
clip = self.fix_clip(clip) | |
ok: dict[str, dict[str, torch.Tensor]] = { | |
"unet": {}, | |
"clip": {}, | |
"vae": {}, | |
} | |
def _hf(part: str, wk: str, t: torch.Tensor): | |
if not isinstance(t, torch.Tensor): | |
log.debug("Not a torch tensor, skipping key") | |
return | |
log.debug(f"Operation {operation[part]}") | |
if operation[part] == Operation.CONVERT: | |
ok[part][wk] = self.convert_precision( | |
t, precision[part] | |
) # conv_func(t) | |
elif operation[part] == Operation.COPY: | |
ok[part][wk] = t | |
elif operation[part] == Operation.DELETE: | |
return | |
log.info("[Converter] Converting model...") | |
for part_name, part in zip( | |
["unet", "vae", "clip", "other"], | |
[unet, vae, clip], | |
strict=False, | |
): | |
if part: | |
match ema_mode: | |
case "remove_ema": | |
for k, v in tqdm.tqdm(part.items()): | |
if "model_ema." not in k: | |
_hf(part_name, k, v) | |
case "ema_only": | |
if not self.has_ema(part): | |
log.warn("No EMA to extract") | |
return | |
for k in tqdm.tqdm(part): | |
ema_k = "___" | |
try: | |
ema_k = "model_ema." + k[6:].replace(".", "") | |
except Exception: | |
pass | |
if ema_k in part: | |
_hf(part_name, k, part[ema_k]) | |
elif not k.startswith("model_ema.") or k in [ | |
"model_ema.num_updates", | |
"model_ema.decay", | |
]: | |
_hf(part_name, k, part[k]) | |
case "disabled" | _: | |
for k, v in tqdm.tqdm(part.items()): | |
_hf(part_name, k, v) | |
if save_separately: | |
if remove_junk: | |
ok = self.do_remove_junk(ok) | |
flat_ok = { | |
k: v | |
for _, subdict in ok.items() | |
for k, v in subdict.items() | |
} | |
save_path = ( | |
folder / f"{part_name}-{save_name}.safetensors" | |
).as_posix() | |
safetensors.torch.save_file(flat_ok, save_path) | |
ok: dict[str, dict[str, torch.Tensor]] = { | |
"unet": {}, | |
"clip": {}, | |
"vae": {}, | |
} | |
if save_separately: | |
return () | |
if remove_junk: | |
ok = self.do_remove_junk(ok) | |
flat_ok = { | |
k: v for _, subdict in ok.items() for k, v in subdict.items() | |
} | |
try: | |
safetensors.torch.save_file( | |
flat_ok, (folder / f"{save_name}.safetensors").as_posix() | |
) | |
except Exception as e: | |
log.error(e) | |
return () | |
__nodes__ = [MTB_ModelPruner] | |