from pathlib import Path import safetensors.torch import torch import tqdm from ..log import log from ..utils import Operation, Precision from ..utils import output_dir as comfy_out_dir PRUNE_DATA = { "known_junk_prefix": [ "embedding_manager.embedder.", "lora_te_text_model", "control_model.", ], "nai_keys": { "cond_stage_model.transformer.embeddings.": "cond_stage_model.transformer.text_model.embeddings.", "cond_stage_model.transformer.encoder.": "cond_stage_model.transformer.text_model.encoder.", "cond_stage_model.transformer.final_layer_norm.": "cond_stage_model.transformer.text_model.final_layer_norm.", }, } # position_ids in clip is int64. model_ema.num_updates is int32 dtypes_to_fp16 = {torch.float32, torch.float64, torch.bfloat16} dtypes_to_bf16 = {torch.float32, torch.float64, torch.float16} dtypes_to_fp8 = {torch.float32, torch.float64, torch.bfloat16, torch.float16} class MTB_ModelPruner: @classmethod def INPUT_TYPES(cls): return { "optional": { "unet": ("MODEL",), "clip": ("CLIP",), "vae": ("VAE",), }, "required": { "save_separately": ("BOOLEAN", {"default": False}), "save_folder": ("STRING", {"default": "checkpoints/ComfyUI"}), "fix_clip": ("BOOLEAN", {"default": True}), "remove_junk": ("BOOLEAN", {"default": True}), "ema_mode": ( ("disabled", "remove_ema", "ema_only"), {"default": "remove_ema"}, ), "precision_unet": ( Precision.list_members(), {"default": Precision.FULL.value}, ), "operation_unet": ( Operation.list_members(), {"default": Operation.CONVERT.value}, ), "precision_clip": ( Precision.list_members(), {"default": Precision.FULL.value}, ), "operation_clip": ( Operation.list_members(), {"default": Operation.CONVERT.value}, ), "precision_vae": ( Precision.list_members(), {"default": Precision.FULL.value}, ), "operation_vae": ( Operation.list_members(), {"default": Operation.CONVERT.value}, ), }, } OUTPUT_NODE = True RETURN_TYPES = () CATEGORY = "mtb/prune" FUNCTION = "prune" def convert_precision(self, tensor: torch.Tensor, precision: Precision): precision = Precision.from_str(precision) log.debug(f"Converting to {precision}") match precision: case Precision.FP8: if tensor.dtype in dtypes_to_fp8: return tensor.to(torch.float8_e4m3fn) log.error(f"Cannot convert {tensor.dtype} to fp8") return tensor case Precision.FP16: if tensor.dtype in dtypes_to_fp16: return tensor.half() log.error(f"Cannot convert {tensor.dtype} to f16") return tensor case Precision.BF16: if tensor.dtype in dtypes_to_bf16: return tensor.bfloat16() log.error(f"Cannot convert {tensor.dtype} to bf16") return tensor case Precision.FULL | Precision.FP32: return tensor def is_sdxl_model(self, clip: dict[str, torch.Tensor] | None): if clip: return (any(k.startswith("conditioner.embedders") for k in clip),) return False def has_ema(self, unet: dict[str, torch.Tensor]): return any(k.startswith("model_ema") for k in unet) def fix_clip(self, clip: dict[str, torch.Tensor] | None): if self.is_sdxl_model(clip): log.warn("[fix clip] SDXL not supported") return if clip is None: return position_id_key = ( "cond_stage_model.transformer.text_model.embeddings.position_ids" ) if position_id_key in clip: correct = torch.Tensor([list(range(77))]).to(torch.int64) now = clip[position_id_key].to(torch.int64) broken = correct.ne(now) broken = [i for i in range(77) if broken[0][i]] if len(broken) != 0: clip[position_id_key] = correct log.info(f"[Converter] Fixed broken clip\n{broken}") else: log.info( "[Converter] Clip in this model is fine, skip fixing..." ) else: log.info("[Converter] Missing position id in model, try fixing...") clip[position_id_key] = torch.Tensor([list(range(77))]).to( torch.int64 ) return clip def get_dicts(self, unet, clip, vae): clip_sd = clip.get_sd() state_dict = unet.model.state_dict_for_saving( clip_sd, vae.get_sd(), None ) unet = { k: v for k, v in state_dict.items() if k.startswith("model.diffusion_model") } clip = { k: v for k, v in state_dict.items() if k.startswith("cond_stage_model") or k.startswith("conditioner.embedders") } vae = { k: v for k, v in state_dict.items() if k.startswith("first_stage_model") } other = { k: v for k, v in state_dict.items() if k not in unet and k not in vae and k not in clip } return (unet, clip, vae, other) def do_remove_junk(self, tensors: dict[str, dict[str, torch.Tensor]]): need_delete: list[str] = [] for layer in tensors: for key in layer: for jk in PRUNE_DATA["known_junk_prefix"]: if key.startswith(jk): need_delete.append(".".join([layer, key])) for k in need_delete: log.info(f"Removing junk data: {k}") del tensors[k] return tensors def prune( self, *, save_separately: bool, save_folder: str, fix_clip: bool, remove_junk: bool, ema_mode: str, precision_unet: Precision, precision_clip: Precision, precision_vae: Precision, operation_unet: str, operation_clip: str, operation_vae: str, unet: dict[str, torch.Tensor] | None = None, clip: dict[str, torch.Tensor] | None = None, vae: dict[str, torch.Tensor] | None = None, ): operation = { "unet": Operation.from_str(operation_unet), "clip": Operation.from_str(operation_clip), "vae": Operation.from_str(operation_vae), } precision = { "unet": Precision.from_str(precision_unet), "clip": Precision.from_str(precision_clip), "vae": Precision.from_str(precision_vae), } unet, clip, vae, _other = self.get_dicts(unet, clip, vae) out_dir = Path(save_folder) folder = out_dir.parent if not out_dir.is_absolute(): folder = (comfy_out_dir / save_folder).parent if not folder.exists(): if folder.parent.exists(): folder.mkdir() else: raise FileNotFoundError( f"Folder {folder.parent} does not exist" ) name = out_dir.name save_name = f"{name}-{precision_unet}" if ema_mode != "disabled": save_name += f"-{ema_mode}" if fix_clip: save_name += "-clip-fix" if ( any(o == Operation.CONVERT for o in operation.values()) and any(p == Precision.FP8 for p in precision.values()) and torch.__version__ < "2.1.0" ): raise NotImplementedError( "PyTorch 2.1.0 or newer is required for fp8 conversion" ) if not self.is_sdxl_model(clip): for part in [unet, vae, clip]: if part: nai_keys = PRUNE_DATA["nai_keys"] for k in list(part.keys()): for r in nai_keys: if isinstance(k, str) and k.startswith(r): new_key = k.replace(r, nai_keys[r]) part[new_key] = part[k] del part[k] log.info( f"[Converter] Fixed novelai error key {k}" ) break if fix_clip: clip = self.fix_clip(clip) ok: dict[str, dict[str, torch.Tensor]] = { "unet": {}, "clip": {}, "vae": {}, } def _hf(part: str, wk: str, t: torch.Tensor): if not isinstance(t, torch.Tensor): log.debug("Not a torch tensor, skipping key") return log.debug(f"Operation {operation[part]}") if operation[part] == Operation.CONVERT: ok[part][wk] = self.convert_precision( t, precision[part] ) # conv_func(t) elif operation[part] == Operation.COPY: ok[part][wk] = t elif operation[part] == Operation.DELETE: return log.info("[Converter] Converting model...") for part_name, part in zip( ["unet", "vae", "clip", "other"], [unet, vae, clip], strict=False, ): if part: match ema_mode: case "remove_ema": for k, v in tqdm.tqdm(part.items()): if "model_ema." not in k: _hf(part_name, k, v) case "ema_only": if not self.has_ema(part): log.warn("No EMA to extract") return for k in tqdm.tqdm(part): ema_k = "___" try: ema_k = "model_ema." + k[6:].replace(".", "") except Exception: pass if ema_k in part: _hf(part_name, k, part[ema_k]) elif not k.startswith("model_ema.") or k in [ "model_ema.num_updates", "model_ema.decay", ]: _hf(part_name, k, part[k]) case "disabled" | _: for k, v in tqdm.tqdm(part.items()): _hf(part_name, k, v) if save_separately: if remove_junk: ok = self.do_remove_junk(ok) flat_ok = { k: v for _, subdict in ok.items() for k, v in subdict.items() } save_path = ( folder / f"{part_name}-{save_name}.safetensors" ).as_posix() safetensors.torch.save_file(flat_ok, save_path) ok: dict[str, dict[str, torch.Tensor]] = { "unet": {}, "clip": {}, "vae": {}, } if save_separately: return () if remove_junk: ok = self.do_remove_junk(ok) flat_ok = { k: v for _, subdict in ok.items() for k, v in subdict.items() } try: safetensors.torch.save_file( flat_ok, (folder / f"{save_name}.safetensors").as_posix() ) except Exception as e: log.error(e) return () __nodes__ = [MTB_ModelPruner]