Spaces:
Sleeping
Sleeping
import os | |
import comfy | |
import comfy.utils | |
import cv2 | |
import folder_paths | |
import numpy as np | |
import torch | |
from comfy import model_management | |
from PIL import Image | |
from ..log import NullWriter, log | |
from ..utils import get_model_path, np2tensor, pil2tensor, tensor2np | |
class MTB_LoadFaceEnhanceModel: | |
"""Loads a GFPGan or RestoreFormer model for face enhancement.""" | |
def __init__(self) -> None: | |
pass | |
def get_models_root(cls): | |
fr = get_model_path("face_restore") | |
# fr = Path(folder_paths.models_dir) / "face_restore" | |
if fr.exists(): | |
return (fr, None) | |
um = get_model_path("upscale_models") | |
return (fr, um) if um.exists() else (None, None) | |
def get_models(cls): | |
fr_models_path, um_models_path = cls.get_models_root() | |
if fr_models_path is None and um_models_path is None: | |
if not hasattr(cls, "_warned"): | |
log.warning("Face restoration models not found.") | |
cls._warned = True | |
return [] | |
if not fr_models_path.exists(): | |
# log.warning( | |
# f"No Face Restore checkpoints found at {fr_models_path} (if you've used mtb before these checkpoints were saved in upscale_models before)" | |
# ) | |
# log.warning( | |
# "For now we fallback to upscale_models but this will be removed in a future version" | |
# ) | |
if um_models_path.exists(): | |
return [ | |
x | |
for x in um_models_path.iterdir() | |
if x.name.endswith(".pth") | |
and ("GFPGAN" in x.name or "RestoreFormer" in x.name) | |
] | |
return [] | |
return [ | |
x | |
for x in fr_models_path.iterdir() | |
if x.name.endswith(".pth") | |
and ("GFPGAN" in x.name or "RestoreFormer" in x.name) | |
] | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"model_name": ( | |
[x.name for x in cls.get_models()], | |
{"default": "None"}, | |
), | |
"upscale": ("INT", {"default": 1}), | |
}, | |
"optional": {"bg_upsampler": ("UPSCALE_MODEL", {"default": None})}, | |
} | |
RETURN_TYPES = ("FACEENHANCE_MODEL",) | |
RETURN_NAMES = ("model",) | |
FUNCTION = "load_model" | |
CATEGORY = "mtb/facetools" | |
DEPRECATED = True | |
def load_model(self, model_name, upscale=2, bg_upsampler=None): | |
from gfpgan import GFPGANer | |
basic = "RestoreFormer" not in model_name | |
fr_root, um_root = self.get_models_root() | |
if bg_upsampler is not None: | |
log.warning( | |
f"Upscale value overridden to {bg_upsampler.scale} from bg_upsampler" | |
) | |
upscale = bg_upsampler.scale | |
bg_upsampler = BGUpscaleWrapper(bg_upsampler) | |
sys.stdout = NullWriter() | |
model = GFPGANer( | |
model_path=( | |
(fr_root if fr_root.exists() else um_root) / model_name | |
).as_posix(), | |
upscale=upscale, | |
arch="clean" | |
if basic | |
else "RestoreFormer", # or original for v1.0 only | |
channel_multiplier=2, # 1 for v1.0 only | |
bg_upsampler=bg_upsampler, | |
) | |
sys.stdout = sys.__stdout__ | |
return (model,) | |
class BGUpscaleWrapper: | |
def __init__(self, upscale_model) -> None: | |
self.upscale_model = upscale_model | |
def enhance(self, img: Image.Image, outscale=2): | |
device = model_management.get_torch_device() | |
self.upscale_model.to(device) | |
tile = 128 + 64 | |
overlap = 8 | |
imgt = np2tensor(img) | |
imgt = imgt.movedim(-1, -3).to(device) | |
steps = imgt.shape[0] * comfy.utils.get_tiled_scale_steps( | |
imgt.shape[3], | |
imgt.shape[2], | |
tile_x=tile, | |
tile_y=tile, | |
overlap=overlap, | |
) | |
log.debug(f"Steps: {steps}") | |
pbar = comfy.utils.ProgressBar(steps) | |
s = comfy.utils.tiled_scale( | |
imgt, | |
lambda a: self.upscale_model(a), | |
tile_x=tile, | |
tile_y=tile, | |
overlap=overlap, | |
upscale_amount=self.upscale_model.scale, | |
pbar=pbar, | |
) | |
self.upscale_model.cpu() | |
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0) | |
return (tensor2np(s)[0],) | |
import sys | |
class MTB_RestoreFace: | |
"""Uses GFPGan to restore faces""" | |
def __init__(self) -> None: | |
pass | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "restore" | |
CATEGORY = "mtb/facetools" | |
DEPRECATED = True | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"image": ("IMAGE",), | |
"model": ("FACEENHANCE_MODEL",), | |
# Input are aligned faces | |
"aligned": ("BOOLEAN", {"default": False}), | |
# Only restore the center face | |
"only_center_face": ("BOOLEAN", {"default": False}), | |
# Adjustable weights | |
"weight": ("FLOAT", {"default": 0.5}), | |
"save_tmp_steps": ("BOOLEAN", {"default": True}), | |
}, | |
"optional": { | |
"preserve_alpha": ("BOOLEAN", {"default": True}), | |
}, | |
} | |
def do_restore( | |
self, | |
image: torch.Tensor, | |
model, | |
aligned, | |
only_center_face, | |
weight, | |
save_tmp_steps, | |
preserve_alpha: bool = False, | |
) -> torch.Tensor: | |
pimage = tensor2np(image)[0] | |
width, height = pimage.shape[1], pimage.shape[0] | |
source_img = cv2.cvtColor(np.array(pimage), cv2.COLOR_RGB2BGR) | |
alpha_channel = None | |
if ( | |
preserve_alpha and image.size(-1) == 4 | |
): # Check if the image has an alpha channel | |
alpha_channel = pimage[:, :, 3] | |
pimage = pimage[:, :, :3] # Remove alpha channel for processing | |
sys.stdout = NullWriter() | |
cropped_faces, restored_faces, restored_img = model.enhance( | |
source_img, | |
has_aligned=aligned, | |
only_center_face=only_center_face, | |
paste_back=True, | |
# TODO: weight has no effect in 1.3 and 1.4 (only tested these for now...) | |
weight=weight, | |
) | |
sys.stdout = sys.__stdout__ | |
log.warning(f"Weight value has no effect for now. (value: {weight})") | |
if save_tmp_steps: | |
self.save_intermediate_images( | |
cropped_faces, restored_faces, height, width | |
) | |
output = None | |
if restored_img is not None: | |
restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB) | |
output = Image.fromarray(restored_img) | |
if alpha_channel is not None: | |
alpha_resized = Image.fromarray(alpha_channel).resize( | |
output.size, Image.LANCZOS | |
) | |
output.putalpha(alpha_resized) | |
# imwrite(restored_img, save_restore_path) | |
return pil2tensor(output) | |
def restore( | |
self, | |
image: torch.Tensor, | |
model, | |
aligned=False, | |
only_center_face=False, | |
weight=0.5, | |
save_tmp_steps=True, | |
preserve_alpha: bool = False, | |
) -> tuple[torch.Tensor]: | |
out = [ | |
self.do_restore( | |
image[i], | |
model, | |
aligned, | |
only_center_face, | |
weight, | |
save_tmp_steps, | |
preserve_alpha, | |
) | |
for i in range(image.size(0)) | |
] | |
return (torch.cat(out, dim=0),) | |
def get_step_image_path(self, step, idx): | |
( | |
full_output_folder, | |
filename, | |
counter, | |
_subfolder, | |
_filename_prefix, | |
) = folder_paths.get_save_image_path( | |
f"{step}_{idx:03}", | |
folder_paths.temp_directory, | |
) | |
file = f"{filename}_{counter:05}_.png" | |
return os.path.join(full_output_folder, file) | |
def save_intermediate_images( | |
self, cropped_faces, restored_faces, height, width | |
): | |
for idx, (cropped_face, restored_face) in enumerate( | |
zip(cropped_faces, restored_faces, strict=False) | |
): | |
face_id = idx + 1 | |
file = self.get_step_image_path("cropped_faces", face_id) | |
cv2.imwrite(file, cropped_face) | |
file = self.get_step_image_path("cropped_faces_restored", face_id) | |
cv2.imwrite(file, restored_face) | |
file = self.get_step_image_path("cropped_faces_compare", face_id) | |
# save comparison image | |
cmp_img = np.concatenate((cropped_face, restored_face), axis=1) | |
cv2.imwrite(file, cmp_img) | |
__nodes__ = [MTB_RestoreFace, MTB_LoadFaceEnhanceModel] | |