# Optional face enhance nodes # region imports import sys from pathlib import Path import comfy.model_management as model_management import cv2 import insightface import numpy as np import onnxruntime import torch from insightface.model_zoo.inswapper import INSwapper from PIL import Image from ..errors import ModelNotFound from ..log import NullWriter, mklog from ..utils import download_antelopev2, get_model_path, pil2tensor, tensor2pil # endregion log = mklog(__name__) class MTB_LoadFaceAnalysisModel: """Loads a face analysis model""" models = [] @classmethod def INPUT_TYPES(cls): return { "required": { "faceswap_model": ( ["antelopev2", "buffalo_l", "buffalo_m", "buffalo_sc"], {"default": "buffalo_l"}, ), }, } RETURN_TYPES = ("FACE_ANALYSIS_MODEL",) FUNCTION = "load_model" CATEGORY = "mtb/facetools" DEPRECATED = True def load_model(self, faceswap_model: str): if faceswap_model == "antelopev2": download_antelopev2() face_analyser = insightface.app.FaceAnalysis( name=faceswap_model, root=get_model_path("insightface").as_posix(), ) return (face_analyser,) class MTB_LoadFaceSwapModel: """Loads a faceswap model""" @staticmethod def get_models() -> list[Path]: models_path = get_model_path("insightface") if models_path.exists(): models = models_path.iterdir() return [x for x in models if x.suffix in [".onnx", ".pth"]] return [] @classmethod def INPUT_TYPES(cls): return { "required": { "faceswap_model": ( [x.name for x in cls.get_models()], {"default": "None"}, ), }, } RETURN_TYPES = ("FACESWAP_MODEL",) FUNCTION = "load_model" CATEGORY = "mtb/facetools" DEPRECATED = True def load_model(self, faceswap_model: str): model_path = get_model_path("insightface", faceswap_model) if not model_path or not model_path.exists(): raise ModelNotFound(f"{faceswap_model} ({model_path})") log.info(f"Loading model {model_path}") return ( INSwapper( model_path, onnxruntime.InferenceSession( path_or_bytes=model_path, providers=onnxruntime.get_available_providers(), ), ), ) # region roop node class MTB_FaceSwap: """Face swap using deepinsight/insightface models""" model = None model_path = None def __init__(self) -> None: pass @classmethod def INPUT_TYPES(cls): return { "required": { "image": ("IMAGE",), "reference": ("IMAGE",), "faces_index": ("STRING", {"default": "0"}), "faceanalysis_model": ( "FACE_ANALYSIS_MODEL", {"default": "None"}, ), "faceswap_model": ("FACESWAP_MODEL", {"default": "None"}), }, "optional": { "preserve_alpha": ("BOOLEAN", {"default": True}), }, } RETURN_TYPES = ("IMAGE",) FUNCTION = "swap" CATEGORY = "mtb/facetools" DEPRECATED = True def swap( self, image: torch.Tensor, reference: torch.Tensor, faces_index: str, faceanalysis_model, faceswap_model, preserve_alpha=False, ): def do_swap(img): model_management.throw_exception_if_processing_interrupted() img = tensor2pil(img)[0] ref = tensor2pil(reference)[0] alpha_channel = None if preserve_alpha and img.mode == "RGBA": alpha_channel = img.getchannel("A") img = img.convert("RGB") face_ids = { int(x) for x in faces_index.strip(",").split(",") if x.isnumeric() } sys.stdout = NullWriter() swapped = swap_face( faceanalysis_model, ref, img, faceswap_model, face_ids ) sys.stdout = sys.__stdout__ if alpha_channel: swapped.putalpha(alpha_channel) return pil2tensor(swapped) batch_count = image.size(0) log.info(f"Running insightface swap (batch size: {batch_count})") if reference.size(0) != 1: raise ValueError("Reference image must have batch size 1") if batch_count == 1: image = do_swap(image) else: image_batch = [do_swap(image[i]) for i in range(batch_count)] image = torch.cat(image_batch, dim=0) return (image,) # endregion # region face swap utils def get_face_single( face_analyser, img_data: np.ndarray, face_index=0, det_size=(640, 640) ): face_analyser.prepare(ctx_id=0, det_size=det_size) face = face_analyser.get(img_data) if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320: log.debug("No face ed, trying again with smaller image") det_size_half = (det_size[0] // 2, det_size[1] // 2) return get_face_single( face_analyser, img_data, face_index=face_index, det_size=det_size_half, ) try: return sorted(face, key=lambda x: x.bbox[0])[face_index] except IndexError: return None def swap_face( face_analyser, source_img: Image.Image | list[Image.Image], target_img: Image.Image | list[Image.Image], face_swapper_model, faces_index: set[int] | None = None, ) -> Image.Image: if faces_index is None: faces_index = {0} log.debug(f"Swapping faces: {faces_index}") result_image = target_img if face_swapper_model is not None: cv_source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR) cv_target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) source_face = get_face_single( face_analyser, cv_source_img, face_index=0 ) if source_face is not None: result = cv_target_img for face_num in faces_index: target_face = get_face_single( face_analyser, cv_target_img, face_index=face_num ) if target_face is not None: sys.stdout = NullWriter() result = face_swapper_model.get( result, target_face, source_face ) sys.stdout = sys.__stdout__ else: log.warning(f"No target face found for {face_num}") result_image = Image.fromarray( cv2.cvtColor(result, cv2.COLOR_BGR2RGB) ) else: log.warning("No source face found") else: log.error("No face swap model provided") return result_image # endregion face swap utils __nodes__ = [MTB_FaceSwap, MTB_LoadFaceSwapModel, MTB_LoadFaceAnalysisModel]