import os import sys import numpy as np import torch import cv2 from PIL import Image import folder_paths import comfy.utils import time import copy import dill import yaml from ultralytics import YOLO current_file_path = os.path.abspath(__file__) current_directory = os.path.dirname(current_file_path) from .LivePortrait.live_portrait_wrapper import LivePortraitWrapper from .LivePortrait.utils.camera import get_rotation_matrix from .LivePortrait.config.inference_config import InferenceConfig from .LivePortrait.modules.spade_generator import SPADEDecoder from .LivePortrait.modules.warping_network import WarpingNetwork from .LivePortrait.modules.motion_extractor import MotionExtractor from .LivePortrait.modules.appearance_feature_extractor import AppearanceFeatureExtractor from .LivePortrait.modules.stitching_retargeting_network import StitchingRetargetingNetwork from collections import OrderedDict cur_device = None def get_device(): global cur_device if cur_device == None: if torch.cuda.is_available(): cur_device = torch.device('cuda') print("Uses CUDA device.") elif torch.backends.mps.is_available(): cur_device = torch.device('mps') print("Uses MPS device.") else: cur_device = torch.device('cpu') print("Uses CPU device.") return cur_device def tensor2pil(image): return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) def pil2tensor(image): return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) def rgb_crop(rgb, region): return rgb[region[1]:region[3], region[0]:region[2]] def rgb_crop_batch(rgbs, region): return rgbs[:, region[1]:region[3], region[0]:region[2]] def get_rgb_size(rgb): return rgb.shape[1], rgb.shape[0] def create_transform_matrix(x, y, s_x, s_y): return np.float32([[s_x, 0, x], [0, s_y, y]]) def get_model_dir(m): try: return folder_paths.get_folder_paths(m)[0] except: return os.path.join(folder_paths.models_dir, m) def calc_crop_limit(center, img_size, crop_size): pos = center - crop_size / 2 if pos < 0: crop_size += pos * 2 pos = 0 pos2 = pos + crop_size if img_size < pos2: crop_size -= (pos2 - img_size) * 2 pos2 = img_size pos = pos2 - crop_size return pos, pos2, crop_size def retargeting(delta_out, driving_exp, factor, idxes): for idx in idxes: #delta_out[0, idx] -= src_exp[0, idx] * factor delta_out[0, idx] += driving_exp[0, idx] * factor class PreparedSrcImg: def __init__(self, src_rgb, crop_trans_m, x_s_info, f_s_user, x_s_user, mask_ori): self.src_rgb = src_rgb self.crop_trans_m = crop_trans_m self.x_s_info = x_s_info self.f_s_user = f_s_user self.x_s_user = x_s_user self.mask_ori = mask_ori import requests from tqdm import tqdm class LP_Engine: pipeline = None detect_model = None mask_img = None temp_img_idx = 0 def get_temp_img_name(self): self.temp_img_idx += 1 return "expression_edit_preview" + str(self.temp_img_idx) + ".png" def download_model(_, file_path, model_url): print('AdvancedLivePortrait: Downloading model...') response = requests.get(model_url, stream=True) try: if response.status_code == 200: total_size = int(response.headers.get('content-length', 0)) block_size = 1024 # 1 Kibibyte # tqdm will display a progress bar with open(file_path, 'wb') as file, tqdm( desc='Downloading', total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as bar: for data in response.iter_content(block_size): bar.update(len(data)) file.write(data) except requests.exceptions.RequestException as err: print('AdvancedLivePortrait: Model download failed: {err}') print(f'AdvancedLivePortrait: Download it manually from: {model_url}') print(f'AdvancedLivePortrait: And put it in {file_path}') except Exception as e: print(f'AdvancedLivePortrait: An unexpected error occurred: {e}') def remove_ddp_dumplicate_key(_, state_dict): state_dict_new = OrderedDict() for key in state_dict.keys(): state_dict_new[key.replace('module.', '')] = state_dict[key] return state_dict_new def filter_for_model(_, checkpoint, prefix): filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if key.startswith(prefix)} return filtered_checkpoint def load_model(self, model_config, model_type): device = get_device() if model_type == 'stitching_retargeting_module': ckpt_path = os.path.join(get_model_dir("liveportrait"), "retargeting_models", model_type + ".pth") else: ckpt_path = os.path.join(get_model_dir("liveportrait"), "base_models", model_type + ".pth") is_safetensors = None if os.path.isfile(ckpt_path) == False: is_safetensors = True ckpt_path = os.path.join(get_model_dir("liveportrait"), model_type + ".safetensors") if os.path.isfile(ckpt_path) == False: self.download_model(ckpt_path, "https://huggingface.co/Kijai/LivePortrait_safetensors/resolve/main/" + model_type + ".safetensors") model_params = model_config['model_params'][f'{model_type}_params'] if model_type == 'appearance_feature_extractor': model = AppearanceFeatureExtractor(**model_params).to(device) elif model_type == 'motion_extractor': model = MotionExtractor(**model_params).to(device) elif model_type == 'warping_module': model = WarpingNetwork(**model_params).to(device) elif model_type == 'spade_generator': model = SPADEDecoder(**model_params).to(device) elif model_type == 'stitching_retargeting_module': # Special handling for stitching and retargeting module config = model_config['model_params']['stitching_retargeting_module_params'] checkpoint = comfy.utils.load_torch_file(ckpt_path) stitcher = StitchingRetargetingNetwork(**config.get('stitching')) if is_safetensors: stitcher.load_state_dict(self.filter_for_model(checkpoint, 'retarget_shoulder')) else: stitcher.load_state_dict(self.remove_ddp_dumplicate_key(checkpoint['retarget_shoulder'])) stitcher = stitcher.to(device) stitcher.eval() return { 'stitching': stitcher, } else: raise ValueError(f"Unknown model type: {model_type}") model.load_state_dict(comfy.utils.load_torch_file(ckpt_path)) model.eval() return model def load_models(self): model_path = get_model_dir("liveportrait") if not os.path.exists(model_path): os.mkdir(model_path) model_config_path = os.path.join(current_directory, 'LivePortrait', 'config', 'models.yaml') model_config = yaml.safe_load(open(model_config_path, 'r')) appearance_feature_extractor = self.load_model(model_config, 'appearance_feature_extractor') motion_extractor = self.load_model(model_config, 'motion_extractor') warping_module = self.load_model(model_config, 'warping_module') spade_generator = self.load_model(model_config, 'spade_generator') stitching_retargeting_module = self.load_model(model_config, 'stitching_retargeting_module') self.pipeline = LivePortraitWrapper(InferenceConfig(), appearance_feature_extractor, motion_extractor, warping_module, spade_generator, stitching_retargeting_module) def get_detect_model(self): if self.detect_model == None: model_dir = get_model_dir("ultralytics") if not os.path.exists(model_dir): os.mkdir(model_dir) model_path = os.path.join(model_dir, "face_yolov8n.pt") if not os.path.exists(model_path): self.download_model(model_path, "https://huggingface.co/Bingsu/adetailer/resolve/main/face_yolov8n.pt") self.detect_model = YOLO(model_path) return self.detect_model def get_face_bboxes(self, image_rgb): detect_model = self.get_detect_model() pred = detect_model(image_rgb, conf=0.7, device="") return pred[0].boxes.xyxy.cpu().numpy() def detect_face(self, image_rgb, crop_factor, sort = True): bboxes = self.get_face_bboxes(image_rgb) w, h = get_rgb_size(image_rgb) print(f"w, h:{w, h}") cx = w / 2 min_diff = w best_box = None for x1, y1, x2, y2 in bboxes: bbox_w = x2 - x1 if bbox_w < 30: continue diff = abs(cx - (x1 + bbox_w / 2)) if diff < min_diff: best_box = [x1, y1, x2, y2] print(f"diff, min_diff, best_box:{diff, min_diff, best_box}") min_diff = diff if best_box == None: print("Failed to detect face!!") return [0, 0, w, h] x1, y1, x2, y2 = best_box #for x1, y1, x2, y2 in bboxes: bbox_w = x2 - x1 bbox_h = y2 - y1 crop_w = bbox_w * crop_factor crop_h = bbox_h * crop_factor crop_w = max(crop_h, crop_w) crop_h = crop_w kernel_x = int(x1 + bbox_w / 2) kernel_y = int(y1 + bbox_h / 2) new_x1 = int(kernel_x - crop_w / 2) new_x2 = int(kernel_x + crop_w / 2) new_y1 = int(kernel_y - crop_h / 2) new_y2 = int(kernel_y + crop_h / 2) if not sort: return [int(new_x1), int(new_y1), int(new_x2), int(new_y2)] if new_x1 < 0: new_x2 -= new_x1 new_x1 = 0 elif w < new_x2: new_x1 -= (new_x2 - w) new_x2 = w if new_x1 < 0: new_x2 -= new_x1 new_x1 = 0 if new_y1 < 0: new_y2 -= new_y1 new_y1 = 0 elif h < new_y2: new_y1 -= (new_y2 - h) new_y2 = h if new_y1 < 0: new_y2 -= new_y1 new_y1 = 0 if w < new_x2 and h < new_y2: over_x = new_x2 - w over_y = new_y2 - h over_min = min(over_x, over_y) new_x2 -= over_min new_y2 -= over_min return [int(new_x1), int(new_y1), int(new_x2), int(new_y2)] def calc_face_region(self, square, dsize): region = copy.deepcopy(square) is_changed = False if dsize[0] < region[2]: region[2] = dsize[0] is_changed = True if dsize[1] < region[3]: region[3] = dsize[1] is_changed = True return region, is_changed def expand_img(self, rgb_img, square): #new_img = rgb_crop(rgb_img, face_region) crop_trans_m = create_transform_matrix(max(-square[0], 0), max(-square[1], 0), 1, 1) new_img = cv2.warpAffine(rgb_img, crop_trans_m, (square[2] - square[0], square[3] - square[1]), cv2.INTER_LINEAR) return new_img def get_pipeline(self): if self.pipeline == None: print("Load pipeline...") self.load_models() return self.pipeline def prepare_src_image(self, img): h, w = img.shape[:2] input_shape = [256,256] if h != input_shape[0] or w != input_shape[1]: if 256 < h: interpolation = cv2.INTER_AREA else: interpolation = cv2.INTER_LINEAR x = cv2.resize(img, (input_shape[0], input_shape[1]), interpolation = interpolation) else: x = img.copy() if x.ndim == 3: x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1 elif x.ndim == 4: x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1 else: raise ValueError(f'img ndim should be 3 or 4: {x.ndim}') x = np.clip(x, 0, 1) # clip to 0~1 x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW x = x.to(get_device()) return x def GetMaskImg(self): if self.mask_img is None: path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "./LivePortrait/utils/resources/mask_template.png") self.mask_img = cv2.imread(path, cv2.IMREAD_COLOR) return self.mask_img def crop_face(self, img_rgb, crop_factor): crop_region = self.detect_face(img_rgb, crop_factor) face_region, is_changed = self.calc_face_region(crop_region, get_rgb_size(img_rgb)) face_img = rgb_crop(img_rgb, face_region) if is_changed: face_img = self.expand_img(face_img, crop_region) return face_img def prepare_source(self, source_image, crop_factor, is_video = False, tracking = False): print("Prepare source...") engine = self.get_pipeline() source_image_np = (source_image * 255).byte().numpy() img_rgb = source_image_np[0] psi_list = [] for img_rgb in source_image_np: if tracking or len(psi_list) == 0: crop_region = self.detect_face(img_rgb, crop_factor) face_region, is_changed = self.calc_face_region(crop_region, get_rgb_size(img_rgb)) s_x = (face_region[2] - face_region[0]) / 512. s_y = (face_region[3] - face_region[1]) / 512. crop_trans_m = create_transform_matrix(crop_region[0], crop_region[1], s_x, s_y) mask_ori = cv2.warpAffine(self.GetMaskImg(), crop_trans_m, get_rgb_size(img_rgb), cv2.INTER_LINEAR) mask_ori = mask_ori.astype(np.float32) / 255. if is_changed: s = (crop_region[2] - crop_region[0]) / 512. crop_trans_m = create_transform_matrix(crop_region[0], crop_region[1], s, s) face_img = rgb_crop(img_rgb, face_region) if is_changed: face_img = self.expand_img(face_img, crop_region) i_s = self.prepare_src_image(face_img) x_s_info = engine.get_kp_info(i_s) f_s_user = engine.extract_feature_3d(i_s) x_s_user = engine.transform_keypoint(x_s_info) psi = PreparedSrcImg(img_rgb, crop_trans_m, x_s_info, f_s_user, x_s_user, mask_ori) if is_video == False: return psi psi_list.append(psi) return psi_list def prepare_driving_video(self, face_images): print("Prepare driving video...") pipeline = self.get_pipeline() f_img_np = (face_images * 255).byte().numpy() out_list = [] for f_img in f_img_np: i_d = self.prepare_src_image(f_img) d_info = pipeline.get_kp_info(i_d) out_list.append(d_info) return out_list def calc_fe(_, x_d_new, eyes, eyebrow, wink, pupil_x, pupil_y, mouth, eee, woo, smile, rotate_pitch, rotate_yaw, rotate_roll): x_d_new[0, 20, 1] += smile * -0.01 x_d_new[0, 14, 1] += smile * -0.02 x_d_new[0, 17, 1] += smile * 0.0065 x_d_new[0, 17, 2] += smile * 0.003 x_d_new[0, 13, 1] += smile * -0.00275 x_d_new[0, 16, 1] += smile * -0.00275 x_d_new[0, 3, 1] += smile * -0.0035 x_d_new[0, 7, 1] += smile * -0.0035 x_d_new[0, 19, 1] += mouth * 0.001 x_d_new[0, 19, 2] += mouth * 0.0001 x_d_new[0, 17, 1] += mouth * -0.0001 rotate_pitch -= mouth * 0.05 x_d_new[0, 20, 2] += eee * -0.001 x_d_new[0, 20, 1] += eee * -0.001 #x_d_new[0, 19, 1] += eee * 0.0006 x_d_new[0, 14, 1] += eee * -0.001 x_d_new[0, 14, 1] += woo * 0.001 x_d_new[0, 3, 1] += woo * -0.0005 x_d_new[0, 7, 1] += woo * -0.0005 x_d_new[0, 17, 2] += woo * -0.0005 x_d_new[0, 11, 1] += wink * 0.001 x_d_new[0, 13, 1] += wink * -0.0003 x_d_new[0, 17, 0] += wink * 0.0003 x_d_new[0, 17, 1] += wink * 0.0003 x_d_new[0, 3, 1] += wink * -0.0003 rotate_roll -= wink * 0.1 rotate_yaw -= wink * 0.1 if 0 < pupil_x: x_d_new[0, 11, 0] += pupil_x * 0.0007 x_d_new[0, 15, 0] += pupil_x * 0.001 else: x_d_new[0, 11, 0] += pupil_x * 0.001 x_d_new[0, 15, 0] += pupil_x * 0.0007 x_d_new[0, 11, 1] += pupil_y * -0.001 x_d_new[0, 15, 1] += pupil_y * -0.001 eyes -= pupil_y / 2. x_d_new[0, 11, 1] += eyes * -0.001 x_d_new[0, 13, 1] += eyes * 0.0003 x_d_new[0, 15, 1] += eyes * -0.001 x_d_new[0, 16, 1] += eyes * 0.0003 x_d_new[0, 1, 1] += eyes * -0.00025 x_d_new[0, 2, 1] += eyes * 0.00025 if 0 < eyebrow: x_d_new[0, 1, 1] += eyebrow * 0.001 x_d_new[0, 2, 1] += eyebrow * -0.001 else: x_d_new[0, 1, 0] += eyebrow * -0.001 x_d_new[0, 2, 0] += eyebrow * 0.001 x_d_new[0, 1, 1] += eyebrow * 0.0003 x_d_new[0, 2, 1] += eyebrow * -0.0003 return torch.Tensor([rotate_pitch, rotate_yaw, rotate_roll]) g_engine = LP_Engine() class ExpressionSet: def __init__(self, erst = None, es = None): if es != None: self.e = copy.deepcopy(es.e) # [:, :, :] self.r = copy.deepcopy(es.r) # [:] self.s = copy.deepcopy(es.s) self.t = copy.deepcopy(es.t) elif erst != None: self.e = erst[0] self.r = erst[1] self.s = erst[2] self.t = erst[3] else: self.e = torch.from_numpy(np.zeros((1, 21, 3))).float().to(get_device()) self.r = torch.Tensor([0, 0, 0]) self.s = 0 self.t = 0 def div(self, value): self.e /= value self.r /= value self.s /= value self.t /= value def add(self, other): self.e += other.e self.r += other.r self.s += other.s self.t += other.t def sub(self, other): self.e -= other.e self.r -= other.r self.s -= other.s self.t -= other.t def mul(self, value): self.e *= value self.r *= value self.s *= value self.t *= value #def apply_ratio(self, ratio): self.exp *= ratio def logging_time(original_fn): def wrapper_fn(*args, **kwargs): start_time = time.time() result = original_fn(*args, **kwargs) end_time = time.time() print("WorkingTime[{}]: {} sec".format(original_fn.__name__, end_time - start_time)) return result return wrapper_fn #exp_data_dir = os.path.join(current_directory, "exp_data") exp_data_dir = os.path.join(folder_paths.output_directory, "exp_data") if os.path.isdir(exp_data_dir) == False: os.mkdir(exp_data_dir) class SaveExpData: @classmethod def INPUT_TYPES(s): return {"required": { "file_name": ("STRING", {"multiline": False, "default": ""}), }, "optional": {"save_exp": ("EXP_DATA",), } } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("file_name",) FUNCTION = "run" CATEGORY = "AdvancedLivePortrait" OUTPUT_NODE = True def run(self, file_name, save_exp:ExpressionSet=None): if save_exp == None or file_name == "": return file_name with open(os.path.join(exp_data_dir, file_name + ".exp"), "wb") as f: dill.dump(save_exp, f) return file_name class LoadExpData: @classmethod def INPUT_TYPES(s): file_list = [os.path.splitext(file)[0] for file in os.listdir(exp_data_dir) if file.endswith('.exp')] return {"required": { "file_name": (sorted(file_list, key=str.lower),), "ratio": ("FLOAT", {"default": 1, "min": 0, "max": 1, "step": 0.01}), }, } RETURN_TYPES = ("EXP_DATA",) RETURN_NAMES = ("exp",) FUNCTION = "run" CATEGORY = "AdvancedLivePortrait" def run(self, file_name, ratio): # es = ExpressionSet() with open(os.path.join(exp_data_dir, file_name + ".exp"), 'rb') as f: es = dill.load(f) es.mul(ratio) return (es,) class ExpData: @classmethod def INPUT_TYPES(s): return {"required":{ #"code": ("STRING", {"multiline": False, "default": ""}), "code1": ("INT", {"default": 0}), "value1": ("FLOAT", {"default": 0, "min": -100, "max": 100, "step": 0.1}), "code2": ("INT", {"default": 0}), "value2": ("FLOAT", {"default": 0, "min": -100, "max": 100, "step": 0.1}), "code3": ("INT", {"default": 0}), "value3": ("FLOAT", {"default": 0, "min": -100, "max": 100, "step": 0.1}), "code4": ("INT", {"default": 0}), "value4": ("FLOAT", {"default": 0, "min": -100, "max": 100, "step": 0.1}), "code5": ("INT", {"default": 0}), "value5": ("FLOAT", {"default": 0, "min": -100, "max": 100, "step": 0.1}), }, "optional":{"add_exp": ("EXP_DATA",),} } RETURN_TYPES = ("EXP_DATA",) RETURN_NAMES = ("exp",) FUNCTION = "run" CATEGORY = "AdvancedLivePortrait" def run(self, code1, value1, code2, value2, code3, value3, code4, value4, code5, value5, add_exp=None): if add_exp == None: es = ExpressionSet() else: es = ExpressionSet(es = add_exp) codes = [code1, code2, code3, code4, code5] values = [value1, value2, value3, value4, value5] for i in range(5): idx = int(codes[i] / 10) r = codes[i] % 10 es.e[0, idx, r] += values[i] * 0.001 return (es,) class PrintExpData: @classmethod def INPUT_TYPES(s): return {"required": { "cut_noise": ("FLOAT", {"default": 0, "min": 0, "max": 100, "step": 0.1}), }, "optional": {"exp": ("EXP_DATA",), } } RETURN_TYPES = ("EXP_DATA",) RETURN_NAMES = ("exp",) FUNCTION = "run" CATEGORY = "AdvancedLivePortrait" OUTPUT_NODE = True def run(self, cut_noise, exp = None): if exp == None: return (exp,) cuted_list = [] e = exp.exp * 1000 for idx in range(21): for r in range(3): a = abs(e[0, idx, r]) if(cut_noise < a): cuted_list.append((a, e[0, idx, r], idx*10+r)) sorted_list = sorted(cuted_list, reverse=True, key=lambda item: item[0]) print(f"sorted_list: {[[item[2], round(float(item[1]),1)] for item in sorted_list]}") return (exp,) class Command: def __init__(self, es, change, keep): self.es:ExpressionSet = es self.change = change self.keep = keep crop_factor_default = 1.7 crop_factor_min = 1.5 crop_factor_max = 2.5 class AdvancedLivePortrait: def __init__(self): self.src_images = None self.driving_images = None self.pbar = comfy.utils.ProgressBar(1) self.crop_factor = None @classmethod def INPUT_TYPES(s): return { "required": { "retargeting_eyes": ("FLOAT", {"default": 0, "min": 0, "max": 1, "step": 0.01}), "retargeting_mouth": ("FLOAT", {"default": 0, "min": 0, "max": 1, "step": 0.01}), "crop_factor": ("FLOAT", {"default": crop_factor_default, "min": crop_factor_min, "max": crop_factor_max, "step": 0.1}), "turn_on": ("BOOLEAN", {"default": True}), "tracking_src_vid": ("BOOLEAN", {"default": False}), "animate_without_vid": ("BOOLEAN", {"default": False}), "command": ("STRING", {"multiline": True, "default": ""}), }, "optional": { "src_images": ("IMAGE",), "motion_link": ("EDITOR_LINK",), "driving_images": ("IMAGE",), }, } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("images",) FUNCTION = "run" OUTPUT_NODE = True CATEGORY = "AdvancedLivePortrait" # INPUT_IS_LIST = False # OUTPUT_IS_LIST = (False,) def parsing_command(self, command, motoin_link): command.replace(' ', '') # if command == '': return lines = command.split('\n') cmd_list = [] total_length = 0 i = 0 #old_es = None for line in lines: i += 1 if line == '': continue try: cmds = line.split('=') idx = int(cmds[0]) if idx == 0: es = ExpressionSet() else: es = ExpressionSet(es = motoin_link[idx]) cmds = cmds[1].split(':') change = int(cmds[0]) keep = int(cmds[1]) except: assert False, f"(AdvancedLivePortrait) Command Err Line {i}: {line}" return None, None total_length += change + keep es.div(change) cmd_list.append(Command(es, change, keep)) return cmd_list, total_length def run(self, retargeting_eyes, retargeting_mouth, turn_on, tracking_src_vid, animate_without_vid, command, crop_factor, src_images=None, driving_images=None, motion_link=None): if turn_on == False: return (None,None) src_length = 1 if src_images == None: if motion_link != None: self.psi_list = [motion_link[0]] else: return (None,None) if src_images != None: src_length = len(src_images) if id(src_images) != id(self.src_images) or self.crop_factor != crop_factor: self.crop_factor = crop_factor self.src_images = src_images if 1 < src_length: self.psi_list = g_engine.prepare_source(src_images, crop_factor, True, tracking_src_vid) else: self.psi_list = [g_engine.prepare_source(src_images, crop_factor)] cmd_list, cmd_length = self.parsing_command(command, motion_link) if cmd_list == None: return (None,None) cmd_idx = 0 driving_length = 0 if driving_images is not None: if id(driving_images) != id(self.driving_images): self.driving_images = driving_images self.driving_values = g_engine.prepare_driving_video(driving_images) driving_length = len(self.driving_values) total_length = max(driving_length, src_length) if animate_without_vid: total_length = max(total_length, cmd_length) c_i_es = ExpressionSet() c_o_es = ExpressionSet() d_0_es = None out_list = [] psi = None pipeline = g_engine.get_pipeline() for i in range(total_length): if i < src_length: psi = self.psi_list[i] s_info = psi.x_s_info s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t'])) new_es = ExpressionSet(es = s_es) if i < cmd_length: cmd = cmd_list[cmd_idx] if 0 < cmd.change: cmd.change -= 1 c_i_es.add(cmd.es) c_i_es.sub(c_o_es) elif 0 < cmd.keep: cmd.keep -= 1 new_es.add(c_i_es) if cmd.change == 0 and cmd.keep == 0: cmd_idx += 1 if cmd_idx < len(cmd_list): c_o_es = ExpressionSet(es = c_i_es) cmd = cmd_list[cmd_idx] c_o_es.div(cmd.change) elif 0 < cmd_length: new_es.add(c_i_es) if i < driving_length: d_i_info = self.driving_values[i] d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']])#.float().to(device="cuda:0") if d_0_es is None: d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t'])) retargeting(s_es.e, d_0_es.e, retargeting_eyes, (11, 13, 15, 16)) retargeting(s_es.e, d_0_es.e, retargeting_mouth, (14, 17, 19, 20)) new_es.e += d_i_info['exp'] - d_0_es.e new_es.r += d_i_r - d_0_es.r new_es.t += d_i_info['t'] - d_0_es.t r_new = get_rotation_matrix( s_info['pitch'] + new_es.r[0], s_info['yaw'] + new_es.r[1], s_info['roll'] + new_es.r[2]) d_new = new_es.s * (new_es.e @ r_new) + new_es.t d_new = pipeline.stitching(psi.x_s_user, d_new) crop_out = pipeline.warp_decode(psi.f_s_user, psi.x_s_user, d_new) crop_out = pipeline.parse_output(crop_out['out'])[0] crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb), cv2.INTER_LINEAR) out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype( np.uint8) out_list.append(out) self.pbar.update_absolute(i+1, total_length, ("PNG", Image.fromarray(crop_out), None)) if len(out_list) == 0: return (None,) out_imgs = torch.cat([pil2tensor(img_rgb) for img_rgb in out_list]) return (out_imgs,) class ExpressionEditor: def __init__(self): self.sample_image = None self.src_image = None self.crop_factor = None @classmethod def INPUT_TYPES(s): display = "number" #display = "slider" return { "required": { "rotate_pitch": ("FLOAT", {"default": 0, "min": -20, "max": 20, "step": 0.5, "display": display}), "rotate_yaw": ("FLOAT", {"default": 0, "min": -20, "max": 20, "step": 0.5, "display": display}), "rotate_roll": ("FLOAT", {"default": 0, "min": -20, "max": 20, "step": 0.5, "display": display}), "blink": ("FLOAT", {"default": 0, "min": -20, "max": 5, "step": 0.5, "display": display}), "eyebrow": ("FLOAT", {"default": 0, "min": -10, "max": 15, "step": 0.5, "display": display}), "wink": ("FLOAT", {"default": 0, "min": 0, "max": 25, "step": 0.5, "display": display}), "pupil_x": ("FLOAT", {"default": 0, "min": -15, "max": 15, "step": 0.5, "display": display}), "pupil_y": ("FLOAT", {"default": 0, "min": -15, "max": 15, "step": 0.5, "display": display}), "aaa": ("FLOAT", {"default": 0, "min": -30, "max": 120, "step": 1, "display": display}), "eee": ("FLOAT", {"default": 0, "min": -20, "max": 15, "step": 0.2, "display": display}), "woo": ("FLOAT", {"default": 0, "min": -20, "max": 15, "step": 0.2, "display": display}), "smile": ("FLOAT", {"default": 0, "min": -0.3, "max": 1.3, "step": 0.01, "display": display}), "src_ratio": ("FLOAT", {"default": 1, "min": 0, "max": 1, "step": 0.01, "display": display}), "sample_ratio": ("FLOAT", {"default": 1, "min": -0.2, "max": 1.2, "step": 0.01, "display": display}), "sample_parts": (["OnlyExpression", "OnlyRotation", "OnlyMouth", "OnlyEyes", "All"],), "crop_factor": ("FLOAT", {"default": crop_factor_default, "min": crop_factor_min, "max": crop_factor_max, "step": 0.1}), }, "optional": {"src_image": ("IMAGE",), "motion_link": ("EDITOR_LINK",), "sample_image": ("IMAGE",), "add_exp": ("EXP_DATA",), }, } RETURN_TYPES = ("IMAGE", "EDITOR_LINK", "EXP_DATA") RETURN_NAMES = ("image", "motion_link", "save_exp") FUNCTION = "run" OUTPUT_NODE = True CATEGORY = "AdvancedLivePortrait" # INPUT_IS_LIST = False # OUTPUT_IS_LIST = (False,) def run(self, rotate_pitch, rotate_yaw, rotate_roll, blink, eyebrow, wink, pupil_x, pupil_y, aaa, eee, woo, smile, src_ratio, sample_ratio, sample_parts, crop_factor, src_image=None, sample_image=None, motion_link=None, add_exp=None): rotate_yaw = -rotate_yaw new_editor_link = None if motion_link != None: self.psi = motion_link[0] new_editor_link = motion_link.copy() elif src_image != None: if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor: self.crop_factor = crop_factor self.psi = g_engine.prepare_source(src_image, crop_factor) self.src_image = src_image new_editor_link = [] new_editor_link.append(self.psi) else: return (None,None) pipeline = g_engine.get_pipeline() psi = self.psi s_info = psi.x_s_info #delta_new = copy.deepcopy() s_exp = s_info['exp'] * src_ratio s_exp[0, 5] = s_info['exp'][0, 5] s_exp += s_info['kp'] es = ExpressionSet() if sample_image != None: if id(self.sample_image) != id(sample_image): self.sample_image = sample_image d_image_np = (sample_image * 255).byte().numpy() d_face = g_engine.crop_face(d_image_np[0], 1.7) i_d = g_engine.prepare_src_image(d_face) self.d_info = pipeline.get_kp_info(i_d) self.d_info['exp'][0, 5, 0] = 0 self.d_info['exp'][0, 5, 1] = 0 # "OnlyExpression", "OnlyRotation", "OnlyMouth", "OnlyEyes", "All" if sample_parts == "OnlyExpression" or sample_parts == "All": es.e += self.d_info['exp'] * sample_ratio if sample_parts == "OnlyRotation" or sample_parts == "All": rotate_pitch += self.d_info['pitch'] * sample_ratio rotate_yaw += self.d_info['yaw'] * sample_ratio rotate_roll += self.d_info['roll'] * sample_ratio elif sample_parts == "OnlyMouth": retargeting(es.e, self.d_info['exp'], sample_ratio, (14, 17, 19, 20)) elif sample_parts == "OnlyEyes": retargeting(es.e, self.d_info['exp'], sample_ratio, (1, 2, 11, 13, 15, 16)) es.r = g_engine.calc_fe(es.e, blink, eyebrow, wink, pupil_x, pupil_y, aaa, eee, woo, smile, rotate_pitch, rotate_yaw, rotate_roll) if add_exp != None: es.add(add_exp) new_rotate = get_rotation_matrix(s_info['pitch'] + es.r[0], s_info['yaw'] + es.r[1], s_info['roll'] + es.r[2]) x_d_new = (s_info['scale'] * (1 + es.s)) * ((s_exp + es.e) @ new_rotate) + s_info['t'] x_d_new = pipeline.stitching(psi.x_s_user, x_d_new) crop_out = pipeline.warp_decode(psi.f_s_user, psi.x_s_user, x_d_new) crop_out = pipeline.parse_output(crop_out['out'])[0] crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb), cv2.INTER_LINEAR) out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8) out_img = pil2tensor(out) filename = g_engine.get_temp_img_name() #"fe_edit_preview.png" folder_paths.get_save_image_path(filename, folder_paths.get_temp_directory()) img = Image.fromarray(crop_out) img.save(os.path.join(folder_paths.get_temp_directory(), filename), compress_level=1) results = list() results.append({"filename": filename, "type": "temp"}) new_editor_link.append(es) return {"ui": {"images": results}, "result": (out_img, new_editor_link, es)} NODE_CLASS_MAPPINGS = { "AdvancedLivePortrait": AdvancedLivePortrait, "ExpressionEditor": ExpressionEditor, "LoadExpData": LoadExpData, "SaveExpData": SaveExpData, "ExpData": ExpData, "PrintExpData:": PrintExpData, } NODE_DISPLAY_NAME_MAPPINGS = { "AdvancedLivePortrait": "Advanced Live Portrait (PHM)", "ExpressionEditor": "Expression Editor (PHM)", "LoadExpData": "Load Exp Data (PHM)", "SaveExpData": "Save Exp Data (PHM)" }