gavinyuan
update: GPEN weights
5c64773
raw
history blame
18.1 kB
import os
import uuid
import glob
import shutil
from pathlib import Path
from multiprocessing.pool import Pool
import gradio as gr
import torch
from torchvision import transforms
import cv2
import numpy as np
from PIL import Image
import tqdm
from modules.networks.faceshifter import FSGenerator
from inference.alignment import norm_crop, norm_crop_with_M, paste_back
from inference.utils import save, get_5_from_98, get_detector, get_lmk
from third_party.PIPNet.lib.tools import get_lmk_model, demo_image
from inference.landmark_smooth import kalman_filter_landmark, savgol_filter_landmark
from inference.tricks import Trick
make_abs_path = lambda fn: os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), fn))
fs_model_name = 'faceshifter'
in_size = 256
mouth_net_param = {
"use": True,
"feature_dim": 128,
"crop_param": (28, 56, 84, 112),
"weight_path": make_abs_path("./weights/arcface/mouth_net_28_56_84_112.pth"),
}
trick = Trick()
T = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5),
]
)
tensor2pil_transform = transforms.ToPILImage()
def extract_generator(ckpt: str, pt: str):
print(f'[extract_generator] loading ckpt...')
from trainer.faceshifter.faceshifter_pl import FaceshifterPL512, FaceshifterPL
import yaml
with open(make_abs_path('../../trainer/faceshifter/config.yaml'), 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
config['mouth_net'] = mouth_net_param
if in_size == 256:
net = FaceshifterPL(n_layers=3, num_D=3, config=config)
elif in_size == 512:
net = FaceshifterPL512(n_layers=3, num_D=3, config=config, verbose=False)
else:
raise ValueError('Not supported in_size.')
checkpoint = torch.load(ckpt, map_location="cpu", )
net.load_state_dict(checkpoint["state_dict"], strict=False)
net.eval()
G = net.generator
torch.save(G.state_dict(), pt)
print(f'[extract_generator] extracted from {ckpt}, pth saved to {pt}')
''' load model '''
if fs_model_name == 'faceshifter':
pt_path = make_abs_path("./weights/extracted/G_mouth1_t38_post.pth")
# pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_t512_6.pth")
# ckpt_path = "/apdcephfs/share_1290939/gavinyuan/out/triplet512_6/epoch=3-step=128999.ckpt"
# pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_t512_4.pth")
# ckpt_path = "/apdcephfs/share_1290939/gavinyuan/out/triplet512_4/epoch=2-step=185999.ckpt"
if not os.path.exists(pt_path) or 't512' in pt_path:
extract_generator(ckpt_path, pt_path)
fs_model = FSGenerator(
make_abs_path("./weights/arcface/ms1mv3_arcface_r100_fp16/backbone.pth"),
mouth_net_param=mouth_net_param,
in_size=in_size,
downup=in_size == 512,
)
fs_model.load_state_dict(torch.load(pt_path, "cpu"), strict=True)
fs_model.eval()
@torch.no_grad()
def infer_batch_to_img(i_s, i_t, post: bool = False):
i_r = fs_model(i_s, i_t)[0] # x, id_vector, att
if post:
target_hair_mask = trick.get_any_mask(i_t, par=[0, 17])
target_hair_mask = trick.smooth_mask(target_hair_mask)
i_r = target_hair_mask * i_t + (target_hair_mask * (-1) + 1) * i_r
i_r = trick.finetune_mouth(i_s, i_t, i_r) if in_size == 256 else i_r
img_r = trick.tensor_to_arr(i_r)[0]
return img_r
elif fs_model_name == 'simswap_triplet' or fs_model_name == 'simswap_vanilla':
from modules.networks.simswap import Generator_Adain_Upsample
sw_model = Generator_Adain_Upsample(
input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=False,
mouth_net_param=mouth_net_param
)
if fs_model_name == 'simswap_triplet':
pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_st5.pth")
ckpt_path = make_abs_path("/apdcephfs/share_1290939/gavinyuan/out/"
"simswap_triplet_5/epoch=12-step=782999.ckpt")
elif fs_model_name == 'simswap_vanilla':
pt_path = make_abs_path("../ffplus/extracted_ckpt/G_tmp_sv4_off.pth")
ckpt_path = make_abs_path("/apdcephfs/share_1290939/gavinyuan/out/"
"simswap_vanilla_4/epoch=694-step=1487999.ckpt")
else:
pt_path = None
ckpt_path = None
sw_model.load_state_dict(torch.load(pt_path, "cpu"), strict=False)
sw_model.eval()
fs_model = sw_model
from trainer.simswap.simswap_pl import SimSwapPL
import yaml
with open(make_abs_path('../../trainer/simswap/config.yaml'), 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
config['mouth_net'] = mouth_net_param
net = SimSwapPL(config=config, use_official_arc='off' in pt_path)
checkpoint = torch.load(ckpt_path, map_location="cpu")
net.load_state_dict(checkpoint["state_dict"], strict=False)
net.eval()
sw_mouth_net = net.mouth_net # maybe None
sw_netArc = net.netArc
fs_model = fs_model.cuda()
sw_mouth_net = sw_mouth_net.cuda() if sw_mouth_net is not None else sw_mouth_net
sw_netArc = sw_netArc.cuda()
@torch.no_grad()
def infer_batch_to_img(i_s, i_t, post: bool = False):
i_r = fs_model(source=i_s, target=i_t, net_arc=sw_netArc, mouth_net=sw_mouth_net,)
if post:
target_hair_mask = trick.get_any_mask(i_t, par=[0, 17])
target_hair_mask = trick.smooth_mask(target_hair_mask)
i_r = target_hair_mask * i_t + (target_hair_mask * (-1) + 1) * i_r
i_r = i_r.clamp(-1, 1)
i_r = trick.tensor_to_arr(i_r)[0]
return i_r
elif fs_model_name == 'simswap_official':
from simswap.image_infer import SimSwapOfficialImageInfer
fs_model = SimSwapOfficialImageInfer()
pt_path = 'Simswap Official'
mouth_net_param = {
"use": False
}
@torch.no_grad()
def infer_batch_to_img(i_s, i_t):
i_r = fs_model.image_infer(source_tensor=i_s, target_tensor=i_t)
i_r = i_r.clamp(-1, 1)
return i_r
else:
raise ValueError('Not supported fs_model_name.')
print(f'[demo] model loaded from {pt_path}')
def swap_image(
source_image,
target_path,
out_path,
transform,
G,
align_source="arcface",
align_target="set1",
gpu_mode=True,
paste_back=True,
use_post=False,
use_gpen=False,
in_size=256,
):
name = target_path.split("/")[-1]
name = "out_" + name
if isinstance(G, torch.nn.Module):
G.eval()
if gpu_mode:
G = G.cuda()
device = torch.device(0) if gpu_mode else torch.device('cpu')
source_img = np.array(Image.open(source_image).convert("RGB"))
net, detector = get_lmk_model()
lmk = get_5_from_98(demo_image(source_img, net, detector, device=device)[0])
source_img = norm_crop(source_img, lmk, in_size, mode=align_source, borderValue=0.0)
source_img = transform(source_img).unsqueeze(0)
target = np.array(Image.open(target_path).convert("RGB"))
original_target = target.copy()
lmk = get_5_from_98(demo_image(target, net, detector, device=device)[0])
target, M = norm_crop_with_M(target, lmk, in_size, mode=align_target, borderValue=0.0)
target = transform(target).unsqueeze(0)
if gpu_mode:
target = target.cuda()
source_img = source_img.cuda()
cv2.imwrite('cropped_source.png', trick.tensor_to_arr(source_img)[0, :, :, ::-1])
cv2.imwrite('cropped_target.png', trick.tensor_to_arr(target)[0, :, :, ::-1])
# both inputs should be 512
result = infer_batch_to_img(source_img, target, post=use_post)
cv2.imwrite('result.png', result[:, :, ::-1])
os.makedirs(out_path, exist_ok=True)
Image.fromarray(result.astype(np.uint8)).save(os.path.join(out_path, name))
save((result, M, original_target, os.path.join(out_path, "paste_back_" + name), None),
trick=trick, use_post=use_gpen)
def process_video(
source_image,
target_path,
out_path,
transform,
G,
align_source="arcface",
align_target="set1",
gpu_mode=True,
frames=9999999,
use_tddfav2=False,
landmark_smooth="kalman",
):
if isinstance(G, torch.nn.Module):
G.eval()
if gpu_mode:
G = G.cuda()
device = torch.device(0) if gpu_mode else torch.device('cpu')
''' Target video to frames (.png) '''
fps = 25.0
if not os.path.isdir(target_path):
vidcap = cv2.VideoCapture(target_path)
fps = vidcap.get(cv2.CAP_PROP_FPS)
try:
for match in glob.glob(os.path.join("./tmp/", "*.png")):
os.remove(match)
for match in glob.glob(os.path.join(out_path, "*.png")):
os.remove(match)
except Exception as e:
print(e)
os.makedirs("./tmp/", exist_ok=True)
os.system(
f"ffmpeg -i {target_path} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 ./tmp/frame_%05d.png"
)
target_path = "./tmp/"
globbed_images = sorted(glob.glob(os.path.join(target_path, "*.png")))
''' Get target landmarks '''
print('[Extracting target landmarks...]')
if not use_tddfav2:
align_net, align_detector = get_lmk_model()
else:
align_net, align_detector = get_detector(gpu_mode=gpu_mode)
target_lmks = []
for frame_path in tqdm.tqdm(globbed_images):
target = np.array(Image.open(frame_path).convert("RGB"))
lmk = demo_image(target, align_net, align_detector, device=device)
lmk = lmk[0]
target_lmks.append(lmk)
''' Landmark smoothing '''
target_lmks = np.array(target_lmks, np.float32) # (#frames, 98, 2)
if landmark_smooth == 'kalman':
target_lmks = kalman_filter_landmark(target_lmks,
process_noise=0.01,
measure_noise=0.01).astype(np.int32)
elif landmark_smooth == 'savgol':
target_lmks = savgol_filter_landmark(target_lmks).astype(np.int32)
elif landmark_smooth == 'cancel':
target_lmks = target_lmks.astype(np.int32)
else:
raise KeyError('Not supported landmark_smooth choice')
''' Crop source image '''
source_img = np.array(Image.open(source_image).convert("RGB"))
if not use_tddfav2:
lmk = get_5_from_98(demo_image(source_img, align_net, align_detector, device=device)[0])
else:
lmk = get_lmk(source_img, align_net, align_detector)
source_img = norm_crop(source_img, lmk, in_size, mode=align_source, borderValue=0.0)
source_img = transform(source_img).unsqueeze(0)
if gpu_mode:
source_img = source_img.cuda()
''' Process by frames '''
targets = []
t_facial_masks = []
Ms = []
original_frames = []
names = []
count = 0
for image in tqdm.tqdm(globbed_images):
names.append(os.path.join(out_path, Path(image).name))
target = np.array(Image.open(image).convert("RGB"))
original_frames.append(target)
''' Crop target frames '''
lmk = get_5_from_98(target_lmks[count])
target, M = norm_crop_with_M(target, lmk, in_size, mode=align_target, borderValue=0.0)
target = transform(target).unsqueeze(0) # in [-1,1]
if gpu_mode:
target = target.cuda()
''' Finetune paste masks '''
target_facial_mask = trick.get_any_mask(target,
par=[1, 2, 3, 4, 5, 6, 10, 11, 12, 13]).squeeze() # in [0,1]
target_facial_mask = target_facial_mask.cpu().numpy().astype(np.float32)
target_facial_mask = trick.finetune_mask(target_facial_mask, target_lmks) # in [0,1]
t_facial_masks.append(target_facial_mask)
''' Face swapping '''
with torch.no_grad():
if 'faceshifter' in fs_model_name:
output = G(source_img, target)
target_hair_mask = trick.get_any_mask(target, par=[0, 17])
target_hair_mask = trick.smooth_mask(target_hair_mask)
output = target_hair_mask * target + (target_hair_mask * (-1) + 1) * output
output = trick.finetune_mouth(source_img, target, output)
elif 'simswap' in fs_model_name and 'official' not in fs_model_name:
output = fs_model(source=source_img, target=target,
net_arc=sw_netArc, mouth_net=sw_mouth_net,)
if 'vanilla' not in fs_model_name:
target_hair_mask = trick.get_any_mask(target, par=[0, 17])
target_hair_mask = trick.smooth_mask(target_hair_mask)
output = target_hair_mask * target + (target_hair_mask * (-1) + 1) * output
output = trick.finetune_mouth(source_img, target, output)
output = output.clamp(-1, 1)
elif 'simswap_official' in fs_model_name:
output = fs_model.image_infer(source_tensor=source_img, target_tensor=target)
output = output.clamp(-1, 1)
if isinstance(output, tuple):
target = output[0][0] * 0.5 + 0.5
else:
target = output[0] * 0.5 + 0.5
targets.append(np.array(tensor2pil_transform(target)))
Ms.append(M)
count += 1
if count > frames:
break
os.makedirs(out_path, exist_ok=True)
return targets, t_facial_masks, Ms, original_frames, names, fps
def swap_image_gr(img1, img2, use_post=False, use_gpen=False, ):
root_dir = make_abs_path("./online_data")
req_id = uuid.uuid1().hex
data_dir = os.path.join(root_dir, req_id)
os.makedirs(data_dir, exist_ok=True)
source_path = os.path.join(data_dir, "source.png")
target_path = os.path.join(data_dir, "target.png")
filename = "paste_back_out_target.png"
out_path = os.path.join(data_dir, filename)
cv2.imwrite(source_path, img1[:, :, ::-1])
cv2.imwrite(target_path, img2[:, :, ::-1])
swap_image(
source_path,
target_path,
data_dir,
T,
fs_model,
gpu_mode=use_gpu,
align_target='ffhq',
align_source='ffhq',
use_post=use_post,
use_gpen=use_gpen,
in_size=in_size,
)
out = cv2.imread(out_path)[..., ::-1]
return out
def swap_video_gr(img1, target_path, frames=9999999):
root_dir = make_abs_path("./online_data")
req_id = uuid.uuid1().hex
data_dir = os.path.join(root_dir, req_id)
os.makedirs(data_dir, exist_ok=True)
source_path = os.path.join(data_dir, "source.png")
cv2.imwrite(source_path, img1[:, :, ::-1])
out_dir = os.path.join(data_dir, "out")
out_name = "output.mp4"
targets, t_facial_masks, Ms, original_frames, names, fps = process_video(
source_path,
target_path,
out_dir,
T,
fs_model,
gpu_mode=use_gpu,
frames=frames,
align_target='ffhq',
align_source='ffhq',
use_tddfav2=False,
)
pool_process = 170
audio = True
concat = False
if pool_process <= 1:
for target, M, original_target, name, t_facial_mask in tqdm.tqdm(
zip(targets, Ms, original_frames, names, t_facial_masks)
):
if M is None or target is None:
Image.fromarray(original_target.astype(np.uint8)).save(name)
continue
Image.fromarray(paste_back(np.array(target), M, original_target, t_facial_mask)).save(name)
else:
with Pool(pool_process) as pool:
pool.map(save, zip(targets, Ms, original_frames, names, t_facial_masks))
video_save_path = os.path.join(out_dir, out_name)
if audio:
print("use audio")
os.system(
f"ffmpeg -y -r {fps} -i {out_dir}/frame_%05d.png -i {target_path}"
f" -map 0:v:0 -map 1:a:0? -c:a copy -c:v libx264 -r {fps} -crf 10 -pix_fmt yuv420p {video_save_path}"
)
else:
print("no audio")
os.system(
f"ffmpeg -y -r {fps} -i ./tmp/frame_%05d.png "
f"-c:v libx264 -r {fps} -crf 10 -pix_fmt yuv420p {video_save_path}"
)
# ffmpeg -i left.mp4 -i right.mp4 -filter_complex hstack output.mp4
if concat:
concat_video_save_path = os.path.join(out_dir, "concat_" + out_name)
os.system(
f"ffmpeg -y -i {target_path} -i {video_save_path} -filter_complex hstack {concat_video_save_path}"
)
# delete tmp file
shutil.rmtree("./tmp/")
for match in glob.glob(os.path.join(out_dir, "*.png")):
os.remove(match)
print(video_save_path)
return video_save_path
if __name__ == "__main__":
use_gpu = torch.cuda.is_available()
with gr.Blocks() as demo:
gr.Markdown("SuperSwap")
with gr.Tab("Image"):
with gr.Row():
with gr.Column(scale=3):
image1_input = gr.Image(label='source')
image2_input = gr.Image(label='target')
use_post = gr.Checkbox(label="Post-Process")
use_gpen = gr.Checkbox(label="Super Resolution")
with gr.Column(scale=2):
image_output = gr.Image()
image_button = gr.Button("Run: Face Swapping")
with gr.Tab("Video"):
with gr.Row():
with gr.Column(scale=3):
image3_input = gr.Image(label='source')
video_input = gr.Video(label='target')
with gr.Column(scale=2):
video_output = gr.Video()
video_button = gr.Button("Run: Face Swapping")
image_button.click(
swap_image_gr,
inputs=[image1_input, image2_input, use_post, use_gpen],
outputs=image_output,
)
video_button.click(
swap_video_gr,
inputs=[image3_input, video_input],
outputs=video_output,
)
demo.launch(server_name="0.0.0.0", server_port=7860)