multimodalart's picture
Squashing commit
4450790 verified
raw
history blame
9.01 kB
import os
import comfy
import comfy.utils
import cv2
import folder_paths
import numpy as np
import torch
from comfy import model_management
from PIL import Image
from ..log import NullWriter, log
from ..utils import get_model_path, np2tensor, pil2tensor, tensor2np
class MTB_LoadFaceEnhanceModel:
"""Loads a GFPGan or RestoreFormer model for face enhancement."""
def __init__(self) -> None:
pass
@classmethod
def get_models_root(cls):
fr = get_model_path("face_restore")
# fr = Path(folder_paths.models_dir) / "face_restore"
if fr.exists():
return (fr, None)
um = get_model_path("upscale_models")
return (fr, um) if um.exists() else (None, None)
@classmethod
def get_models(cls):
fr_models_path, um_models_path = cls.get_models_root()
if fr_models_path is None and um_models_path is None:
if not hasattr(cls, "_warned"):
log.warning("Face restoration models not found.")
cls._warned = True
return []
if not fr_models_path.exists():
# log.warning(
# f"No Face Restore checkpoints found at {fr_models_path} (if you've used mtb before these checkpoints were saved in upscale_models before)"
# )
# log.warning(
# "For now we fallback to upscale_models but this will be removed in a future version"
# )
if um_models_path.exists():
return [
x
for x in um_models_path.iterdir()
if x.name.endswith(".pth")
and ("GFPGAN" in x.name or "RestoreFormer" in x.name)
]
return []
return [
x
for x in fr_models_path.iterdir()
if x.name.endswith(".pth")
and ("GFPGAN" in x.name or "RestoreFormer" in x.name)
]
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_name": (
[x.name for x in cls.get_models()],
{"default": "None"},
),
"upscale": ("INT", {"default": 1}),
},
"optional": {"bg_upsampler": ("UPSCALE_MODEL", {"default": None})},
}
RETURN_TYPES = ("FACEENHANCE_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_model"
CATEGORY = "mtb/facetools"
DEPRECATED = True
def load_model(self, model_name, upscale=2, bg_upsampler=None):
from gfpgan import GFPGANer
basic = "RestoreFormer" not in model_name
fr_root, um_root = self.get_models_root()
if bg_upsampler is not None:
log.warning(
f"Upscale value overridden to {bg_upsampler.scale} from bg_upsampler"
)
upscale = bg_upsampler.scale
bg_upsampler = BGUpscaleWrapper(bg_upsampler)
sys.stdout = NullWriter()
model = GFPGANer(
model_path=(
(fr_root if fr_root.exists() else um_root) / model_name
).as_posix(),
upscale=upscale,
arch="clean"
if basic
else "RestoreFormer", # or original for v1.0 only
channel_multiplier=2, # 1 for v1.0 only
bg_upsampler=bg_upsampler,
)
sys.stdout = sys.__stdout__
return (model,)
class BGUpscaleWrapper:
def __init__(self, upscale_model) -> None:
self.upscale_model = upscale_model
def enhance(self, img: Image.Image, outscale=2):
device = model_management.get_torch_device()
self.upscale_model.to(device)
tile = 128 + 64
overlap = 8
imgt = np2tensor(img)
imgt = imgt.movedim(-1, -3).to(device)
steps = imgt.shape[0] * comfy.utils.get_tiled_scale_steps(
imgt.shape[3],
imgt.shape[2],
tile_x=tile,
tile_y=tile,
overlap=overlap,
)
log.debug(f"Steps: {steps}")
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(
imgt,
lambda a: self.upscale_model(a),
tile_x=tile,
tile_y=tile,
overlap=overlap,
upscale_amount=self.upscale_model.scale,
pbar=pbar,
)
self.upscale_model.cpu()
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
return (tensor2np(s)[0],)
import sys
class MTB_RestoreFace:
"""Uses GFPGan to restore faces"""
def __init__(self) -> None:
pass
RETURN_TYPES = ("IMAGE",)
FUNCTION = "restore"
CATEGORY = "mtb/facetools"
DEPRECATED = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"model": ("FACEENHANCE_MODEL",),
# Input are aligned faces
"aligned": ("BOOLEAN", {"default": False}),
# Only restore the center face
"only_center_face": ("BOOLEAN", {"default": False}),
# Adjustable weights
"weight": ("FLOAT", {"default": 0.5}),
"save_tmp_steps": ("BOOLEAN", {"default": True}),
},
"optional": {
"preserve_alpha": ("BOOLEAN", {"default": True}),
},
}
def do_restore(
self,
image: torch.Tensor,
model,
aligned,
only_center_face,
weight,
save_tmp_steps,
preserve_alpha: bool = False,
) -> torch.Tensor:
pimage = tensor2np(image)[0]
width, height = pimage.shape[1], pimage.shape[0]
source_img = cv2.cvtColor(np.array(pimage), cv2.COLOR_RGB2BGR)
alpha_channel = None
if (
preserve_alpha and image.size(-1) == 4
): # Check if the image has an alpha channel
alpha_channel = pimage[:, :, 3]
pimage = pimage[:, :, :3] # Remove alpha channel for processing
sys.stdout = NullWriter()
cropped_faces, restored_faces, restored_img = model.enhance(
source_img,
has_aligned=aligned,
only_center_face=only_center_face,
paste_back=True,
# TODO: weight has no effect in 1.3 and 1.4 (only tested these for now...)
weight=weight,
)
sys.stdout = sys.__stdout__
log.warning(f"Weight value has no effect for now. (value: {weight})")
if save_tmp_steps:
self.save_intermediate_images(
cropped_faces, restored_faces, height, width
)
output = None
if restored_img is not None:
restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
output = Image.fromarray(restored_img)
if alpha_channel is not None:
alpha_resized = Image.fromarray(alpha_channel).resize(
output.size, Image.LANCZOS
)
output.putalpha(alpha_resized)
# imwrite(restored_img, save_restore_path)
return pil2tensor(output)
def restore(
self,
image: torch.Tensor,
model,
aligned=False,
only_center_face=False,
weight=0.5,
save_tmp_steps=True,
preserve_alpha: bool = False,
) -> tuple[torch.Tensor]:
out = [
self.do_restore(
image[i],
model,
aligned,
only_center_face,
weight,
save_tmp_steps,
preserve_alpha,
)
for i in range(image.size(0))
]
return (torch.cat(out, dim=0),)
def get_step_image_path(self, step, idx):
(
full_output_folder,
filename,
counter,
_subfolder,
_filename_prefix,
) = folder_paths.get_save_image_path(
f"{step}_{idx:03}",
folder_paths.temp_directory,
)
file = f"{filename}_{counter:05}_.png"
return os.path.join(full_output_folder, file)
def save_intermediate_images(
self, cropped_faces, restored_faces, height, width
):
for idx, (cropped_face, restored_face) in enumerate(
zip(cropped_faces, restored_faces, strict=False)
):
face_id = idx + 1
file = self.get_step_image_path("cropped_faces", face_id)
cv2.imwrite(file, cropped_face)
file = self.get_step_image_path("cropped_faces_restored", face_id)
cv2.imwrite(file, restored_face)
file = self.get_step_image_path("cropped_faces_compare", face_id)
# save comparison image
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
cv2.imwrite(file, cmp_img)
__nodes__ = [MTB_RestoreFace, MTB_LoadFaceEnhanceModel]